Unverified Commit b345c74d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Make sure all pipelines can run with batched input (#1669)



* [SD] Make sure batched input works correctly

* uP

* uP

* up

* up

* uP

* up

* fix mask stuff

* up

* uP

* more up

* up

* uP

* up

* finish

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent b4170422
...@@ -218,6 +218,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -218,6 +218,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
else: else:
timestep_embed = timestep_embed[..., None] timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
# 2. down # 2. down
down_block_res_samples = () down_block_res_samples = ()
......
...@@ -249,9 +249,9 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -249,9 +249,9 @@ class AltDiffusionPipeline(DiffusionPipeline):
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids): if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
......
...@@ -44,13 +44,24 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -44,13 +44,24 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image): def preprocess(image):
w, h = image.size if isinstance(image, torch.Tensor):
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 return image
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) elif isinstance(image, PIL.Image.Image):
image = np.array(image).astype(np.float32) / 255.0 image = [image]
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) if isinstance(image[0], PIL.Image.Image):
return 2.0 * image - 1.0 w, h = image[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
...@@ -81,7 +92,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -81,7 +92,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker"]
def __init__( def __init__(
self, self,
...@@ -246,9 +257,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -246,9 +257,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids): if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
...@@ -510,8 +521,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -510,8 +521,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
) )
# 4. Preprocess image # 4. Preprocess image
if isinstance(image, PIL.Image.Image): image = preprocess(image)
image = preprocess(image)
# 5. set timesteps # 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
......
...@@ -46,7 +46,6 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -46,7 +46,6 @@ class DDIMPipeline(DiffusionPipeline):
use_clipped_model_output: Optional[bool] = None, use_clipped_model_output: Optional[bool] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]: ) -> Union[ImagePipelineOutput, Tuple]:
r""" r"""
Args: Args:
......
...@@ -109,16 +109,18 @@ def prepare_mask_and_masked_image(image, mask): ...@@ -109,16 +109,18 @@ def prepare_mask_and_masked_image(image, mask):
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
else: else:
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
image = np.array(image.convert("RGB")) image = [image]
image = image[None].transpose(0, 3, 1, 2) image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, PIL.Image.Image): if isinstance(mask, PIL.Image.Image):
mask = np.array(mask.convert("L")) mask = [mask]
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
mask = mask.astype(np.float32) / 255.0
# paint-by-example inverses the mask # paint-by-example inverses the mask
mask = 1 - mask mask = 1 - mask
...@@ -159,7 +161,7 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -159,7 +161,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker"]
def __init__( def __init__(
self, self,
...@@ -323,8 +325,22 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -323,8 +325,22 @@ class PaintByExamplePipeline(DiffusionPipeline):
masked_image_latents = 0.18215 * masked_image_latents masked_image_latents = 0.18215 * masked_image_latents
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
mask = mask.repeat(batch_size, 1, 1, 1) if mask.shape[0] < batch_size:
masked_image_latents = masked_image_latents.repeat(batch_size, 1, 1, 1) if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = ( masked_image_latents = (
...@@ -351,7 +367,7 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -351,7 +367,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
if do_classifier_free_guidance: if do_classifier_free_guidance:
uncond_embeddings = self.image_encoder.uncond_vector uncond_embeddings = self.image_encoder.uncond_vector
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) uncond_embeddings = uncond_embeddings.repeat(1, image_embeddings.shape[0], 1)
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, 1, -1) uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, 1, -1)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
......
...@@ -35,14 +35,26 @@ from .safety_checker import StableDiffusionSafetyChecker ...@@ -35,14 +35,26 @@ from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image): def preprocess(image):
w, h = image.size if isinstance(image, torch.Tensor):
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 return image
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) elif isinstance(image, PIL.Image.Image):
image = np.array(image).astype(np.float32) / 255.0 image = [image]
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) if isinstance(image[0], PIL.Image.Image):
return 2.0 * image - 1.0 w, h = image[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta): def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta):
...@@ -279,9 +291,9 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -279,9 +291,9 @@ class CycleDiffusionPipeline(DiffusionPipeline):
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids): if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
...@@ -551,8 +563,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -551,8 +563,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
) )
# 4. Preprocess image # 4. Preprocess image
if isinstance(image, PIL.Image.Image): image = preprocess(image)
image = preprocess(image)
# 5. Prepare timesteps # 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
......
...@@ -32,13 +32,26 @@ from . import StableDiffusionPipelineOutput ...@@ -32,13 +32,26 @@ from . import StableDiffusionPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image): def preprocess(image):
w, h = image.size if isinstance(image, torch.Tensor):
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 return image
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) elif isinstance(image, PIL.Image.Image):
image = np.array(image).astype(np.float32) / 255.0 image = [image]
image = image[None].transpose(0, 3, 1, 2)
return 2.0 * image - 1.0 if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
...@@ -77,7 +90,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -77,7 +90,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker"]
def __init__( def __init__(
self, self,
...@@ -325,8 +338,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -325,8 +338,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
# set timesteps # set timesteps
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
if isinstance(image, PIL.Image.Image): image = preprocess(image)
image = preprocess(image)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
......
...@@ -248,9 +248,9 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -248,9 +248,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids): if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
......
...@@ -41,14 +41,26 @@ from ...utils import PIL_INTERPOLATION, deprecate, logging ...@@ -41,14 +41,26 @@ from ...utils import PIL_INTERPOLATION, deprecate, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image): def preprocess(image):
w, h = image.size if isinstance(image, torch.Tensor):
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 return image
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) elif isinstance(image, PIL.Image.Image):
image = np.array(image).astype(np.float32) / 255.0 image = [image]
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) if isinstance(image[0], PIL.Image.Image):
return 2.0 * image - 1.0 w, h = image[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
...@@ -189,9 +201,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -189,9 +201,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids): if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
...@@ -366,12 +378,13 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -366,12 +378,13 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device): def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device):
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
width, height = image.size image = [image]
width, height = map(lambda dim: dim - dim % 32, (width, height)) # resize to integer multiple of 32
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
width, height = image.size
else: else:
image = [img for img in image] image = [img for img in image]
if isinstance(image[0], PIL.Image.Image):
width, height = image[0].size
else:
width, height = image[0].shape[-2:] width, height = image[0].shape[-2:]
if depth_map is None: if depth_map is None:
...@@ -493,7 +506,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -493,7 +506,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
) )
# 4. Prepare depth mask # 4. Preprocess image
depth_mask = self.prepare_depth_map( depth_mask = self.prepare_depth_map(
image, image,
depth_map, depth_map,
...@@ -503,11 +516,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -503,11 +516,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
device, device,
) )
# 5. Preprocess image # 5. Prepare depth mask
if isinstance(image, PIL.Image.Image): image = preprocess(image)
image = preprocess(image)
else:
image = 2.0 * (image / 255.0) - 1.0
# 6. set timesteps # 6. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
......
...@@ -65,7 +65,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -65,7 +65,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker"]
def __init__( def __init__(
self, self,
......
...@@ -43,13 +43,24 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -43,13 +43,24 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image): def preprocess(image):
w, h = image.size if isinstance(image, torch.Tensor):
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 return image
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) elif isinstance(image, PIL.Image.Image):
image = np.array(image).astype(np.float32) / 255.0 image = [image]
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) if isinstance(image[0], PIL.Image.Image):
return 2.0 * image - 1.0 w, h = image[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class StableDiffusionImg2ImgPipeline(DiffusionPipeline): class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
...@@ -79,7 +90,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -79,7 +90,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker"]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__( def __init__(
...@@ -248,9 +259,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -248,9 +259,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids): if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
...@@ -515,8 +526,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -515,8 +526,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
) )
# 4. Preprocess image # 4. Preprocess image
if isinstance(image, PIL.Image.Image): image = preprocess(image)
image = preprocess(image)
# 5. set timesteps # 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
......
...@@ -107,14 +107,29 @@ def prepare_mask_and_masked_image(image, mask): ...@@ -107,14 +107,29 @@ def prepare_mask_and_masked_image(image, mask):
elif isinstance(mask, torch.Tensor): elif isinstance(mask, torch.Tensor):
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
else: else:
if isinstance(image, PIL.Image.Image): # preprocess image
image = np.array(image.convert("RGB")) if isinstance(image, (PIL.Image.Image, np.ndarray)):
image = image[None].transpose(0, 3, 1, 2) image = [image]
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
if isinstance(mask, PIL.Image.Image):
mask = np.array(mask.convert("L")) # preprocess mask
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
mask = mask.astype(np.float32) / 255.0 mask = mask.astype(np.float32) / 255.0
mask = mask[None, None] elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0 mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1 mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask) mask = torch.from_numpy(mask)
...@@ -151,7 +166,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -151,7 +166,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker"]
def __init__( def __init__(
self, self,
...@@ -313,9 +328,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -313,9 +328,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids): if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
...@@ -481,8 +496,22 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -481,8 +496,22 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
masked_image_latents = 0.18215 * masked_image_latents masked_image_latents = 0.18215 * masked_image_latents
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
mask = mask.repeat(batch_size, 1, 1, 1) if mask.shape[0] < batch_size:
masked_image_latents = masked_image_latents.repeat(batch_size, 1, 1, 1) if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = ( masked_image_latents = (
......
...@@ -92,7 +92,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -92,7 +92,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["feature_extractor"]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__( def __init__(
...@@ -261,9 +261,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -261,9 +261,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids): if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
......
...@@ -192,9 +192,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -192,9 +192,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids): if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
......
...@@ -32,15 +32,23 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -32,15 +32,23 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image): def preprocess(image):
# resize to multiple of 64 if isinstance(image, torch.Tensor):
width, height = image.size return image
width = width - width % 64 elif isinstance(image, PIL.Image.Image):
height = height - height % 64 image = [image]
image = image.resize((width, height))
if isinstance(image[0], PIL.Image.Image):
image = np.array(image.convert("RGB")) w, h = image[0].size
image = image[None].transpose(0, 3, 1, 2) w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
image = [np.array(i.resize((w, h)))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image return image
...@@ -156,9 +164,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -156,9 +164,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids): if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
...@@ -407,10 +415,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -407,10 +415,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
) )
# 4. Preprocess image # 4. Preprocess image
image = [image] if isinstance(image, PIL.Image.Image) else image image = preprocess(image)
if isinstance(image, list):
image = [preprocess(img) for img in image]
image = torch.cat(image, dim=0)
image = image.to(dtype=text_embeddings.dtype, device=device) image = image.to(dtype=text_embeddings.dtype, device=device)
# 5. set timesteps # 5. set timesteps
......
...@@ -64,6 +64,7 @@ class DanceDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -64,6 +64,7 @@ class DanceDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
else: else:
generator = torch.Generator(device=device).manual_seed(seed) generator = torch.Generator(device=device).manual_seed(seed)
inputs = { inputs = {
"batch_size": 1,
"generator": generator, "generator": generator,
"num_inference_steps": 4, "num_inference_steps": 4,
} }
......
...@@ -52,6 +52,7 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -52,6 +52,7 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
else: else:
generator = torch.Generator(device=device).manual_seed(seed) generator = torch.Generator(device=device).manual_seed(seed)
inputs = { inputs = {
"batch_size": 1,
"generator": generator, "generator": generator,
"num_inference_steps": 2, "num_inference_steps": 2,
"output_type": "numpy", "output_type": "numpy",
......
...@@ -25,7 +25,7 @@ from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder ...@@ -25,7 +25,7 @@ from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
from diffusers.utils import floats_tensor, load_image, slow, torch_device from diffusers.utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu
from PIL import Image from PIL import Image
from transformers import CLIPVisionConfig from transformers import CLIPImageProcessor, CLIPVisionConfig
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
...@@ -76,6 +76,7 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -76,6 +76,7 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
patch_size=4, patch_size=4,
) )
image_encoder = PaintByExampleImageEncoder(config, proj_size=32) image_encoder = PaintByExampleImageEncoder(config, proj_size=32)
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
components = { components = {
"unet": unet, "unet": unet,
...@@ -83,7 +84,7 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -83,7 +84,7 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"vae": vae, "vae": vae,
"image_encoder": image_encoder, "image_encoder": image_encoder,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": feature_extractor,
} }
return components return components
...@@ -100,7 +101,6 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -100,7 +101,6 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
example_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32)) example_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32))
example_image = self.convert_to_pt(example_image)
if str(device).startswith("mps"): if str(device).startswith("mps"):
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
......
...@@ -29,7 +29,8 @@ from diffusers import ( ...@@ -29,7 +29,8 @@ from diffusers import (
) )
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModelWithProjection
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
...@@ -74,19 +75,22 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte ...@@ -74,19 +75,22 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
patch_size=4, patch_size=4,
) )
image_encoder = CLIPVisionModelWithProjection(image_encoder_config) image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
components = { components = {
"unet": unet, "unet": unet,
"scheduler": scheduler, "scheduler": scheduler,
"vae": vae, "vae": vae,
"image_encoder": image_encoder, "image_encoder": image_encoder,
"feature_extractor": feature_extractor,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None,
} }
return components return components
def get_dummy_inputs(self, device, seed=0): def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed))
image = image.cpu().permute(0, 2, 3, 1)[0]
image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32))
if str(device).startswith("mps"): if str(device).startswith("mps"):
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
else: else:
...@@ -112,7 +116,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte ...@@ -112,7 +116,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3) assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5093, 0.5717, 0.4806, 0.4891, 0.5552, 0.4594, 0.5177, 0.4894, 0.4904]) expected_slice = np.array([0.5167, 0.5746, 0.4835, 0.4914, 0.5605, 0.4691, 0.5201, 0.4898, 0.4958])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_img_variation_multiple_images(self): def test_stable_diffusion_img_variation_multiple_images(self):
...@@ -123,7 +127,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte ...@@ -123,7 +127,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
sd_pipe.set_progress_bar_config(disable=None) sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"].repeat(2, 1, 1, 1) inputs["image"] = 2 * [inputs["image"]]
output = sd_pipe(**inputs) output = sd_pipe(**inputs)
image = output.images image = output.images
...@@ -131,7 +135,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte ...@@ -131,7 +135,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
image_slice = image[-1, -3:, -3:, -1] image_slice = image[-1, -3:, -3:, -1]
assert image.shape == (2, 64, 64, 3) assert image.shape == (2, 64, 64, 3)
expected_slice = np.array([0.6427, 0.5452, 0.5602, 0.5478, 0.5968, 0.6211, 0.5538, 0.5514, 0.5281]) expected_slice = np.array([0.6568, 0.5470, 0.5684, 0.5444, 0.5945, 0.6221, 0.5508, 0.5531, 0.5263])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_img_variation_num_images_per_prompt(self): def test_stable_diffusion_img_variation_num_images_per_prompt(self):
...@@ -150,7 +154,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte ...@@ -150,7 +154,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
# test num_images_per_prompt=1 (default) for batch of images # test num_images_per_prompt=1 (default) for batch of images
batch_size = 2 batch_size = 2
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"].repeat(batch_size, 1, 1, 1) inputs["image"] = batch_size * [inputs["image"]]
images = sd_pipe(**inputs).images images = sd_pipe(**inputs).images
assert images.shape == (batch_size, 64, 64, 3) assert images.shape == (batch_size, 64, 64, 3)
...@@ -165,7 +169,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte ...@@ -165,7 +169,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
# test num_images_per_prompt for batch of prompts # test num_images_per_prompt for batch of prompts
batch_size = 2 batch_size = 2
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"].repeat(batch_size, 1, 1, 1) inputs["image"] = batch_size * [inputs["image"]]
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3) assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
......
...@@ -30,7 +30,7 @@ from diffusers import ( ...@@ -30,7 +30,7 @@ from diffusers import (
) )
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
...@@ -77,6 +77,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -77,6 +77,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
) )
text_encoder = CLIPTextModel(text_encoder_config) text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
components = { components = {
"unet": unet, "unet": unet,
...@@ -85,7 +86,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -85,7 +86,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
"text_encoder": text_encoder, "text_encoder": text_encoder,
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": feature_extractor,
} }
return components return components
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment