export_onnx.py 11.7 KB
Newer Older
wangwf's avatar
wangwf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import argparse
import os
import os.path as osp
import shutil

import onnx
import torch
from transformers import (CLIPTextModel, T5EncoderModel)
from diffusers import FluxTransformer2DModel, AutoencoderKL


def get_local_path(local_dir, model_dir):
    model_local_dir = os.path.join(local_dir, model_dir)
    if not os.path.exists(model_local_dir):
        os.makedirs(model_local_dir)
    return model_local_dir


def gather_weights_to_one_file(onnx_path):
    onnx_model = onnx.load(onnx_path)
    onnx_model_without_data = onnx.load(onnx_path, load_external_data=False)
    
    os.remove(onnx_path)  # remove old model file

    # remove external data file
    dir_path = osp.dirname(onnx_path)
    for ini in onnx_model_without_data.graph.initializer:
        for ed in ini.external_data:
            external_data_path = osp.join(dir_path, ed.value)
            if osp.isfile(external_data_path):
                os.remove(external_data_path)
    for node in onnx_model_without_data.graph.node:
        if node.op_type != "Constant":
            continue
        for attr in node.attribute:
            external_data_path = osp.join(
                dir_path, attr.t.name.replace('/', '_').replace(':', '_'))
            if osp.isfile(external_data_path):
                os.remove(external_data_path)
            
    onnx.save(onnx_model,
              onnx_path,
              save_as_external_data=True,
              all_tensors_to_one_file=True,
              location="model.onnx.data")


def copy_files(local_dir, save_dir, overwrite=True):
    if overwrite or not osp.is_exist(osp.join(save_dir, "scheduler")):
        shutil.copytree(osp.join(local_dir, "scheduler"), 
                        osp.join(save_dir, "scheduler"), 
                        dirs_exist_ok=True)
    if overwrite or not osp.is_exist(osp.join(save_dir, "tokenizer")):
        shutil.copytree(osp.join(local_dir, "tokenizer"), 
                        osp.join(save_dir, "tokenizer"), 
                        dirs_exist_ok=True)
    if overwrite or not osp.is_exist(osp.join(save_dir, "tokenizer_2")):
        shutil.copytree(osp.join(local_dir, "tokenizer_2"), 
                        osp.join(save_dir, "tokenizer_2"), 
                        dirs_exist_ok=True)
    if overwrite or not osp.is_exist(osp.join(save_dir, 'model_index.json')):
        shutil.copy(osp.join(local_dir, 'model_index.json'), 
                    osp.join(save_dir, 'model_index.json'))
    for sub_dir in ['text_encoder', 'text_encoder_2', 'transformer', 'vae']:
        if overwrite or not osp.is_exist(
                osp.join(save_dir, sub_dir, 'config.json')):
            shutil.copy(osp.join(local_dir, sub_dir, 'config.json'), 
                        osp.join(save_dir, sub_dir, 'config.json'))


def export_clip(local_dir, 
                model_dir="text_encoder", 
                save_dir=None,
                torch_dtype=torch.float32):
    save_dir = save_dir or local_dir
    clip_save_dir = get_local_path(save_dir, model_dir)
    onnx_path = os.path.join(clip_save_dir, "model.onnx")

    bs = 1
    max_len = 77
    sample_inputs = (torch.zeros(bs, max_len, dtype=torch.int32), )
    input_names = ["input_ids"]

    model = CLIPTextModel.from_pretrained(local_dir,
                                          subfolder=model_dir,
                                          torch_dtype=torch_dtype)

    output_names = ["text_embeddings"]
    dynamic_axes = {"input_ids": {0: 'B'}, "text_embeddings": {0: 'B'}}

    # CLIP export requires nightly pytorch due to bug in onnx parser
    with torch.inference_mode():
        torch.onnx.export(model,
                            sample_inputs,
                            onnx_path,
                            export_params=True,
                            input_names=input_names,
                            output_names=output_names,
                            dynamic_axes=dynamic_axes)

    assert os.path.isfile(onnx_path)
    gather_weights_to_one_file(onnx_path)
    print(f"Success export clip model: {onnx_path}")
    return onnx_path


def export_t5(local_dir, 
              model_dir="text_encoder_2", 
              save_dir=None,
              torch_dtype=torch.float32):
    save_dir = save_dir or local_dir
    t5_save_dir = get_local_path(save_dir, model_dir)
    onnx_path = os.path.join(t5_save_dir, "model.onnx")

    bs = 1
    max_len = 512
    sample_inputs = (torch.zeros(bs, max_len, dtype=torch.int32), )
    input_names = ["input_ids"]
    model = T5EncoderModel.from_pretrained(local_dir,
                                            subfolder=model_dir,
                                            torch_dtype=torch_dtype)
    output_names = ["text_embeddings"]
    dynamic_axes = {"input_ids": {0: 'B'}, "text_embeddings": {0: 'B'}}

    with torch.inference_mode():
        torch.onnx.export(model,
                          sample_inputs,
                          onnx_path,
                          export_params=True,
                          input_names=input_names,
                          output_names=output_names,
                          dynamic_axes=dynamic_axes)

    assert os.path.isfile(onnx_path)
    gather_weights_to_one_file(onnx_path)
    print(f"Success export t5 model: {onnx_path}")
    return onnx_path

# Following decorators required to apply fp16 inference patch to the \
# transformer blocks. Note that we do not export fp16 weights directly to ONNX \
# to allow migraphx to perform optimizations before quantizing down to fp16. \
# This results in better accuracy compared to exporting fp16 directly to onnx.
def transformer_block_clip_wrapper(fn):
    def new_forward(*args, **kwargs):
        encoder_hidden_states, hidden_states = fn(*args, **kwargs)
        return encoder_hidden_states.clip(-65504, 65504), hidden_states

    return new_forward


def single_transformer_block_clip_wrapper(fn):
    def new_forward(*args, **kwargs):
        hidden_states = fn(*args, **kwargs)
        return hidden_states.clip(-65504, 65504)

    return new_forward


def add_output_clippings_for_fp16(model):
    for b in model.transformer_blocks:
        b.forward = transformer_block_clip_wrapper(b.forward)

    for b in model.single_transformer_blocks:
        b.forward = single_transformer_block_clip_wrapper(b.forward)


def export_transformer(local_dir,
                       model_dir="transformer",
                       save_dir=None,
                       torch_dtype=torch.float32,
                       fp16=True):
    save_dir = save_dir or local_dir
    transformer_save_dir = get_local_path(save_dir, model_dir)
    onnx_path = os.path.join(transformer_save_dir, "model.onnx")

    bs = 1
    img_height = 1024
    img_width = 1024
    compression_factor = 8
    latent_h = img_height // compression_factor
    latent_w = img_width // compression_factor
    max_len = 512

    config = FluxTransformer2DModel.load_config(local_dir,
                                                subfolder=model_dir)
    sample_inputs = (
        torch.randn(bs, (latent_h // 2) * (latent_w // 2),
                    config["in_channels"],
                    dtype=torch_dtype),
        torch.randn(bs,
                    max_len,
                    config['joint_attention_dim'],
                    dtype=torch_dtype),
        torch.randn(bs, config['pooled_projection_dim'], dtype=torch_dtype),
        torch.tensor([1.] * bs, dtype=torch_dtype),
        torch.randn((latent_h // 2) * (latent_w // 2), 3, dtype=torch_dtype),
        torch.randn(max_len, 3, dtype=torch_dtype),
        torch.tensor([1.] * bs, dtype=torch_dtype),
    )

    input_names = [
        'hidden_states', 'encoder_hidden_states', 'pooled_projections',
        'timestep', 'img_ids', 'txt_ids', 'guidance'
    ]

    model = FluxTransformer2DModel.from_pretrained(local_dir,
                                                   subfolder=model_dir,
                                                   torch_dtype=torch_dtype)

    if fp16:
        print("applying fp16 clip workarounds to transformer")
        add_output_clippings_for_fp16(model)

    output_names = ["latent"]
    dynamic_axes = {
        'hidden_states': {
            0: 'B',
            1: 'latent_dim'
        },
        'encoder_hidden_states': {
            0: 'B',
            1: 'L'
        },
        'pooled_projections': {
            0: 'B'
        },
        'timestep': {
            0: 'B'
        },
        'img_ids': {
            0: 'latent_dim'
        },
        'txt_ids': {
            0: 'L'
        },
        'guidance': {
            0: 'B'
        },
    }

    with torch.inference_mode():
        torch.onnx.export(model,
                          sample_inputs,
                          onnx_path,
                          export_params=True,
                          input_names=input_names,
                          output_names=output_names,
                          dynamic_axes=dynamic_axes)

    assert os.path.isfile(onnx_path)
    gather_weights_to_one_file(onnx_path)
    print(f"Success export transformer model: {onnx_path}")
    return onnx_path


def export_vae(local_dir, 
               model_dir="vae", 
               save_dir=None,
               torch_dtype=torch.float32):
    save_dir = save_dir or local_dir
    vae_save_dir = get_local_path(save_dir, model_dir)
    onnx_path = os.path.join(vae_save_dir, "model.onnx")
    
    config = AutoencoderKL.load_config(local_dir, subfolder=model_dir)
    bs=1
    latent_channels = config['latent_channels']
    img_height = 1024
    img_width = 1024
    compression_factor = 8
    latent_h = img_height // compression_factor
    latent_w = img_width // compression_factor
    sample_inputs = (torch.randn(bs,
                                 latent_channels,
                                 latent_h,
                                 latent_w,
                                 dtype=torch_dtype), )
    input_names = ["latent"]
    model = AutoencoderKL.from_pretrained(local_dir,
                                          subfolder=model_dir,
                                          torch_dtype=torch_dtype)
    model.forward = model.decode

    output_names = ["images"]
    dynamic_axes = {
        'latent': {
            0: 'B',
            2: 'H',
            3: 'W'
        },
        'images': {
            0: 'B',
            2: '8H',
            3: '8W'
        }
    }

    with torch.inference_mode():
        torch.onnx.export(model,
                          sample_inputs,
                          onnx_path,
                          export_params=True,
                          input_names=input_names,
                          output_names=output_names,
                          dynamic_axes=dynamic_axes)

    assert os.path.isfile(onnx_path)
    gather_weights_to_one_file(onnx_path)
    print(f"Success export vae_decoder model: {onnx_path}")
    return onnx_path


def parse_args():
    parser = argparse.ArgumentParser(description="export ONNX models")
    parser.add_argument("--local-dir",
                        type=str,
                        required=True,
                        help="local directory containing the model")
    parser.add_argument("--save-dir",
                        type=str,
                        required=None,
                        help="the directory for saving ONNX models")
    args = parser.parse_args()
    
    if args.save_dir is None:
        args.save_dir = args.local_dir

    return args


def main():
    args = parse_args()
    local_dir = args.local_dir
    save_dir = args.save_dir
    os.makedirs(save_dir, exist_ok=True)

    export_clip(local_dir, save_dir=save_dir)
    export_t5(local_dir, save_dir=save_dir)
    export_transformer(local_dir, save_dir=save_dir)
    export_vae(local_dir, save_dir=save_dir)

    if save_dir != local_dir:
        copy_files(local_dir, save_dir, overwrite=True)


if __name__ == "__main__":
    main()