convert_stable_diffusion_checkpoint_to_onnx.py 8.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
16
17
import os
import shutil
18
19
20
21
22
from pathlib import Path

import torch
from torch.onnx import export

23
import onnx
24
from diffusers import OnnxStableDiffusionPipeline, StableDiffusionPipeline
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from diffusers.onnx_utils import OnnxRuntimeModel
from packaging import version


is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")


def onnx_export(
    model,
    model_args: tuple,
    output_path: Path,
    ordered_input_names,
    output_names,
    dynamic_axes,
    opset,
    use_external_data_format=False,
):
    output_path.parent.mkdir(parents=True, exist_ok=True)
    # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
    # so we check the torch version for backwards compatibility
    if is_torch_less_than_1_11:
        export(
            model,
            model_args,
            f=output_path.as_posix(),
            input_names=ordered_input_names,
            output_names=output_names,
            dynamic_axes=dynamic_axes,
            do_constant_folding=True,
            use_external_data_format=use_external_data_format,
            enable_onnx_checker=True,
            opset_version=opset,
        )
    else:
        export(
            model,
            model_args,
            f=output_path.as_posix(),
            input_names=ordered_input_names,
            output_names=output_names,
            dynamic_axes=dynamic_axes,
            do_constant_folding=True,
            opset_version=opset,
        )


@torch.no_grad()
72
73
74
75
76
77
78
79
80
def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = False):
    dtype = torch.float16 if fp16 else torch.float32
    if fp16 and torch.cuda.is_available():
        device = "cuda"
    elif fp16 and not torch.cuda.is_available():
        raise ValueError("`float16` model export is only supported on GPUs with CUDA")
    else:
        device = "cpu"
    pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
81
82
83
84
85
86
87
88
89
90
91
92
93
    output_path = Path(output_path)

    # TEXT ENCODER
    text_input = pipeline.tokenizer(
        "A sample prompt",
        padding="max_length",
        max_length=pipeline.tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    onnx_export(
        pipeline.text_encoder,
        # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
94
        model_args=(text_input.input_ids.to(device=device, dtype=torch.int32)),
95
96
97
98
99
100
101
102
        output_path=output_path / "text_encoder" / "model.onnx",
        ordered_input_names=["input_ids"],
        output_names=["last_hidden_state", "pooler_output"],
        dynamic_axes={
            "input_ids": {0: "batch", 1: "sequence"},
        },
        opset=opset,
    )
103
    del pipeline.text_encoder
104
105

    # UNET
106
    unet_path = output_path / "unet" / "model.onnx"
107
108
    onnx_export(
        pipeline.unet,
109
        model_args=(
110
111
112
            torch.randn(2, pipeline.unet.in_channels, 64, 64).to(device=device, dtype=dtype),
            torch.LongTensor([0, 1]).to(device=device),
            torch.randn(2, 77, 768).to(device=device, dtype=dtype),
113
114
            False,
        ),
115
        output_path=unet_path,
116
117
118
119
120
121
122
123
124
125
        ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
        output_names=["out_sample"],  # has to be different from "sample" for correct tracing
        dynamic_axes={
            "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
            "timestep": {0: "batch"},
            "encoder_hidden_states": {0: "batch", 1: "sequence"},
        },
        opset=opset,
        use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
    )
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    unet_model_path = str(unet_path.absolute().as_posix())
    unet_dir = os.path.dirname(unet_model_path)
    unet = onnx.load(unet_model_path)
    # clean up existing tensor files
    shutil.rmtree(unet_dir)
    os.mkdir(unet_dir)
    # collate external tensor files into one
    onnx.save_model(
        unet,
        unet_model_path,
        save_as_external_data=True,
        all_tensors_to_one_file=True,
        location="weights.pb",
        convert_attribute=False,
    )
141
    del pipeline.unet
142
143
144
145
146
147
148

    # VAE ENCODER
    vae_encoder = pipeline.vae
    # need to get the raw tensor output (sample) from the encoder
    vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
    onnx_export(
        vae_encoder,
149
        model_args=(torch.randn(1, 3, 512, 512).to(device=device, dtype=dtype), False),
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        output_path=output_path / "vae_encoder" / "model.onnx",
        ordered_input_names=["sample", "return_dict"],
        output_names=["latent_sample"],
        dynamic_axes={
            "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
        },
        opset=opset,
    )

    # VAE DECODER
    vae_decoder = pipeline.vae
    # forward only through the decoder part
    vae_decoder.forward = vae_encoder.decode
    onnx_export(
        vae_decoder,
165
        model_args=(torch.randn(1, 4, 64, 64).to(device=device, dtype=dtype), False),
166
167
168
169
170
171
172
173
        output_path=output_path / "vae_decoder" / "model.onnx",
        ordered_input_names=["latent_sample", "return_dict"],
        output_names=["sample"],
        dynamic_axes={
            "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
        },
        opset=opset,
    )
174
    del pipeline.vae
175
176
177
178
179
180

    # SAFETY CHECKER
    safety_checker = pipeline.safety_checker
    safety_checker.forward = safety_checker.forward_onnx
    onnx_export(
        pipeline.safety_checker,
181
182
183
184
        model_args=(
            torch.randn(1, 3, 224, 224).to(device=device, dtype=dtype),
            torch.randn(1, 512, 512, 3).to(device=device, dtype=dtype),
        ),
185
186
187
188
189
        output_path=output_path / "safety_checker" / "model.onnx",
        ordered_input_names=["clip_input", "images"],
        output_names=["out_images", "has_nsfw_concepts"],
        dynamic_axes={
            "clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
190
            "images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
191
192
193
        },
        opset=opset,
    )
194
    del pipeline.safety_checker
195

196
    onnx_pipeline = OnnxStableDiffusionPipeline(
197
        vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
198
199
200
201
202
203
204
205
206
207
208
209
        vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
        text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
        tokenizer=pipeline.tokenizer,
        unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
        scheduler=pipeline.scheduler,
        safety_checker=OnnxRuntimeModel.from_pretrained(output_path / "safety_checker"),
        feature_extractor=pipeline.feature_extractor,
    )

    onnx_pipeline.save_pretrained(output_path)
    print("ONNX pipeline saved to", output_path)

210
211
    del pipeline
    del onnx_pipeline
212
    _ = OnnxStableDiffusionPipeline.from_pretrained(output_path, provider="CPUExecutionProvider")
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    print("ONNX pipeline is loadable")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--model_path",
        type=str,
        required=True,
        help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).",
    )

    parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.")

    parser.add_argument(
        "--opset",
        default=14,
231
        type=int,
232
233
        help="The version of the ONNX operator set to use.",
    )
234
    parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode")
235
236
237

    args = parser.parse_args()

238
    convert_models(args.model_path, args.output_path, args.opset, args.fp16)