Unverified Commit a5720e9e authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[PixArt-Alpha] Fix PixArt-Alpha pipeline when number of images to generate is more than 1 (#5752)



* does this fix things?

* attention mask use

* attention mask order

* better masking.

* add: tesrt

* remove mask_featur

* test

* debug

* fix: tests

* deprecate mask_feature

* add deprecation test

* add slow test

* add print statements to retrieve the assertion values.

* fix for the 1024 fast tes

* fix tesy

* fix the remaining

* Apply suggestions from code review

* more debug

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 16d50045
...@@ -27,6 +27,7 @@ from ...models import AutoencoderKL, Transformer2DModel ...@@ -27,6 +27,7 @@ from ...models import AutoencoderKL, Transformer2DModel
from ...schedulers import DPMSolverMultistepScheduler from ...schedulers import DPMSolverMultistepScheduler
from ...utils import ( from ...utils import (
BACKENDS_MAPPING, BACKENDS_MAPPING,
deprecate,
is_bs4_available, is_bs4_available,
is_ftfy_available, is_ftfy_available,
logging, logging,
...@@ -162,8 +163,10 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -162,8 +163,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
clean_caption: bool = False, clean_caption: bool = False,
mask_feature: bool = True, **kwargs,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -189,10 +192,11 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -189,10 +192,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
string. string.
clean_caption (bool, defaults to `False`): clean_caption (bool, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding. If `True`, the function will preprocess and clean the provided caption before encoding.
mask_feature: (bool, defaults to `True`):
If `True`, the function will mask the text embeddings.
""" """
embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
if "mask_feature" in kwargs:
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
if device is None: if device is None:
device = self._execution_device device = self._execution_device
...@@ -229,13 +233,11 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -229,13 +233,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
f" {max_length} tokens: {removed_text}" f" {max_length} tokens: {removed_text}"
) )
attention_mask = text_inputs.attention_mask.to(device) prompt_attention_mask = text_inputs.attention_mask
prompt_embeds_attention_mask = attention_mask prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
else:
prompt_embeds_attention_mask = torch.ones_like(prompt_embeds)
if self.text_encoder is not None: if self.text_encoder is not None:
dtype = self.text_encoder.dtype dtype = self.text_encoder.dtype
...@@ -250,8 +252,8 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -250,8 +252,8 @@ class PixArtAlphaPipeline(DiffusionPipeline):
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1) prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1) prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None: if do_classifier_free_guidance and negative_prompt_embeds is None:
...@@ -267,11 +269,11 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -267,11 +269,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
add_special_tokens=True, add_special_tokens=True,
return_tensors="pt", return_tensors="pt",
) )
attention_mask = uncond_input.attention_mask.to(device) negative_prompt_attention_mask = uncond_input.attention_mask
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
negative_prompt_embeds = self.text_encoder( negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
attention_mask=attention_mask,
) )
negative_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds[0]
...@@ -284,23 +286,13 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -284,23 +286,13 @@ class PixArtAlphaPipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
# Here we concatenate the unconditional and text embeddings into a single batch negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
# to avoid doing two forward passes
else: else:
negative_prompt_embeds = None negative_prompt_embeds = None
negative_prompt_attention_mask = None
# Perform additional masking. return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
if mask_feature and not embeds_initially_provided:
prompt_embeds = prompt_embeds.unsqueeze(1)
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
masked_negative_prompt_embeds = (
negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None
)
return masked_prompt_embeds, masked_negative_prompt_embeds
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
...@@ -329,6 +321,8 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -329,6 +321,8 @@ class PixArtAlphaPipeline(DiffusionPipeline):
callback_steps, callback_steps,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -365,6 +359,12 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -365,6 +359,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
) )
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape: if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError( raise ValueError(
...@@ -372,6 +372,12 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -372,6 +372,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False): def _text_preprocessing(self, text, clean_caption=False):
...@@ -579,14 +585,16 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -579,14 +585,16 @@ class PixArtAlphaPipeline(DiffusionPipeline):
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
clean_caption: bool = True, clean_caption: bool = True,
mask_feature: bool = True,
use_resolution_binning: bool = True, use_resolution_binning: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]: ) -> Union[ImagePipelineOutput, Tuple]:
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -630,9 +638,12 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -630,9 +638,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -648,11 +659,10 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -648,11 +659,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt. prompt.
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked. use_resolution_binning (`bool` defaults to `True`):
use_resolution_binning: If set to `True`, the requested height and width are first mapped to the closest resolutions using
(`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, the requested resolution. Useful for generating non-square images.
they are resized back to the requested resolution. Useful for generating non-square images.
Examples: Examples:
...@@ -661,6 +671,9 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -661,6 +671,9 @@ class PixArtAlphaPipeline(DiffusionPipeline):
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images returned where the first element is a list with the generated images
""" """
if "mask_feature" in kwargs:
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor
...@@ -669,7 +682,15 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -669,7 +682,15 @@ class PixArtAlphaPipeline(DiffusionPipeline):
height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN) height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN)
self.check_inputs( self.check_inputs(
prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds prompt,
height,
width,
negative_prompt,
callback_steps,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) )
# 2. Default height and width to transformer # 2. Default height and width to transformer
...@@ -688,7 +709,12 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -688,7 +709,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt( (
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt, prompt,
do_classifier_free_guidance, do_classifier_free_guidance,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
...@@ -696,11 +722,13 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -696,11 +722,13 @@ class PixArtAlphaPipeline(DiffusionPipeline):
device=device, device=device,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption, clean_caption=clean_caption,
mask_feature=mask_feature,
) )
if do_classifier_free_guidance: if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps # 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
...@@ -758,6 +786,7 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -758,6 +786,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
noise_pred = self.transformer( noise_pred = self.transformer(
latent_model_input, latent_model_input,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=current_timestep, timestep=current_timestep,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
......
...@@ -111,13 +111,20 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -111,13 +111,20 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
num_inference_steps = inputs["num_inference_steps"] num_inference_steps = inputs["num_inference_steps"]
output_type = inputs["output_type"] output_type = inputs["output_type"]
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt, mask_feature=False) (
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = pipe.encode_prompt(prompt)
# inputs with prompt converted to embeddings # inputs with prompt converted to embeddings
inputs = { inputs = {
"prompt_embeds": prompt_embeds, "prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attention_mask,
"negative_prompt": None, "negative_prompt": None,
"negative_prompt_embeds": negative_prompt_embeds, "negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": negative_prompt_attention_mask,
"generator": generator, "generator": generator,
"num_inference_steps": num_inference_steps, "num_inference_steps": num_inference_steps,
"output_type": output_type, "output_type": output_type,
...@@ -151,8 +158,10 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -151,8 +158,10 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# inputs with prompt converted to embeddings # inputs with prompt converted to embeddings
inputs = { inputs = {
"prompt_embeds": prompt_embeds, "prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attention_mask,
"negative_prompt": None, "negative_prompt": None,
"negative_prompt_embeds": negative_prompt_embeds, "negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": negative_prompt_attention_mask,
"generator": generator, "generator": generator,
"num_inference_steps": num_inference_steps, "num_inference_steps": num_inference_steps,
"output_type": output_type, "output_type": output_type,
...@@ -211,13 +220,15 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -211,13 +220,15 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
num_inference_steps = inputs["num_inference_steps"] num_inference_steps = inputs["num_inference_steps"]
output_type = inputs["output_type"] output_type = inputs["output_type"]
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt) prompt_embeds, prompt_attn_mask, negative_prompt_embeds, neg_prompt_attn_mask = pipe.encode_prompt(prompt)
# inputs with prompt converted to embeddings # inputs with prompt converted to embeddings
inputs = { inputs = {
"prompt_embeds": prompt_embeds, "prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attn_mask,
"negative_prompt": None, "negative_prompt": None,
"negative_prompt_embeds": negative_prompt_embeds, "negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": neg_prompt_attn_mask,
"generator": generator, "generator": generator,
"num_inference_steps": num_inference_steps, "num_inference_steps": num_inference_steps,
"output_type": output_type, "output_type": output_type,
...@@ -252,8 +263,10 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -252,8 +263,10 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# inputs with prompt converted to embeddings # inputs with prompt converted to embeddings
inputs = { inputs = {
"prompt_embeds": prompt_embeds, "prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attn_mask,
"negative_prompt": None, "negative_prompt": None,
"negative_prompt_embeds": negative_prompt_embeds, "negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": neg_prompt_attn_mask,
"generator": generator, "generator": generator,
"num_inference_steps": num_inference_steps, "num_inference_steps": num_inference_steps,
"output_type": output_type, "output_type": output_type,
...@@ -266,6 +279,40 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -266,6 +279,40 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, 1e-4) self.assertLess(max_diff, 1e-4)
def test_inference_with_multiple_images_per_prompt(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_images_per_prompt"] = 2
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
self.assertEqual(image.shape, (2, 8, 8, 3))
expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_raises_warning_for_mask_feature(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs.update({"mask_feature": True})
with self.assertWarns(FutureWarning) as warning_ctx:
_ = pipe(**inputs).images
assert "mask_feature" in str(warning_ctx.warning)
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3) self._test_inference_batch_single_identical(expected_max_diff=1e-3)
...@@ -290,7 +337,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase): ...@@ -290,7 +337,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1323]) expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
max_diff = np.abs(image_slice.flatten() - expected_slice).max() max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3) self.assertLessEqual(max_diff, 1e-3)
...@@ -307,7 +354,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase): ...@@ -307,7 +354,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0266]) expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
max_diff = np.abs(image_slice.flatten() - expected_slice).max() max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3) self.assertLessEqual(max_diff, 1e-3)
...@@ -323,7 +370,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase): ...@@ -323,7 +370,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.1501, 0.1755, 0.1877, 0.1445, 0.1665, 0.1763, 0.1389, 0.176, 0.2031]) expected_slice = np.array([0.1941, 0.2117, 0.2188, 0.1946, 0.218, 0.2124, 0.199, 0.2437, 0.2583])
max_diff = np.abs(image_slice.flatten() - expected_slice).max() max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3) self.assertLessEqual(max_diff, 1e-3)
...@@ -340,7 +387,26 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase): ...@@ -340,7 +387,26 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.2515, 0.2593, 0.2593, 0.2544, 0.2759, 0.2788, 0.2812, 0.3169, 0.332]) expected_slice = np.array([0.2637, 0.291, 0.2939, 0.207, 0.2512, 0.2783, 0.2168, 0.2324, 0.2817])
max_diff = np.abs(image_slice.flatten() - expected_slice).max() max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3) self.assertLessEqual(max_diff, 1e-3)
def test_pixart_1024_without_resolution_binning(self):
generator = torch.manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt, generator=generator, num_inference_steps=5, output_type="np").images
image_slice = image[0, -3:, -3:, -1]
generator = torch.manual_seed(0)
no_res_bin_image = pipe(
prompt, generator=generator, num_inference_steps=5, output_type="np", use_resolution_binning=False
).images
no_res_bin_image_slice = no_res_bin_image[0, -3:, -3:, -1]
assert not np.allclose(image_slice, no_res_bin_image_slice, atol=1e-4, rtol=1e-4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment