import argparse import os import os.path as osp import shutil from diffusers import AutoencoderKL, UNet2DConditionModel from transformers import CLIPTextModel import onnx import torch def export_text_encoder(pipeline_dir): model_name = "text_encoder" save_path = osp.join(pipeline_dir, model_name, "model.onnx") model = CLIPTextModel.from_pretrained(osp.join(pipeline_dir, model_name)) input_names = ["input_ids"] output_names = ["last_hidden_state", "pooler_output"] dynamic_axes = { 'input_ids': { 0: 'batch_size', 1: 'sequence_length', }, 'last_hidden_state': { 0: 'batch_size', 1: 'sequence_length', }, 'pooler_output': { 0: 'batch_size', } } torch.onnx.export( model, (torch.zeros(1, model.config.max_position_embeddings, dtype=torch.int32), ), save_path, export_params=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes ) if osp.isfile(save_path): print(f"Successfully exported ${model_name} to ONNX: {save_path}") else: raise RuntimeError(f"Failed to export ${model_name} to ONNX.") return save_path def export_unet(pipeline_dir): model_name = "unet" save_path = osp.join(pipeline_dir, model_name, "model.onnx") tmp_dir = "./temp" os.makedirs(tmp_dir, exist_ok=True) tmp_path = "./temp/model.onnx" model = UNet2DConditionModel.from_pretrained(pipeline_dir, subfolder=model_name) input_names = ["sample", "timestep", "encoder_hidden_states"] output_names = ["out_sample"] dynamic_axes = { 'sample': { 0: 'batch_size', 1: 'num_channels', 2: 'height', 3: 'width' }, 'timestep': { 0: 'steps', }, 'encoder_hidden_states': { 0: 'batch_size', 1: 'sequence_length', }, 'out_sample': { 0: 'batch_size', 1: 'num_channels', 2: 'height', 3: 'width' } } dummy_input = ( torch.randn(2, model.config["in_channels"], 64, 64), torch.tensor([1], dtype=torch.int64), torch.randn(2, 77, 1024) ) torch.onnx.export( model, dummy_input, tmp_path, export_params=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes ) onnx_model = onnx.load(tmp_path) external_data_path = osp.basename(save_path) + '.data' if osp.isfile(external_data_path): os.remove(external_data_path) onnx.save(onnx_model, save_path, save_as_external_data=True, all_tensors_to_one_file=True, location=external_data_path, size_threshold=1024, convert_attribute=False) shutil.rmtree(tmp_dir) if osp.isfile(save_path): print(f"Successfully exported {model_name} to ONNX: {save_path}") else: raise RuntimeError(f"Failed to export {model_name} to ONNX.") return save_path def export_vae_decoder(pipeline_dir): model_name = "vae_decoder" sub_model_dir = osp.join(pipeline_dir, model_name) os.makedirs(sub_model_dir, exist_ok=True) shutil.copy(osp.join(pipeline_dir, 'vae/config.json'), osp.join(sub_model_dir, "config.json")) save_path = osp.join(sub_model_dir, "model.onnx") vae = AutoencoderKL.from_pretrained(pipeline_dir, subfolder="vae") input_names = ["latent_sample"] output_names = ["sample"] dynamic_axes = { 'latent_sample': { 0: 'batch_size', 2: 'latent_height', 3: 'latent_width' }, 'latent': { 0: 'batch_size', 2: 'image_height', 3: 'image_width' } } vae.forward = vae.decode torch.onnx.export( vae, (torch.randn(1, vae.config["latent_channels"], 64, 64), ), save_path , input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes ) if osp.isfile(save_path): print(f"Successfully exported ${model_name} to ONNX: {save_path}") else: raise RuntimeError(f"Failed to export ${model_name} to ONNX.") return save_path def main(): parser = argparse.ArgumentParser("Export vae decoder to ONNX") parser.add_argument("--pipeline-dir", type=str, required=True, help="The path to the sdxl pipeline directory.") args = parser.parse_args() export_text_encoder(args.pipeline_dir) export_unet(args.pipeline_dir) export_vae_decoder(args.pipeline_dir) if __name__ == "__main__": main()