"git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "1a47fa08adb9d1aafd8454c43c0265f952ccbda0"
Unverified Commit bf92e746 authored by gujing's avatar gujing Committed by GitHub
Browse files

fix StableDiffusionTensorRT super args error (#6009)

parent b785a155
......@@ -41,7 +41,7 @@ from polygraphy.backend.trt import (
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
......@@ -709,6 +709,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"],
image_height: int = 512,
......@@ -724,7 +725,15 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
timing_cache: str = "timing_cache",
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)
self.vae.forward = self.vae.decode
......
......@@ -41,7 +41,7 @@ from polygraphy.backend.trt import (
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
......@@ -710,6 +710,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"],
image_height: int = 512,
......@@ -725,7 +726,15 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
timing_cache: str = "timing_cache",
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)
self.vae.forward = self.vae.decode
......
......@@ -40,7 +40,7 @@ from polygraphy.backend.trt import (
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
......@@ -624,6 +624,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae"],
image_height: int = 768,
......@@ -639,7 +640,15 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
timing_cache: str = "timing_cache",
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)
self.vae.forward = self.vae.decode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment