import argparse import os import os.path as osp import shutil import onnx import torch from transformers import (CLIPTextModel, T5EncoderModel) from diffusers import FluxTransformer2DModel, AutoencoderKL def get_local_path(local_dir, model_dir): model_local_dir = os.path.join(local_dir, model_dir) if not os.path.exists(model_local_dir): os.makedirs(model_local_dir) return model_local_dir def gather_weights_to_one_file(onnx_path): onnx_model = onnx.load(onnx_path) onnx_model_without_data = onnx.load(onnx_path, load_external_data=False) os.remove(onnx_path) # remove old model file # remove external data file dir_path = osp.dirname(onnx_path) for ini in onnx_model_without_data.graph.initializer: for ed in ini.external_data: external_data_path = osp.join(dir_path, ed.value) if osp.isfile(external_data_path): os.remove(external_data_path) for node in onnx_model_without_data.graph.node: if node.op_type != "Constant": continue for attr in node.attribute: external_data_path = osp.join( dir_path, attr.t.name.replace('/', '_').replace(':', '_')) if osp.isfile(external_data_path): os.remove(external_data_path) onnx.save(onnx_model, onnx_path, save_as_external_data=True, all_tensors_to_one_file=True, location="model.onnx.data") def copy_files(local_dir, save_dir, overwrite=True): if overwrite or not osp.is_exist(osp.join(save_dir, "scheduler")): shutil.copytree(osp.join(local_dir, "scheduler"), osp.join(save_dir, "scheduler"), dirs_exist_ok=True) if overwrite or not osp.is_exist(osp.join(save_dir, "tokenizer")): shutil.copytree(osp.join(local_dir, "tokenizer"), osp.join(save_dir, "tokenizer"), dirs_exist_ok=True) if overwrite or not osp.is_exist(osp.join(save_dir, "tokenizer_2")): shutil.copytree(osp.join(local_dir, "tokenizer_2"), osp.join(save_dir, "tokenizer_2"), dirs_exist_ok=True) if overwrite or not osp.is_exist(osp.join(save_dir, 'model_index.json')): shutil.copy(osp.join(local_dir, 'model_index.json'), osp.join(save_dir, 'model_index.json')) for sub_dir in ['text_encoder', 'text_encoder_2', 'transformer', 'vae']: if overwrite or not osp.is_exist( osp.join(save_dir, sub_dir, 'config.json')): shutil.copy(osp.join(local_dir, sub_dir, 'config.json'), osp.join(save_dir, sub_dir, 'config.json')) def export_clip(local_dir, model_dir="text_encoder", save_dir=None, torch_dtype=torch.float32): save_dir = save_dir or local_dir clip_save_dir = get_local_path(save_dir, model_dir) onnx_path = os.path.join(clip_save_dir, "model.onnx") bs = 1 max_len = 77 sample_inputs = (torch.zeros(bs, max_len, dtype=torch.int32), ) input_names = ["input_ids"] model = CLIPTextModel.from_pretrained(local_dir, subfolder=model_dir, torch_dtype=torch_dtype) output_names = ["text_embeddings"] dynamic_axes = {"input_ids": {0: 'B'}, "text_embeddings": {0: 'B'}} # CLIP export requires nightly pytorch due to bug in onnx parser with torch.inference_mode(): torch.onnx.export(model, sample_inputs, onnx_path, export_params=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes) assert os.path.isfile(onnx_path) gather_weights_to_one_file(onnx_path) print(f"Success export clip model: {onnx_path}") return onnx_path def export_t5(local_dir, model_dir="text_encoder_2", save_dir=None, torch_dtype=torch.float32): save_dir = save_dir or local_dir t5_save_dir = get_local_path(save_dir, model_dir) onnx_path = os.path.join(t5_save_dir, "model.onnx") bs = 1 max_len = 512 sample_inputs = (torch.zeros(bs, max_len, dtype=torch.int32), ) input_names = ["input_ids"] model = T5EncoderModel.from_pretrained(local_dir, subfolder=model_dir, torch_dtype=torch_dtype) output_names = ["text_embeddings"] dynamic_axes = {"input_ids": {0: 'B'}, "text_embeddings": {0: 'B'}} with torch.inference_mode(): torch.onnx.export(model, sample_inputs, onnx_path, export_params=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes) assert os.path.isfile(onnx_path) gather_weights_to_one_file(onnx_path) print(f"Success export t5 model: {onnx_path}") return onnx_path # Following decorators required to apply fp16 inference patch to the \ # transformer blocks. Note that we do not export fp16 weights directly to ONNX \ # to allow migraphx to perform optimizations before quantizing down to fp16. \ # This results in better accuracy compared to exporting fp16 directly to onnx. def transformer_block_clip_wrapper(fn): def new_forward(*args, **kwargs): encoder_hidden_states, hidden_states = fn(*args, **kwargs) return encoder_hidden_states.clip(-65504, 65504), hidden_states return new_forward def single_transformer_block_clip_wrapper(fn): def new_forward(*args, **kwargs): hidden_states = fn(*args, **kwargs) return hidden_states.clip(-65504, 65504) return new_forward def add_output_clippings_for_fp16(model): for b in model.transformer_blocks: b.forward = transformer_block_clip_wrapper(b.forward) for b in model.single_transformer_blocks: b.forward = single_transformer_block_clip_wrapper(b.forward) def export_transformer(local_dir, model_dir="transformer", save_dir=None, torch_dtype=torch.float32, fp16=True): save_dir = save_dir or local_dir transformer_save_dir = get_local_path(save_dir, model_dir) onnx_path = os.path.join(transformer_save_dir, "model.onnx") bs = 1 img_height = 1024 img_width = 1024 compression_factor = 8 latent_h = img_height // compression_factor latent_w = img_width // compression_factor max_len = 512 config = FluxTransformer2DModel.load_config(local_dir, subfolder=model_dir) sample_inputs = ( torch.randn(bs, (latent_h // 2) * (latent_w // 2), config["in_channels"], dtype=torch_dtype), torch.randn(bs, max_len, config['joint_attention_dim'], dtype=torch_dtype), torch.randn(bs, config['pooled_projection_dim'], dtype=torch_dtype), torch.tensor([1.] * bs, dtype=torch_dtype), torch.randn((latent_h // 2) * (latent_w // 2), 3, dtype=torch_dtype), torch.randn(max_len, 3, dtype=torch_dtype), torch.tensor([1.] * bs, dtype=torch_dtype), ) input_names = [ 'hidden_states', 'encoder_hidden_states', 'pooled_projections', 'timestep', 'img_ids', 'txt_ids', 'guidance' ] model = FluxTransformer2DModel.from_pretrained(local_dir, subfolder=model_dir, torch_dtype=torch_dtype) if fp16: print("applying fp16 clip workarounds to transformer") add_output_clippings_for_fp16(model) output_names = ["latent"] dynamic_axes = { 'hidden_states': { 0: 'B', 1: 'latent_dim' }, 'encoder_hidden_states': { 0: 'B', 1: 'L' }, 'pooled_projections': { 0: 'B' }, 'timestep': { 0: 'B' }, 'img_ids': { 0: 'latent_dim' }, 'txt_ids': { 0: 'L' }, 'guidance': { 0: 'B' }, } with torch.inference_mode(): torch.onnx.export(model, sample_inputs, onnx_path, export_params=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes) assert os.path.isfile(onnx_path) gather_weights_to_one_file(onnx_path) print(f"Success export transformer model: {onnx_path}") return onnx_path def export_vae(local_dir, model_dir="vae", save_dir=None, torch_dtype=torch.float32): save_dir = save_dir or local_dir vae_save_dir = get_local_path(save_dir, model_dir) onnx_path = os.path.join(vae_save_dir, "model.onnx") config = AutoencoderKL.load_config(local_dir, subfolder=model_dir) bs=1 latent_channels = config['latent_channels'] img_height = 1024 img_width = 1024 compression_factor = 8 latent_h = img_height // compression_factor latent_w = img_width // compression_factor sample_inputs = (torch.randn(bs, latent_channels, latent_h, latent_w, dtype=torch_dtype), ) input_names = ["latent"] model = AutoencoderKL.from_pretrained(local_dir, subfolder=model_dir, torch_dtype=torch_dtype) model.forward = model.decode output_names = ["images"] dynamic_axes = { 'latent': { 0: 'B', 2: 'H', 3: 'W' }, 'images': { 0: 'B', 2: '8H', 3: '8W' } } with torch.inference_mode(): torch.onnx.export(model, sample_inputs, onnx_path, export_params=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes) assert os.path.isfile(onnx_path) gather_weights_to_one_file(onnx_path) print(f"Success export vae_decoder model: {onnx_path}") return onnx_path def parse_args(): parser = argparse.ArgumentParser(description="export ONNX models") parser.add_argument("--local-dir", type=str, required=True, help="local directory containing the model") parser.add_argument("--save-dir", type=str, required=None, help="the directory for saving ONNX models") args = parser.parse_args() if args.save_dir is None: args.save_dir = args.local_dir return args def main(): args = parse_args() local_dir = args.local_dir save_dir = args.save_dir os.makedirs(save_dir, exist_ok=True) export_clip(local_dir, save_dir=save_dir) export_t5(local_dir, save_dir=save_dir) export_transformer(local_dir, save_dir=save_dir) export_vae(local_dir, save_dir=save_dir) if save_dir != local_dir: copy_files(local_dir, save_dir, overwrite=True) if __name__ == "__main__": main()