Unverified Commit 141cd52d authored by Kristian Mischke's avatar Kristian Mischke Committed by GitHub
Browse files

Fix LLMGroundedDiffusionPipeline super class arguments (#5993)

* make `requires_safety_checker` a kwarg instead of a positional argument as it's more future-proof

* apply `make style` formatting edits

* add image_encoder to arguments and pass to super constructor
parent f72b28c7
......@@ -23,7 +23,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention import Attention, GatedSelfAttentionDense
......@@ -272,10 +272,19 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)
self.register_attn_hooks(unet)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment