Unverified Commit b28ab302 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Unclip] Make sure text_embeddings & image_embeddings can directly be passed...

[Unclip] Make sure text_embeddings & image_embeddings can directly be passed to enable interpolation tasks. (#1858)

* [Unclip] Make sure latents can be reused

* allow one to directly pass embeddings

* up

* make unclip for text work

* finish allowing to pass embeddings

* correct more

* make style
parent 29b2c93c
...@@ -13,12 +13,13 @@ ...@@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import List, Optional, Union from typing import List, Optional, Tuple, Union
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from transformers import CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput from ...pipelines import DiffusionPipeline, ImagePipelineOutput
...@@ -117,31 +118,44 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -117,31 +118,44 @@ class UnCLIPPipeline(DiffusionPipeline):
latents = latents * scheduler.init_noise_sigma latents = latents * scheduler.init_noise_sigma
return latents return latents
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): def _encode_prompt(
batch_size = len(prompt) if isinstance(prompt, list) else 1 self,
prompt,
# get prompt text embeddings device,
text_inputs = self.tokenizer( num_images_per_prompt,
prompt, do_classifier_free_guidance,
padding="max_length", text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
max_length=self.tokenizer.model_max_length, text_attention_mask: Optional[torch.Tensor] = None,
return_tensors="pt", ):
) if text_model_output is None:
text_input_ids = text_inputs.input_ids batch_size = len(prompt) if isinstance(prompt, list) else 1
text_mask = text_inputs.attention_mask.bool().to(device) # get prompt text embeddings
text_inputs = self.tokenizer(
if text_input_ids.shape[-1] > self.tokenizer.model_max_length: prompt,
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) padding="max_length",
logger.warning( max_length=self.tokenizer.model_max_length,
"The following part of your input was truncated because CLIP can only handle sequences up to" return_tensors="pt",
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_input_ids = text_inputs.input_ids
text_mask = text_inputs.attention_mask.bool().to(device)
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_encoder_output = self.text_encoder(text_input_ids.to(device)) text_encoder_output = self.text_encoder(text_input_ids.to(device))
text_embeddings = text_encoder_output.text_embeds text_embeddings = text_encoder_output.text_embeds
text_encoder_hidden_states = text_encoder_output.last_hidden_state text_encoder_hidden_states = text_encoder_output.last_hidden_state
else:
batch_size = text_model_output[0].shape[0]
text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
...@@ -150,11 +164,10 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -150,11 +164,10 @@ class UnCLIPPipeline(DiffusionPipeline):
if do_classifier_free_guidance: if do_classifier_free_guidance:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
max_length=max_length, max_length=self.tokenizer.model_max_length,
truncation=True, truncation=True,
return_tensors="pt", return_tensors="pt",
) )
...@@ -235,7 +248,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -235,7 +248,7 @@ class UnCLIPPipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
prior_num_inference_steps: int = 25, prior_num_inference_steps: int = 25,
decoder_num_inference_steps: int = 25, decoder_num_inference_steps: int = 25,
...@@ -244,6 +257,8 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -244,6 +257,8 @@ class UnCLIPPipeline(DiffusionPipeline):
prior_latents: Optional[torch.FloatTensor] = None, prior_latents: Optional[torch.FloatTensor] = None,
decoder_latents: Optional[torch.FloatTensor] = None, decoder_latents: Optional[torch.FloatTensor] = None,
super_res_latents: Optional[torch.FloatTensor] = None, super_res_latents: Optional[torch.FloatTensor] = None,
text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
text_attention_mask: Optional[torch.Tensor] = None,
prior_guidance_scale: float = 4.0, prior_guidance_scale: float = 4.0,
decoder_guidance_scale: float = 8.0, decoder_guidance_scale: float = 8.0,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
...@@ -254,7 +269,8 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -254,7 +269,8 @@ class UnCLIPPipeline(DiffusionPipeline):
Args: Args:
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation. This can only be left undefined if
`text_model_output` and `text_attention_mask` is passed.
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
prior_num_inference_steps (`int`, *optional*, defaults to 25): prior_num_inference_steps (`int`, *optional*, defaults to 25):
...@@ -287,18 +303,29 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -287,18 +303,29 @@ class UnCLIPPipeline(DiffusionPipeline):
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. usually at the expense of lower image quality.
text_model_output (`CLIPTextModelOutput`, *optional*):
Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs
can be passed for tasks like text embedding interpolations. Make sure to also pass
`text_attention_mask` in this case. `prompt` can the be left to `None`.
text_attention_mask (`torch.Tensor`, *optional*):
Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention
masks are necessary when passing `text_model_output`.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between The output format of the generated image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
""" """
if isinstance(prompt, str): if prompt is not None:
batch_size = 1 if isinstance(prompt, str):
elif isinstance(prompt, list): batch_size = 1
batch_size = len(prompt) elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
else: else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") batch_size = text_model_output[0].shape[0]
device = self._execution_device device = self._execution_device
batch_size = batch_size * num_images_per_prompt batch_size = batch_size * num_images_per_prompt
...@@ -306,7 +333,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -306,7 +333,7 @@ class UnCLIPPipeline(DiffusionPipeline):
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt( text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
) )
# prior # prior
...@@ -315,6 +342,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -315,6 +342,7 @@ class UnCLIPPipeline(DiffusionPipeline):
prior_timesteps_tensor = self.prior_scheduler.timesteps prior_timesteps_tensor = self.prior_scheduler.timesteps
embedding_dim = self.prior.config.embedding_dim embedding_dim = self.prior.config.embedding_dim
prior_latents = self.prepare_latents( prior_latents = self.prepare_latents(
(batch_size, embedding_dim), (batch_size, embedding_dim),
text_embeddings.dtype, text_embeddings.dtype,
...@@ -378,6 +406,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -378,6 +406,7 @@ class UnCLIPPipeline(DiffusionPipeline):
num_channels_latents = self.decoder.in_channels num_channels_latents = self.decoder.in_channels
height = self.decoder.sample_size height = self.decoder.sample_size
width = self.decoder.sample_size width = self.decoder.sample_size
decoder_latents = self.prepare_latents( decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width), (batch_size, num_channels_latents, height, width),
text_encoder_hidden_states.dtype, text_encoder_hidden_states.dtype,
...@@ -430,6 +459,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -430,6 +459,7 @@ class UnCLIPPipeline(DiffusionPipeline):
channels = self.super_res_first.in_channels // 2 channels = self.super_res_first.in_channels // 2
height = self.super_res_first.sample_size height = self.super_res_first.sample_size
width = self.super_res_first.sample_size width = self.super_res_first.sample_size
super_res_latents = self.prepare_latents( super_res_latents = self.prepare_latents(
(batch_size, channels, height, width), (batch_size, channels, height, width),
image_small.dtype, image_small.dtype,
......
...@@ -126,7 +126,6 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -126,7 +126,6 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
latents = latents * scheduler.init_noise_sigma latents = latents * scheduler.init_noise_sigma
return latents return latents
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
batch_size = len(prompt) if isinstance(prompt, list) else 1 batch_size = len(prompt) if isinstance(prompt, list) else 1
...@@ -139,15 +138,6 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -139,15 +138,6 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
text_mask = text_inputs.attention_mask.bool().to(device) text_mask = text_inputs.attention_mask.bool().to(device)
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_encoder_output = self.text_encoder(text_input_ids.to(device)) text_encoder_output = self.text_encoder(text_input_ids.to(device))
text_embeddings = text_encoder_output.text_embeds text_embeddings = text_encoder_output.text_embeds
...@@ -199,14 +189,15 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -199,14 +189,15 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
return text_embeddings, text_encoder_hidden_states, text_mask return text_embeddings, text_encoder_hidden_states, text_mask
def _encode_image(self, image, device, num_images_per_prompt): def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if image_embeddings is None:
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values if not isinstance(image, torch.Tensor):
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_embeddings = self.image_encoder(image).image_embeds image_embeddings = self.image_encoder(image).image_embeds
image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0) image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
...@@ -258,13 +249,14 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -258,13 +249,14 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
decoder_num_inference_steps: int = 25, decoder_num_inference_steps: int = 25,
super_res_num_inference_steps: int = 7, super_res_num_inference_steps: int = 7,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
decoder_latents: Optional[torch.FloatTensor] = None, decoder_latents: Optional[torch.FloatTensor] = None,
super_res_latents: Optional[torch.FloatTensor] = None, super_res_latents: Optional[torch.FloatTensor] = None,
image_embeddings: Optional[torch.Tensor] = None,
decoder_guidance_scale: float = 8.0, decoder_guidance_scale: float = 8.0,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
...@@ -277,7 +269,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -277,7 +269,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
The image or images to guide the image generation. If you provide a tensor, it needs to comply with the The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
configuration of configuration of
[this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
`CLIPFeatureExtractor`. `CLIPFeatureExtractor`. Can be left to `None` only when `image_embeddings` are passed.
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
decoder_num_inference_steps (`int`, *optional*, defaults to 25): decoder_num_inference_steps (`int`, *optional*, defaults to 25):
...@@ -299,18 +291,24 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -299,18 +291,24 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. usually at the expense of lower image quality.
image_embeddings (`torch.Tensor`, *optional*):
Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings
can be passed for tasks like image interpolations. `image` can the be left to `None`.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between The output format of the generated image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
""" """
if isinstance(image, PIL.Image.Image): if image is not None:
batch_size = 1 if isinstance(image, PIL.Image.Image):
elif isinstance(image, list): batch_size = 1
batch_size = len(image) elif isinstance(image, list):
batch_size = len(image)
else:
batch_size = image.shape[0]
else: else:
batch_size = image.shape[0] batch_size = image_embeddings.shape[0]
prompt = [""] * batch_size prompt = [""] * batch_size
...@@ -324,10 +322,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -324,10 +322,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
prompt, device, num_images_per_prompt, do_classifier_free_guidance prompt, device, num_images_per_prompt, do_classifier_free_guidance
) )
image_embeddings = self._encode_image(image, device, num_images_per_prompt) image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings)
# decoder # decoder
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings, image_embeddings=image_embeddings,
text_embeddings=text_embeddings, text_embeddings=text_embeddings,
...@@ -343,14 +340,16 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -343,14 +340,16 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
num_channels_latents = self.decoder.in_channels num_channels_latents = self.decoder.in_channels
height = self.decoder.sample_size height = self.decoder.sample_size
width = self.decoder.sample_size width = self.decoder.sample_size
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width), if decoder_latents is None:
text_encoder_hidden_states.dtype, decoder_latents = self.prepare_latents(
device, (batch_size, num_channels_latents, height, width),
generator, text_encoder_hidden_states.dtype,
decoder_latents, device,
self.decoder_scheduler, generator,
) decoder_latents,
self.decoder_scheduler,
)
for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
...@@ -395,14 +394,16 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -395,14 +394,16 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
channels = self.super_res_first.in_channels // 2 channels = self.super_res_first.in_channels // 2
height = self.super_res_first.sample_size height = self.super_res_first.sample_size
width = self.super_res_first.sample_size width = self.super_res_first.sample_size
super_res_latents = self.prepare_latents(
(batch_size, channels, height, width), if super_res_latents is None:
image_small.dtype, super_res_latents = self.prepare_latents(
device, (batch_size, channels, height, width),
generator, image_small.dtype,
super_res_latents, device,
self.super_res_scheduler, generator,
) super_res_latents,
self.super_res_scheduler,
)
interpolate_antialias = {} interpolate_antialias = {}
if "antialias" in inspect.signature(F.interpolate).parameters: if "antialias" in inspect.signature(F.interpolate).parameters:
......
...@@ -248,6 +248,120 @@ class UnCLIPPipelineFastTests(unittest.TestCase): ...@@ -248,6 +248,120 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_unclip_passed_text_embed(self):
device = torch.device("cpu")
class DummyScheduler:
init_noise_sigma = 1
prior = self.dummy_prior
decoder = self.dummy_decoder
text_proj = self.dummy_text_proj
text_encoder = self.dummy_text_encoder
tokenizer = self.dummy_tokenizer
super_res_first = self.dummy_super_res_first
super_res_last = self.dummy_super_res_last
prior_scheduler = UnCLIPScheduler(
variance_type="fixed_small_log",
prediction_type="sample",
num_train_timesteps=1000,
clip_sample_range=5.0,
)
decoder_scheduler = UnCLIPScheduler(
variance_type="learned_range",
prediction_type="epsilon",
num_train_timesteps=1000,
)
super_res_scheduler = UnCLIPScheduler(
variance_type="fixed_small_log",
prediction_type="epsilon",
num_train_timesteps=1000,
)
pipe = UnCLIPPipeline(
prior=prior,
decoder=decoder,
text_proj=text_proj,
text_encoder=text_encoder,
tokenizer=tokenizer,
super_res_first=super_res_first,
super_res_last=super_res_last,
prior_scheduler=prior_scheduler,
decoder_scheduler=decoder_scheduler,
super_res_scheduler=super_res_scheduler,
)
pipe = pipe.to(device)
generator = torch.Generator(device=device).manual_seed(0)
dtype = prior.dtype
batch_size = 1
shape = (batch_size, prior.config.embedding_dim)
prior_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
shape = (batch_size, decoder.in_channels, decoder.sample_size, decoder.sample_size)
decoder_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
shape = (
batch_size,
super_res_first.in_channels // 2,
super_res_first.sample_size,
super_res_first.sample_size,
)
super_res_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
pipe.set_progress_bar_config(disable=None)
prompt = "this is a prompt example"
generator = torch.Generator(device=device).manual_seed(0)
output = pipe(
[prompt],
generator=generator,
prior_num_inference_steps=2,
decoder_num_inference_steps=2,
super_res_num_inference_steps=2,
prior_latents=prior_latents,
decoder_latents=decoder_latents,
super_res_latents=super_res_latents,
output_type="np",
)
image = output.images
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
)
text_model_output = text_encoder(text_inputs.input_ids)
text_attention_mask = text_inputs.attention_mask
generator = torch.Generator(device=device).manual_seed(0)
image_from_text = pipe(
generator=generator,
prior_num_inference_steps=2,
decoder_num_inference_steps=2,
super_res_num_inference_steps=2,
prior_latents=prior_latents,
decoder_latents=decoder_latents,
super_res_latents=super_res_latents,
text_model_output=text_model_output,
text_attention_mask=text_attention_mask,
output_type="np",
)[0]
# make sure passing text embeddings manually is identical
assert np.abs(image - image_from_text).max() < 1e-4
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -407,6 +407,55 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -407,6 +407,55 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_unclip_passed_image_embed(self):
device = torch.device("cpu")
seed = 0
class DummyScheduler:
init_noise_sigma = 1
pipe = self.get_pipeline(device)
generator = torch.Generator(device=device).manual_seed(0)
dtype = pipe.decoder.dtype
batch_size = 1
shape = (batch_size, pipe.decoder.in_channels, pipe.decoder.sample_size, pipe.decoder.sample_size)
decoder_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
shape = (
batch_size,
pipe.super_res_first.in_channels // 2,
pipe.super_res_first.sample_size,
pipe.super_res_first.sample_size,
)
super_res_latents = pipe.prepare_latents(
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
pipeline_inputs = self.get_pipeline_inputs(device, seed)
img_out_1 = pipe(
**pipeline_inputs, decoder_latents=decoder_latents, super_res_latents=super_res_latents
).images
pipeline_inputs = self.get_pipeline_inputs(device, seed)
# Don't pass image, instead pass embedding
image = pipeline_inputs.pop("image")
image_embeddings = pipe.image_encoder(image).image_embeds
img_out_2 = pipe(
**pipeline_inputs,
decoder_latents=decoder_latents,
super_res_latents=super_res_latents,
image_embeddings=image_embeddings,
).images
# make sure passing text embeddings manually is identical
assert np.abs(img_out_1 - img_out_2).max() < 1e-4
@slow @slow
@require_torch_gpu @require_torch_gpu
...@@ -426,11 +475,10 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase): ...@@ -426,11 +475,10 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
"/unclip/karlo_v1_alpha_cat_variation_fp16.npy" "/unclip/karlo_v1_alpha_cat_variation_fp16.npy"
) )
pipeline = UnCLIPImageVariationPipeline.from_pretrained( pipeline = UnCLIPImageVariationPipeline.from_pretrained("fusing/karlo-image-variations-diffusers")
"fusing/karlo-image-variations-diffusers", torch_dtype=torch.float16
)
pipeline = pipeline.to(torch_device) pipeline = pipeline.to(torch_device)
pipeline.set_progress_bar_config(disable=None) pipeline.set_progress_bar_config(disable=None)
pipeline.enable_sequential_cpu_offload()
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipeline( output = pipeline(
...@@ -442,7 +490,5 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase): ...@@ -442,7 +490,5 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
image = output.images[0] image = output.images[0]
np.save("./karlo_v1_alpha_cat_variation_fp16.npy", image)
assert image.shape == (256, 256, 3) assert image.shape == (256, 256, 3)
assert np.abs(expected_image - image).max() < 1e-2 assert np.abs(expected_image - image).max() < 5e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment