Unverified Commit 867e0c91 authored by Vinh H. Pham's avatar Vinh H. Pham Committed by GitHub
Browse files

StableDiffusionLatentUpscalePipeline - positive/negative prompt embeds support (#8947)



* make latent upscaler accept prompt embeds

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 16a3dad4
...@@ -33,6 +33,20 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu ...@@ -33,6 +33,20 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess
def preprocess(image): def preprocess(image):
warnings.warn( warnings.warn(
...@@ -105,7 +119,54 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -105,7 +119,54 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt,
device,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
**kwargs,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
device=device,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
**kwargs,
)
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
return prompt_embeds, pooled_prompt_embeds
def encode_prompt(
self,
prompt,
device,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -119,81 +180,100 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -119,81 +180,100 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
negative_prompt (`str` or `List[str]`): negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`). if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
text_inputs = self.tokenizer( elif prompt is not None and isinstance(prompt, list):
prompt, batch_size = len(prompt)
padding="max_length", else:
max_length=self.tokenizer.model_max_length, batch_size = prompt_embeds.shape[0]
truncation=True,
return_length=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_encoder_out = self.text_encoder(
text_input_ids.to(device),
output_hidden_states=True,
)
text_embeddings = text_encoder_out.hidden_states[-1]
text_pooler_out = text_encoder_out.pooler_output
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1] if prompt_embeds is None or pooled_prompt_embeds is None:
uncond_input = self.tokenizer( text_inputs = self.tokenizer(
uncond_tokens, prompt,
padding="max_length", padding="max_length",
max_length=max_length, max_length=self.tokenizer.model_max_length,
truncation=True, truncation=True,
return_length=True, return_length=True,
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids
uncond_encoder_out = self.text_encoder( untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
uncond_input.input_ids.to(device),
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_encoder_out = self.text_encoder(
text_input_ids.to(device),
output_hidden_states=True, output_hidden_states=True,
) )
prompt_embeds = text_encoder_out.hidden_states[-1]
pooled_prompt_embeds = text_encoder_out.pooler_output
uncond_embeddings = uncond_encoder_out.hidden_states[-1] # get unconditional embeddings for classifier free guidance
uncond_pooler_out = uncond_encoder_out.pooler_output if do_classifier_free_guidance:
if negative_prompt_embeds is None or negative_pooled_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_length=True,
return_tensors="pt",
)
uncond_encoder_out = self.text_encoder(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# For classifier free guidance, we need to do two forward passes. negative_prompt_embeds = uncond_encoder_out.hidden_states[-1]
# Here we concatenate the unconditional and text embeddings into a single batch negative_pooled_prompt_embeds = uncond_encoder_out.pooler_output
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
text_pooler_out = torch.cat([uncond_pooler_out, text_pooler_out])
return text_embeddings, text_pooler_out return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
...@@ -207,12 +287,56 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -207,12 +287,56 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image return image
def check_inputs(self, prompt, image, callback_steps): def check_inputs(
if not isinstance(prompt, str) and not isinstance(prompt, list): self,
prompt,
image,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
):
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if ( if (
not isinstance(image, torch.Tensor) not isinstance(image, torch.Tensor)
and not isinstance(image, np.ndarray)
and not isinstance(image, PIL.Image.Image) and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list) and not isinstance(image, list)
): ):
...@@ -222,10 +346,14 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -222,10 +346,14 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
# verify batch size of prompt and image are same if image is a list or tensor # verify batch size of prompt and image are same if image is a list or tensor
if isinstance(image, (list, torch.Tensor)): if isinstance(image, (list, torch.Tensor)):
if isinstance(prompt, str): if prompt is not None:
batch_size = 1 if isinstance(prompt, str):
batch_size = 1
else:
batch_size = len(prompt)
else: else:
batch_size = len(prompt) batch_size = prompt_embeds.shape[0]
if isinstance(image, list): if isinstance(image, list):
image_batch_size = len(image) image_batch_size = len(image)
else: else:
...@@ -261,13 +389,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -261,13 +389,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]] = None,
image: PipelineImageInput = None, image: PipelineImageInput = None,
num_inference_steps: int = 75, num_inference_steps: int = 75,
guidance_scale: float = 9.0, guidance_scale: float = 9.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None, latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
...@@ -359,10 +491,22 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -359,10 +491,22 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
""" """
# 1. Check inputs # 1. Check inputs
self.check_inputs(prompt, image, callback_steps) self.check_inputs(
prompt,
image,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
# 2. Define call parameters # 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt) if prompt is not None:
batch_size = 1 if isinstance(prompt, str) else len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...@@ -373,16 +517,32 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -373,16 +517,32 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
prompt = [""] * batch_size prompt = [""] * batch_size
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings, text_pooler_out = self._encode_prompt( (
prompt, device, do_classifier_free_guidance, negative_prompt prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt,
device,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) )
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
# 4. Preprocess image # 4. Preprocess image
image = self.image_processor.preprocess(image) image = self.image_processor.preprocess(image)
image = image.to(dtype=text_embeddings.dtype, device=device) image = image.to(dtype=prompt_embeds.dtype, device=device)
if image.shape[1] == 3: if image.shape[1] == 3:
# encode image if not in latent-space yet # encode image if not in latent-space yet
image = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor image = retrieve_latents(self.vae.encode(image), generator=generator) * self.vae.config.scaling_factor
# 5. set timesteps # 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
...@@ -400,17 +560,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -400,17 +560,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
inv_noise_level = (noise_level**2 + 1) ** (-0.5) inv_noise_level = (noise_level**2 + 1) ** (-0.5)
image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None] image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None]
image_cond = image_cond.to(text_embeddings.dtype) image_cond = image_cond.to(prompt_embeds.dtype)
noise_level_embed = torch.cat( noise_level_embed = torch.cat(
[ [
torch.ones(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device), torch.ones(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device),
torch.zeros(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device), torch.zeros(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device),
], ],
dim=1, dim=1,
) )
timestep_condition = torch.cat([noise_level_embed, text_pooler_out], dim=1) timestep_condition = torch.cat([noise_level_embed, pooled_prompt_embeds], dim=1)
# 6. Prepare latent variables # 6. Prepare latent variables
height, width = image.shape[2:] height, width = image.shape[2:]
...@@ -420,7 +580,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -420,7 +580,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
num_channels_latents, num_channels_latents,
height * 2, # 2x upscale height * 2, # 2x upscale
width * 2, width * 2,
text_embeddings.dtype, prompt_embeds.dtype,
device, device,
generator, generator,
latents, latents,
...@@ -454,7 +614,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix ...@@ -454,7 +614,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
noise_pred = self.unet( noise_pred = self.unet(
scaled_model_input, scaled_model_input,
timestep, timestep,
encoder_hidden_states=text_embeddings, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_condition, timestep_cond=timestep_condition,
).sample ).sample
......
...@@ -178,6 +178,46 @@ class StableDiffusionLatentUpscalePipelineFastTests( ...@@ -178,6 +178,46 @@ class StableDiffusionLatentUpscalePipelineFastTests(
max_diff = np.abs(image_slice.flatten() - expected_slice).max() max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3) self.assertLessEqual(max_diff, 1e-3)
def test_stable_diffusion_latent_upscaler_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionLatentUpscalePipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
negative_prompt = "french fries"
output = sd_pipe(**inputs, negative_prompt=negative_prompt)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array(
[0.43865365, 0.404124, 0.42618454, 0.44333526, 0.40564927, 0.43818694, 0.4411913, 0.43404633, 0.46392226]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_latent_upscaler_multiple_init_images(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionLatentUpscalePipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * 2
inputs["image"] = inputs["image"].repeat(2, 1, 1, 1)
image = sd_pipe(**inputs).images
image_slice = image[-1, -3:, -3:, -1]
assert image.shape == (2, 256, 256, 3)
expected_slice = np.array(
[0.38730142, 0.35695046, 0.40646142, 0.40967226, 0.3981609, 0.4195988, 0.4248805, 0.430259, 0.45694894]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=7e-3) super().test_attention_slicing_forward_pass(expected_max_diff=7e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment