import argparse import os.path as osp from diffusers import AutoencoderKL, DiffusionPipeline from migraphx_diffusers import ONNXModifier import numpy as np import onnx import torch def export_vae_decoder(pipeline_dir): save_path = osp.join(pipeline_dir, "vae_decoder", "model.onnx") vae = AutoencoderKL.from_pretrained(pipeline_dir, subfolder="vae") # pipe = DiffusionPipeline.from_pretrained(pipeline_dir, use_safetensors=True) input_names = ["latent_sample"] output_names = ["sample"] dynamic_axes = { 'latent_sample': { 0: 'batch_size', 2: 'latent_height', 3: 'latent_width' }, 'latent': { 0: 'batch_size', 2: 'image_height', 3: 'image_width' } } vae.forward = vae.decode # pipe.vae.forward = pipe.vae.decode torch.onnx.export( vae, # pipe.vae, torch.randn(1, vae.config["latent_channels"], 128, 128), # torch.randn(1, pipe.vae.config["latent_channels"], 128, 128), save_path , export_params=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes ) if osp.isfile(save_path): print(f"Successfully exported vae decoder to ONNX: {save_path}") else: raise RuntimeError(f"Failed to export vae decoder to ONNX.") return save_path def modify_vae_decoder(onnx_path, save_path=None): om = ONNXModifier(onnx_path) shape_node = om.get_node("/decoder/mid_block/attentions.0/Shape") C_gather = om.get_node("/decoder/mid_block/attentions.0/Gather_1") H_gather = om.get_node("/decoder/mid_block/attentions.0/Gather_2") W_gather = om.get_node("/decoder/mid_block/attentions.0/Gather_3") C_gather.set_input(0, shape_node.outputs[0]) H_gather.set_input(0, shape_node.outputs[0]) W_gather.set_input(0, shape_node.outputs[0]) matmul1 = om.get_node("/decoder/mid_block/attentions.0/to_k/MatMul") matmul1.set_input(0, matmul1.inputs[1]) matmul1.set_input(1, "/decoder/mid_block/attentions.0/group_norm/Add_output_0") matmul1_ini = om.get_initializer("onnx::MatMul_918") orig_dtype = onnx.helper.tensor_dtype_to_np_dtype(matmul1_ini.data_type) orig_shape = matmul1_ini.dims[:] new_value = np.frombuffer(matmul1_ini.raw_data, dtype=orig_dtype).reshape(orig_shape).T matmul1_ini.raw_data = new_value.tobytes() matmul1_ini.dims[:] = orig_shape[::-1] add_ini = om.get_initializer("decoder.mid_block.attentions.0.to_k.bias") add_ini.dims[:] = [512, 1] mul1 = om.get_node("/decoder/mid_block/attentions.0/Mul_1") mul2 = om.get_node("/decoder/mid_block/attentions.0/Mul_2") mul1.set_input(0, "/decoder/mid_block/attentions.0/to_q/Add_output_0") mul2.set_input(0, "/decoder/mid_block/attentions.0/to_k/Add_output_0") value_B = om.create_initializer("/decoder/mid_block/attentions.0/mul_B", np.sqrt(1 / np.sqrt(512)).astype(np.float32)[None]) mul1.set_input(1, value_B.name) mul2.set_input(1, value_B.name) matmul2 = om.get_node("/decoder/mid_block/attentions.0/MatMul_1") matmul2.set_input(1, "/decoder/mid_block/attentions.0/to_v/Add_output_0") matmul3 = om.get_node("/decoder/mid_block/attentions.0/to_out.0/MatMul") matmul3.set_input(0, matmul2.outputs[0]) reshape_node = om.get_node("/decoder/mid_block/attentions.0/Reshape_5") reshape_node.set_input(1, shape_node.outputs[0]) om.update_map() om.remove_trash() save_path = onnx_path if save_path is None else save_path om.save(save_path) def main(): parser = argparse.ArgumentParser("Export vae decoder to ONNX") parser.add_argument("--pipeline-dir", type=str, required=True, help="The path to the sdxl pipeline directory.") args = parser.parse_args() onnx_path = export_vae_decoder(args.pipeline_dir) modify_vae_decoder(onnx_path) if __name__ == "__main__": main()