Unverified Commit 141cd52d authored by Kristian Mischke's avatar Kristian Mischke Committed by GitHub
Browse files

Fix LLMGroundedDiffusionPipeline super class arguments (#5993)

* make `requires_safety_checker` a kwarg instead of a positional argument as it's more future-proof

* apply `make style` formatting edits

* add image_encoder to arguments and pass to super constructor
parent f72b28c7
...@@ -23,7 +23,7 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -23,7 +23,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention import Attention, GatedSelfAttentionDense from diffusers.models.attention import Attention, GatedSelfAttentionDense
...@@ -272,10 +272,19 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline): ...@@ -272,10 +272,19 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor, feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__( super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
) )
self.register_attn_hooks(unet) self.register_attn_hooks(unet)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment