"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "2d39723610357e653c0e0427fb7877dc3e274661"
Unverified Commit bf92e746 authored by gujing's avatar gujing Committed by GitHub
Browse files

fix StableDiffusionTensorRT super args error (#6009)

parent b785a155
...@@ -41,7 +41,7 @@ from polygraphy.backend.trt import ( ...@@ -41,7 +41,7 @@ from polygraphy.backend.trt import (
save_engine, save_engine,
) )
from polygraphy.backend.trt import util as trt_util from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import ( from diffusers.pipelines.stable_diffusion import (
...@@ -709,6 +709,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -709,6 +709,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
scheduler: DDIMScheduler, scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"], stages=["clip", "unet", "vae", "vae_encoder"],
image_height: int = 512, image_height: int = 512,
...@@ -724,7 +725,15 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): ...@@ -724,7 +725,15 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
timing_cache: str = "timing_cache", timing_cache: str = "timing_cache",
): ):
super().__init__( super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
) )
self.vae.forward = self.vae.decode self.vae.forward = self.vae.decode
......
...@@ -41,7 +41,7 @@ from polygraphy.backend.trt import ( ...@@ -41,7 +41,7 @@ from polygraphy.backend.trt import (
save_engine, save_engine,
) )
from polygraphy.backend.trt import util as trt_util from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import ( from diffusers.pipelines.stable_diffusion import (
...@@ -710,6 +710,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -710,6 +710,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
scheduler: DDIMScheduler, scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"], stages=["clip", "unet", "vae", "vae_encoder"],
image_height: int = 512, image_height: int = 512,
...@@ -725,7 +726,15 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): ...@@ -725,7 +726,15 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
timing_cache: str = "timing_cache", timing_cache: str = "timing_cache",
): ):
super().__init__( super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
) )
self.vae.forward = self.vae.decode self.vae.forward = self.vae.decode
......
...@@ -40,7 +40,7 @@ from polygraphy.backend.trt import ( ...@@ -40,7 +40,7 @@ from polygraphy.backend.trt import (
save_engine, save_engine,
) )
from polygraphy.backend.trt import util as trt_util from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import ( from diffusers.pipelines.stable_diffusion import (
...@@ -624,6 +624,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -624,6 +624,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
scheduler: DDIMScheduler, scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
stages=["clip", "unet", "vae"], stages=["clip", "unet", "vae"],
image_height: int = 768, image_height: int = 768,
...@@ -639,7 +640,15 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -639,7 +640,15 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
timing_cache: str = "timing_cache", timing_cache: str = "timing_cache",
): ):
super().__init__( super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
) )
self.vae.forward = self.vae.decode self.vae.forward = self.vae.decode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment