Unverified Commit b562b661 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Allow directly passing text embeddings to Stable Diffusion Pipeline for prompt weighting (#2071)

* add text embeds to sd

* add text embeds to sd

* finish tests

* finish

* finish

* make style

* fix tests

* make style

* make style

* up

* better docs

* fix

* fix

* new try

* up

* up

* finish
parent c1184918
...@@ -136,12 +136,21 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -136,12 +136,21 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
return self.device return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
device: (`torch.device`): device: (`torch.device`):
torch device torch device
...@@ -149,47 +158,67 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -149,47 +158,67 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
number of images that should be generated per prompt number of images that should be generated per prompt
do_classifier_free_guidance (`bool`): do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`): negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
text_inputs = self.tokenizer( if prompt_embeds is None:
prompt, text_inputs = self.tokenizer(
padding="max_length", prompt,
max_length=self.tokenizer.model_max_length, padding="max_length",
truncation=True, max_length=self.tokenizer.model_max_length,
return_tensors="pt", truncation=True,
) return_tensors="pt",
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
attention_mask = text_inputs.attention_mask.to(device) text_input_ids, untruncated_ids
else: ):
attention_mask = None removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder( if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
text_input_ids.to(device), attention_mask = text_inputs.attention_mask.to(device)
attention_mask=attention_mask, else:
) attention_mask = None
text_embeddings = text_embeddings[0]
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -209,7 +238,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -209,7 +238,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
...@@ -223,23 +252,27 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -223,23 +252,27 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
uncond_embeddings = self.text_encoder( negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
uncond_embeddings = uncond_embeddings[0] negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
...@@ -325,8 +358,8 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -325,8 +358,8 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]], image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
num_inference_steps: int = 75, num_inference_steps: int = 75,
guidance_scale: float = 9.0, guidance_scale: float = 9.0,
noise_level: int = 20, noise_level: int = 20,
...@@ -335,6 +368,8 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -335,6 +368,8 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -344,8 +379,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -344,8 +379,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
Args: Args:
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
`Image`, or tensor representing an image batch which will be upscaled. * `Image`, or tensor representing an image batch which will be upscaled. *
num_inference_steps (`int`, *optional*, defaults to 50): num_inference_steps (`int`, *optional*, defaults to 50):
...@@ -358,8 +394,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -358,8 +394,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*): negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. If not defined, one has to pass
if `guidance_scale` is less than `1`). `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
...@@ -372,6 +409,13 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -372,6 +409,13 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -422,6 +466,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -422,6 +466,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
# 1. Check inputs # 1. Check inputs
self.check_inputs(prompt, image, noise_level, callback_steps) self.check_inputs(prompt, image, noise_level, callback_steps)
if image is None:
raise ValueError("`image` input cannot be undefined.")
# 2. Define call parameters # 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt) batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device device = self._execution_device
...@@ -431,13 +478,19 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -431,13 +478,19 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) )
# 4. Preprocess image # 4. Preprocess image
image = preprocess(image) image = preprocess(image)
image = image.to(dtype=text_embeddings.dtype, device=device) image = image.to(dtype=prompt_embeds.dtype, device=device)
# 5. set timesteps # 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
...@@ -445,7 +498,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -445,7 +498,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
# 5. Add noise to image # 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype) noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level) image = self.low_res_scheduler.add_noise(image, noise, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1 batch_multiplier = 2 if do_classifier_free_guidance else 1
...@@ -460,7 +513,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -460,7 +513,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
num_channels_latents, num_channels_latents,
height, height,
width, width,
text_embeddings.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
latents, latents,
...@@ -493,7 +546,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -493,7 +546,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level latent_model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level
).sample ).sample
# perform guidance # perform guidance
......
...@@ -217,7 +217,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -217,7 +217,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`):
prompt to be encoded prompt to be encoded
device: (`torch.device`): device: (`torch.device`):
torch device torch device
...@@ -253,16 +253,16 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -253,16 +253,16 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
text_embeddings = self.text_encoder( prompt_embeds = self.text_encoder(
text_input_ids.to(device), text_input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
text_embeddings = text_embeddings[0] prompt_embeds = prompt_embeds[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape bs_embed, seq_len, _ = prompt_embeds.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -299,16 +299,16 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -299,16 +299,16 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
uncond_embeddings = self.text_encoder( negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
uncond_embeddings = uncond_embeddings[0] negative_prompt_embeds = negative_prompt_embeds[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# Encode the safety concept text # Encode the safety concept text
if enable_safety_guidance: if enable_safety_guidance:
...@@ -329,15 +329,15 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -329,15 +329,15 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
# For classifier free guidance + sld, we need to do three forward passes. # For classifier free guidance + sld, we need to do three forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing three forward passes # to avoid doing three forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, safety_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, safety_embeddings])
else: else:
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
def run_safety_checker(self, image, device, dtype, enable_safety_guidance): def run_safety_checker(self, image, device, dtype, enable_safety_guidance):
if self.safety_checker is not None: if self.safety_checker is not None:
...@@ -390,10 +390,16 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -390,10 +390,16 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
return extra_step_kwargs return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(self, prompt, height, width, callback_steps): def check_inputs(
if not isinstance(prompt, str) and not isinstance(prompt, list): self,
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -405,6 +411,32 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -405,6 +411,32 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
...@@ -583,7 +615,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -583,7 +615,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
warnings.warn("Safety checker disabled!") warnings.warn("Safety checker disabled!")
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance
) )
...@@ -598,7 +630,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -598,7 +630,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
num_channels_latents, num_channels_latents,
height, height,
width, width,
text_embeddings.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
latents, latents,
...@@ -621,7 +653,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -621,7 +653,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -680,7 +712,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -680,7 +712,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
# 9. Run safety checker # 9. Run safety checker
image, has_nsfw_concept, flagged_images = self.run_safety_checker( image, has_nsfw_concept, flagged_images = self.run_safety_checker(
image, device, text_embeddings.dtype, enable_safety_guidance image, device, prompt_embeds.dtype, enable_safety_guidance
) )
# 10. Convert to PIL # 10. Convert to PIL
......
...@@ -153,15 +153,15 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -153,15 +153,15 @@ class UnCLIPPipeline(DiffusionPipeline):
text_encoder_output = self.text_encoder(text_input_ids.to(device)) text_encoder_output = self.text_encoder(text_input_ids.to(device))
text_embeddings = text_encoder_output.text_embeds prompt_embeds = text_encoder_output.text_embeds
text_encoder_hidden_states = text_encoder_output.last_hidden_state text_encoder_hidden_states = text_encoder_output.last_hidden_state
else: else:
batch_size = text_model_output[0].shape[0] batch_size = text_model_output[0].shape[0]
text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1] prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask text_mask = text_attention_mask
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
...@@ -176,16 +176,16 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -176,16 +176,16 @@ class UnCLIPPipeline(DiffusionPipeline):
return_tensors="pt", return_tensors="pt",
) )
uncond_text_mask = uncond_input.attention_mask.bool().to(device) uncond_text_mask = uncond_input.attention_mask.bool().to(device)
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
seq_len = uncond_text_encoder_hidden_states.shape[1] seq_len = uncond_text_encoder_hidden_states.shape[1]
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
...@@ -199,12 +199,12 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -199,12 +199,12 @@ class UnCLIPPipeline(DiffusionPipeline):
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
text_mask = torch.cat([uncond_text_mask, text_mask]) text_mask = torch.cat([uncond_text_mask, text_mask])
return text_embeddings, text_encoder_hidden_states, text_mask return prompt_embeds, text_encoder_hidden_states, text_mask
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
...@@ -336,7 +336,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -336,7 +336,7 @@ class UnCLIPPipeline(DiffusionPipeline):
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
) )
...@@ -349,7 +349,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -349,7 +349,7 @@ class UnCLIPPipeline(DiffusionPipeline):
prior_latents = self.prepare_latents( prior_latents = self.prepare_latents(
(batch_size, embedding_dim), (batch_size, embedding_dim),
text_embeddings.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
prior_latents, prior_latents,
...@@ -363,7 +363,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -363,7 +363,7 @@ class UnCLIPPipeline(DiffusionPipeline):
predicted_image_embedding = self.prior( predicted_image_embedding = self.prior(
latent_model_input, latent_model_input,
timestep=t, timestep=t,
proj_embedding=text_embeddings, proj_embedding=prompt_embeds,
encoder_hidden_states=text_encoder_hidden_states, encoder_hidden_states=text_encoder_hidden_states,
attention_mask=text_mask, attention_mask=text_mask,
).predicted_image_embedding ).predicted_image_embedding
...@@ -397,7 +397,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -397,7 +397,7 @@ class UnCLIPPipeline(DiffusionPipeline):
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings, image_embeddings=image_embeddings,
text_embeddings=text_embeddings, prompt_embeds=prompt_embeds,
text_encoder_hidden_states=text_encoder_hidden_states, text_encoder_hidden_states=text_encoder_hidden_states,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
) )
......
...@@ -136,10 +136,10 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -136,10 +136,10 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
text_mask = text_inputs.attention_mask.bool().to(device) text_mask = text_inputs.attention_mask.bool().to(device)
text_encoder_output = self.text_encoder(text_input_ids.to(device)) text_encoder_output = self.text_encoder(text_input_ids.to(device))
text_embeddings = text_encoder_output.text_embeds prompt_embeds = text_encoder_output.text_embeds
text_encoder_hidden_states = text_encoder_output.last_hidden_state text_encoder_hidden_states = text_encoder_output.last_hidden_state
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
...@@ -155,16 +155,16 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -155,16 +155,16 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
return_tensors="pt", return_tensors="pt",
) )
uncond_text_mask = uncond_input.attention_mask.bool().to(device) uncond_text_mask = uncond_input.attention_mask.bool().to(device)
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
seq_len = uncond_text_encoder_hidden_states.shape[1] seq_len = uncond_text_encoder_hidden_states.shape[1]
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
...@@ -178,12 +178,12 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -178,12 +178,12 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
text_mask = torch.cat([uncond_text_mask, text_mask]) text_mask = torch.cat([uncond_text_mask, text_mask])
return text_embeddings, text_encoder_hidden_states, text_mask return prompt_embeds, text_encoder_hidden_states, text_mask
def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None): def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
...@@ -314,7 +314,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -314,7 +314,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
do_classifier_free_guidance = decoder_guidance_scale > 1.0 do_classifier_free_guidance = decoder_guidance_scale > 1.0
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance prompt, device, num_images_per_prompt, do_classifier_free_guidance
) )
...@@ -323,7 +323,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -323,7 +323,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
# decoder # decoder
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings, image_embeddings=image_embeddings,
text_embeddings=text_embeddings, prompt_embeds=prompt_embeds,
text_encoder_hidden_states=text_encoder_hidden_states, text_encoder_hidden_states=text_encoder_hidden_states,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
) )
......
...@@ -52,7 +52,7 @@ class UnCLIPTextProjModel(ModelMixin, ConfigMixin): ...@@ -52,7 +52,7 @@ class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim) self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim)
self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim) self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim)
def forward(self, *, image_embeddings, text_embeddings, text_encoder_hidden_states, do_classifier_free_guidance): def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states, do_classifier_free_guidance):
if do_classifier_free_guidance: if do_classifier_free_guidance:
# Add the classifier free guidance embeddings to the image embeddings # Add the classifier free guidance embeddings to the image embeddings
image_embeddings_batch_size = image_embeddings.shape[0] image_embeddings_batch_size = image_embeddings.shape[0]
...@@ -63,15 +63,15 @@ class UnCLIPTextProjModel(ModelMixin, ConfigMixin): ...@@ -63,15 +63,15 @@ class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0) image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0)
# The image embeddings batch size and the text embeddings batch size are equal # The image embeddings batch size and the text embeddings batch size are equal
assert image_embeddings.shape[0] == text_embeddings.shape[0] assert image_embeddings.shape[0] == prompt_embeds.shape[0]
batch_size = text_embeddings.shape[0] batch_size = prompt_embeds.shape[0]
# "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and # "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and
# adding CLIP embeddings to the existing timestep embedding, ... # adding CLIP embeddings to the existing timestep embedding, ...
time_projected_text_embeddings = self.embedding_proj(text_embeddings) time_projected_prompt_embeds = self.embedding_proj(prompt_embeds)
time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings) time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings)
additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_text_embeddings additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_prompt_embeds
# ... and by projecting CLIP embeddings into four # ... and by projecting CLIP embeddings into four
# extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder" # extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder"
......
...@@ -187,7 +187,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -187,7 +187,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`):
prompt to be encoded prompt to be encoded
device: (`torch.device`): device: (`torch.device`):
torch device torch device
...@@ -227,16 +227,16 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -227,16 +227,16 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
text_embeddings = self.text_encoder( prompt_embeds = self.text_encoder(
text_input_ids.to(device), text_input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
text_embeddings = normalize_embeddings(text_embeddings) prompt_embeds = normalize_embeddings(prompt_embeds)
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape bs_embed, seq_len, _ = prompt_embeds.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -255,30 +255,30 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -255,30 +255,30 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
uncond_embeddings = self.text_encoder( negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
uncond_embeddings = normalize_embeddings(uncond_embeddings) negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`):
prompt to be encoded prompt to be encoded
device: (`torch.device`): device: (`torch.device`):
torch device torch device
...@@ -313,18 +313,18 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -313,18 +313,18 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size
uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt")
pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype)
uncond_embeddings = self.image_encoder(pixel_values) negative_prompt_embeds = self.image_encoder(pixel_values)
uncond_embeddings = normalize_embeddings(uncond_embeddings) negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and conditional embeddings into a single batch # Here we concatenate the unconditional and conditional embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
return image_embeddings return image_embeddings
...@@ -524,9 +524,9 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -524,9 +524,9 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompts # 3. Encode input prompts
text_embeddings = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) prompt_embeds = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance)
image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance) image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance)
dual_prompt_embeddings = torch.cat([text_embeddings, image_embeddings], dim=1) dual_prompt_embeddings = torch.cat([prompt_embeds, image_embeddings], dim=1)
prompt_types = ("text", "image") prompt_types = ("text", "image")
# 4. Prepare timesteps # 4. Prepare timesteps
......
...@@ -114,7 +114,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -114,7 +114,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`):
prompt to be encoded prompt to be encoded
device: (`torch.device`): device: (`torch.device`):
torch device torch device
...@@ -173,18 +173,18 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -173,18 +173,18 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt")
pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype)
uncond_embeddings = self.image_encoder(pixel_values) negative_prompt_embeds = self.image_encoder(pixel_values)
uncond_embeddings = normalize_embeddings(uncond_embeddings) negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and conditional embeddings into a single batch # Here we concatenate the unconditional and conditional embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
return image_embeddings return image_embeddings
......
...@@ -138,7 +138,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -138,7 +138,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
Args: Args:
prompt (`str` or `list(int)`): prompt (`str` or `List[str]`):
prompt to be encoded prompt to be encoded
device: (`torch.device`): device: (`torch.device`):
torch device torch device
...@@ -181,16 +181,16 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -181,16 +181,16 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
text_embeddings = self.text_encoder( prompt_embeds = self.text_encoder(
text_input_ids.to(device), text_input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
text_embeddings = normalize_embeddings(text_embeddings) prompt_embeds = normalize_embeddings(prompt_embeds)
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape bs_embed, seq_len, _ = prompt_embeds.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -227,23 +227,23 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -227,23 +227,23 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
uncond_embeddings = self.text_encoder( negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
uncond_embeddings = normalize_embeddings(uncond_embeddings) negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
...@@ -273,10 +273,16 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -273,10 +273,16 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
return extra_step_kwargs return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(self, prompt, height, width, callback_steps): def check_inputs(
if not isinstance(prompt, str) and not isinstance(prompt, list): self,
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -288,6 +294,32 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -288,6 +294,32 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
...@@ -412,7 +444,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -412,7 +444,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
...@@ -427,7 +459,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -427,7 +459,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
num_channels_latents, num_channels_latents,
height, height,
width, width,
text_embeddings.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
latents, latents,
...@@ -443,7 +475,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -443,7 +475,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
......
...@@ -120,7 +120,7 @@ class VQDiffusionPipeline(DiffusionPipeline): ...@@ -120,7 +120,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] prompt_embeds = self.text_encoder(text_input_ids.to(self.device))[0]
# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
# While CLIP does normalize the pooled output of the text transformer when combining # While CLIP does normalize the pooled output of the text transformer when combining
...@@ -128,15 +128,15 @@ class VQDiffusionPipeline(DiffusionPipeline): ...@@ -128,15 +128,15 @@ class VQDiffusionPipeline(DiffusionPipeline):
# #
# CLIP normalizing the pooled output. # CLIP normalizing the pooled output.
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
# duplicate text embeddings for each generation per prompt # duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance: if do_classifier_free_guidance:
if self.learned_classifier_free_sampling_embeddings.learnable: if self.learned_classifier_free_sampling_embeddings.learnable:
uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings negative_prompt_embeds = self.learned_classifier_free_sampling_embeddings.embeddings
uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1) negative_prompt_embeds = negative_prompt_embeds.unsqueeze(0).repeat(batch_size, 1, 1)
else: else:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -148,21 +148,21 @@ class VQDiffusionPipeline(DiffusionPipeline): ...@@ -148,21 +148,21 @@ class VQDiffusionPipeline(DiffusionPipeline):
truncation=True, truncation=True,
return_tensors="pt", return_tensors="pt",
) )
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# See comment for normalizing text embeddings # See comment for normalizing text embeddings
uncond_embeddings = uncond_embeddings / uncond_embeddings.norm(dim=-1, keepdim=True) negative_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = negative_prompt_embeds.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return text_embeddings return prompt_embeds
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
...@@ -234,7 +234,7 @@ class VQDiffusionPipeline(DiffusionPipeline): ...@@ -234,7 +234,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance) prompt_embeds = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance)
if (callback_steps is None) or ( if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
...@@ -273,9 +273,7 @@ class VQDiffusionPipeline(DiffusionPipeline): ...@@ -273,9 +273,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
# predict the un-noised image # predict the un-noised image
# model_output == `log_p_x_0` # model_output == `log_p_x_0`
model_output = self.transformer( model_output = self.transformer(latent_model_input, encoder_hidden_states=prompt_embeds, timestep=t).sample
latent_model_input, encoder_hidden_states=text_embeddings, timestep=t
).sample
if do_classifier_free_guidance: if do_classifier_free_guidance:
model_output_uncond, model_output_text = model_output.chunk(2) model_output_uncond, model_output_text = model_output.chunk(2)
......
...@@ -120,7 +120,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -120,7 +120,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components) sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
...@@ -134,6 +134,82 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -134,6 +134,82 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
inputs = self.get_dummy_inputs(torch_device)
prompt = 3 * [inputs.pop("prompt")]
text_inputs = sd_pipe.tokenizer(
prompt,
padding="max_length",
max_length=sd_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_inputs = text_inputs["input_ids"].to(torch_device)
prompt_embeds = sd_pipe.text_encoder(text_inputs)[0]
inputs["prompt_embeds"] = prompt_embeds
# forward
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_stable_diffusion_negative_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
negative_prompt = 3 * ["this is a negative prompt"]
inputs["negative_prompt"] = negative_prompt
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
inputs = self.get_dummy_inputs(torch_device)
prompt = 3 * [inputs.pop("prompt")]
embeds = []
for p in [prompt, negative_prompt]:
text_inputs = sd_pipe.tokenizer(
p,
padding="max_length",
max_length=sd_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_inputs = text_inputs["input_ids"].to(torch_device)
embeds.append(sd_pipe.text_encoder(text_inputs)[0])
inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds
# forward
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_stable_diffusion_ddim_factor_8(self): def test_stable_diffusion_ddim_factor_8(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment