"vscode:/vscode.git/clone" did not exist on "a3faf3f260101c2bc09112ebb514d1e0d0220e1a"
Unverified Commit b345c74d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Make sure all pipelines can run with batched input (#1669)



* [SD] Make sure batched input works correctly

* uP

* uP

* up

* up

* uP

* up

* fix mask stuff

* up

* uP

* more up

* up

* uP

* up

* finish

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent b4170422
......@@ -218,6 +218,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
else:
timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
# 2. down
down_block_res_samples = ()
......
......@@ -249,9 +249,9 @@ class AltDiffusionPipeline(DiffusionPipeline):
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
......
......@@ -44,13 +44,24 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
......@@ -81,7 +92,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker"]
def __init__(
self,
......@@ -246,9 +257,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
......@@ -510,8 +521,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
)
# 4. Preprocess image
if isinstance(image, PIL.Image.Image):
image = preprocess(image)
image = preprocess(image)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
......
......@@ -46,7 +46,6 @@ class DDIMPipeline(DiffusionPipeline):
use_clipped_model_output: Optional[bool] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
......
......@@ -109,16 +109,18 @@ def prepare_mask_and_masked_image(image, mask):
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
else:
if isinstance(image, PIL.Image.Image):
image = np.array(image.convert("RGB"))
image = [image]
image = image[None].transpose(0, 3, 1, 2)
image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, PIL.Image.Image):
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = [mask]
mask = mask[None, None]
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
mask = mask.astype(np.float32) / 255.0
# paint-by-example inverses the mask
mask = 1 - mask
......@@ -159,7 +161,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker"]
def __init__(
self,
......@@ -323,8 +325,22 @@ class PaintByExamplePipeline(DiffusionPipeline):
masked_image_latents = 0.18215 * masked_image_latents
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
mask = mask.repeat(batch_size, 1, 1, 1)
masked_image_latents = masked_image_latents.repeat(batch_size, 1, 1, 1)
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
......@@ -351,7 +367,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
if do_classifier_free_guidance:
uncond_embeddings = self.image_encoder.uncond_vector
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.repeat(1, image_embeddings.shape[0], 1)
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, 1, -1)
# For classifier free guidance, we need to do two forward passes.
......
......@@ -35,14 +35,26 @@ from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta):
......@@ -279,9 +291,9 @@ class CycleDiffusionPipeline(DiffusionPipeline):
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
......@@ -551,8 +563,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
)
# 4. Preprocess image
if isinstance(image, PIL.Image.Image):
image = preprocess(image)
image = preprocess(image)
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
......
......@@ -32,13 +32,26 @@ from . import StableDiffusionPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
return 2.0 * image - 1.0
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
......@@ -77,7 +90,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker"]
def __init__(
self,
......@@ -325,8 +338,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
if isinstance(image, PIL.Image.Image):
image = preprocess(image)
image = preprocess(image)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
......
......@@ -248,9 +248,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
......
......@@ -41,14 +41,26 @@ from ...utils import PIL_INTERPOLATION, deprecate, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
......@@ -189,9 +201,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
......@@ -366,12 +378,13 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device):
if isinstance(image, PIL.Image.Image):
width, height = image.size
width, height = map(lambda dim: dim - dim % 32, (width, height)) # resize to integer multiple of 32
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
width, height = image.size
image = [image]
else:
image = [img for img in image]
if isinstance(image[0], PIL.Image.Image):
width, height = image[0].size
else:
width, height = image[0].shape[-2:]
if depth_map is None:
......@@ -493,7 +506,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# 4. Prepare depth mask
# 4. Preprocess image
depth_mask = self.prepare_depth_map(
image,
depth_map,
......@@ -503,11 +516,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
device,
)
# 5. Preprocess image
if isinstance(image, PIL.Image.Image):
image = preprocess(image)
else:
image = 2.0 * (image / 255.0) - 1.0
# 5. Prepare depth mask
image = preprocess(image)
# 6. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
......
......@@ -65,7 +65,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker"]
def __init__(
self,
......
......@@ -43,13 +43,24 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
......@@ -79,7 +90,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker"]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__(
......@@ -248,9 +259,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
......@@ -515,8 +526,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
)
# 4. Preprocess image
if isinstance(image, PIL.Image.Image):
image = preprocess(image)
image = preprocess(image)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
......
......@@ -107,14 +107,29 @@ def prepare_mask_and_masked_image(image, mask):
elif isinstance(mask, torch.Tensor):
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
else:
if isinstance(image, PIL.Image.Image):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
# preprocess image
if isinstance(image, (PIL.Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
if isinstance(mask, PIL.Image.Image):
mask = np.array(mask.convert("L"))
# preprocess mask
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
......@@ -151,7 +166,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker"]
def __init__(
self,
......@@ -313,9 +328,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
......@@ -481,8 +496,22 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
masked_image_latents = 0.18215 * masked_image_latents
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
mask = mask.repeat(batch_size, 1, 1, 1)
masked_image_latents = masked_image_latents.repeat(batch_size, 1, 1, 1)
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
......
......@@ -92,7 +92,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["feature_extractor"]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__(
......@@ -261,9 +261,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
......
......@@ -192,9 +192,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
......
......@@ -32,15 +32,23 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
# resize to multiple of 64
width, height = image.size
width = width - width % 64
height = height - height % 64
image = image.resize((width, height))
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
image = [np.array(i.resize((w, h)))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
......@@ -156,9 +164,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
......@@ -407,10 +415,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
)
# 4. Preprocess image
image = [image] if isinstance(image, PIL.Image.Image) else image
if isinstance(image, list):
image = [preprocess(img) for img in image]
image = torch.cat(image, dim=0)
image = preprocess(image)
image = image.to(dtype=text_embeddings.dtype, device=device)
# 5. set timesteps
......
......@@ -64,6 +64,7 @@ class DanceDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"batch_size": 1,
"generator": generator,
"num_inference_steps": 4,
}
......
......@@ -52,6 +52,7 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"batch_size": 1,
"generator": generator,
"num_inference_steps": 2,
"output_type": "numpy",
......
......@@ -25,7 +25,7 @@ from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
from diffusers.utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from PIL import Image
from transformers import CLIPVisionConfig
from transformers import CLIPImageProcessor, CLIPVisionConfig
from ...test_pipelines_common import PipelineTesterMixin
......@@ -76,6 +76,7 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
patch_size=4,
)
image_encoder = PaintByExampleImageEncoder(config, proj_size=32)
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
components = {
"unet": unet,
......@@ -83,7 +84,7 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"vae": vae,
"image_encoder": image_encoder,
"safety_checker": None,
"feature_extractor": None,
"feature_extractor": feature_extractor,
}
return components
......@@ -100,7 +101,6 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
example_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32))
example_image = self.convert_to_pt(example_image)
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
......
......@@ -29,7 +29,8 @@ from diffusers import (
)
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModelWithProjection
from ...test_pipelines_common import PipelineTesterMixin
......@@ -74,19 +75,22 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
patch_size=4,
)
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"image_encoder": image_encoder,
"feature_extractor": feature_extractor,
"safety_checker": None,
"feature_extractor": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed))
image = image.cpu().permute(0, 2, 3, 1)[0]
image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32))
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
......@@ -112,7 +116,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5093, 0.5717, 0.4806, 0.4891, 0.5552, 0.4594, 0.5177, 0.4894, 0.4904])
expected_slice = np.array([0.5167, 0.5746, 0.4835, 0.4914, 0.5605, 0.4691, 0.5201, 0.4898, 0.4958])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_img_variation_multiple_images(self):
......@@ -123,7 +127,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"].repeat(2, 1, 1, 1)
inputs["image"] = 2 * [inputs["image"]]
output = sd_pipe(**inputs)
image = output.images
......@@ -131,7 +135,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
image_slice = image[-1, -3:, -3:, -1]
assert image.shape == (2, 64, 64, 3)
expected_slice = np.array([0.6427, 0.5452, 0.5602, 0.5478, 0.5968, 0.6211, 0.5538, 0.5514, 0.5281])
expected_slice = np.array([0.6568, 0.5470, 0.5684, 0.5444, 0.5945, 0.6221, 0.5508, 0.5531, 0.5263])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_img_variation_num_images_per_prompt(self):
......@@ -150,7 +154,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
# test num_images_per_prompt=1 (default) for batch of images
batch_size = 2
inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"].repeat(batch_size, 1, 1, 1)
inputs["image"] = batch_size * [inputs["image"]]
images = sd_pipe(**inputs).images
assert images.shape == (batch_size, 64, 64, 3)
......@@ -165,7 +169,7 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
# test num_images_per_prompt for batch of prompts
batch_size = 2
inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"].repeat(batch_size, 1, 1, 1)
inputs["image"] = batch_size * [inputs["image"]]
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
......
......@@ -30,7 +30,7 @@ from diffusers import (
)
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin
......@@ -77,6 +77,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
components = {
"unet": unet,
......@@ -85,7 +86,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"feature_extractor": feature_extractor,
}
return components
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment