Unverified Commit 6bd005eb authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

[ONNX] Collate the external weights, speed up loading from the hub (#610)

parent a9fdb3de
...@@ -13,11 +13,14 @@ ...@@ -13,11 +13,14 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
import shutil
from pathlib import Path from pathlib import Path
import torch import torch
from torch.onnx import export from torch.onnx import export
import onnx
from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline
from diffusers.onnx_utils import OnnxRuntimeModel from diffusers.onnx_utils import OnnxRuntimeModel
from packaging import version from packaging import version
...@@ -92,10 +95,11 @@ def convert_models(model_path: str, output_path: str, opset: int): ...@@ -92,10 +95,11 @@ def convert_models(model_path: str, output_path: str, opset: int):
) )
# UNET # UNET
unet_path = output_path / "unet" / "model.onnx"
onnx_export( onnx_export(
pipeline.unet, pipeline.unet,
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False), model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
output_path=output_path / "unet" / "model.onnx", output_path=unet_path,
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
output_names=["out_sample"], # has to be different from "sample" for correct tracing output_names=["out_sample"], # has to be different from "sample" for correct tracing
dynamic_axes={ dynamic_axes={
...@@ -106,6 +110,21 @@ def convert_models(model_path: str, output_path: str, opset: int): ...@@ -106,6 +110,21 @@ def convert_models(model_path: str, output_path: str, opset: int):
opset=opset, opset=opset,
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
) )
unet_model_path = str(unet_path.absolute().as_posix())
unet_dir = os.path.dirname(unet_model_path)
unet = onnx.load(unet_model_path)
# clean up existing tensor files
shutil.rmtree(unet_dir)
os.mkdir(unet_dir)
# collate external tensor files into one
onnx.save_model(
unet,
unet_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.pb",
convert_attribute=False,
)
# VAE ENCODER # VAE ENCODER
vae_encoder = pipeline.vae vae_encoder = pipeline.vae
......
...@@ -90,8 +90,10 @@ _deps = [ ...@@ -90,8 +90,10 @@ _deps = [
"isort>=5.5.4", "isort>=5.5.4",
"jax>=0.2.8,!=0.3.2,<=0.3.6", "jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib>=0.1.65,<=0.3.6", "jaxlib>=0.1.65,<=0.3.6",
"modelcards==0.1.4", "modelcards>=0.1.4",
"numpy", "numpy",
"onnxruntime",
"onnxruntime-gpu",
"pytest", "pytest",
"pytest-timeout", "pytest-timeout",
"pytest-xdist", "pytest-xdist",
...@@ -100,6 +102,7 @@ _deps = [ ...@@ -100,6 +102,7 @@ _deps = [
"requests", "requests",
"tensorboard", "tensorboard",
"torch>=1.4", "torch>=1.4",
"torchvision",
"transformers>=4.21.0", "transformers>=4.21.0",
] ]
...@@ -171,10 +174,20 @@ extras = {} ...@@ -171,10 +174,20 @@ extras = {}
extras = {} extras = {}
extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"] extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
extras["docs"] = ["hf-doc-builder"] extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"] extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers"] extras["test"] = deps_list(
"datasets",
"onnxruntime",
"onnxruntime-gpu",
"pytest",
"pytest-timeout",
"pytest-xdist",
"scipy",
"torchvision",
"transformers"
)
extras["torch"] = deps_list("torch") extras["torch"] = deps_list("torch")
if os.name == "nt": # windows if os.name == "nt": # windows
......
...@@ -15,8 +15,10 @@ deps = { ...@@ -15,8 +15,10 @@ deps = {
"isort": "isort>=5.5.4", "isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6", "jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib": "jaxlib>=0.1.65,<=0.3.6", "jaxlib": "jaxlib>=0.1.65,<=0.3.6",
"modelcards": "modelcards==0.1.4", "modelcards": "modelcards>=0.1.4",
"numpy": "numpy", "numpy": "numpy",
"onnxruntime": "onnxruntime",
"onnxruntime-gpu": "onnxruntime-gpu",
"pytest": "pytest", "pytest": "pytest",
"pytest-timeout": "pytest-timeout", "pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist", "pytest-xdist": "pytest-xdist",
...@@ -25,5 +27,6 @@ deps = { ...@@ -25,5 +27,6 @@ deps = {
"requests": "requests", "requests": "requests",
"tensorboard": "tensorboard", "tensorboard": "tensorboard",
"torch": "torch>=1.4", "torch": "torch>=1.4",
"torchvision": "torchvision",
"transformers": "transformers>=4.21.0", "transformers": "transformers>=4.21.0",
} }
...@@ -1373,12 +1373,9 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1373,12 +1373,9 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_stable_diffusion_onnx(self): def test_stable_diffusion_onnx(self):
from scripts.convert_stable_diffusion_checkpoint_to_onnx import convert_models sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CUDAExecutionProvider", use_auth_token=True
with tempfile.TemporaryDirectory() as tmpdirname: )
convert_models("CompVis/stable-diffusion-v1-4", tmpdirname, opset=14)
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(tmpdirname, provider="CUDAExecutionProvider")
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
np.random.seed(0) np.random.seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment