Commit f38e3626 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make style

parent 5f826a35
...@@ -13,16 +13,13 @@ ...@@ -13,16 +13,13 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import os
import shutil
from pathlib import Path from pathlib import Path
import torch import torch
from packaging import version
from torch.onnx import export from torch.onnx import export
import onnx from diffusers import AutoencoderKL
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline, AutoencoderKL
from packaging import version
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11") is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
...@@ -79,9 +76,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F ...@@ -79,9 +76,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
output_path = Path(output_path) output_path = Path(output_path)
# VAE DECODER # VAE DECODER
vae_decoder = AutoencoderKL.from_pretrained(model_path + "/vae") vae_decoder = AutoencoderKL.from_pretrained(model_path + "/vae")
vae_latent_channels = vae_decoder.config.latent_channels vae_latent_channels = vae_decoder.config.latent_channels
vae_out_channels = vae_decoder.config.out_channels
# forward only through the decoder part # forward only through the decoder part
vae_decoder.forward = vae_decoder.decode vae_decoder.forward = vae_decoder.decode
onnx_export( onnx_export(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment