Unverified Commit cd1b8d7c authored by dg845's avatar dg845 Committed by GitHub
Browse files

[WIP] Refactor UniDiffuser Pipeline and Tests (#4948)



* Add VAE slicing and tiling methods.

* Switch to using VaeImageProcessing for preprocessing and postprocessing of images.

* Rename the VaeImageProcessor to vae_image_processor to avoid a name clash with the CLIPImageProcessor (image_processor).

* Remove the postprocess() function because we're using a VaeImageProcessor instead.

* Remove UniDiffuserPipeline.decode_image_latents because we're using VaeImageProcessor instead.

* Refactor generating text from text latents into a decode_text_latents method.

* Add enable_full_determinism() to UniDiffuser tests.

* make style

* Add PipelineLatentTesterMixin to UniDiffuserPipelineFastTests.

* Remove enable_model_cpu_offload since it is now part of DiffusionPipeline.

* Rename the VaeImageProcessor instance to self.image_processor for consistency with other pipelines and rename the CLIPImageProcessor instance to clip_image_processor to avoid a name clash.

* Update UniDiffuser conversion script.

* Make safe_serialization configurable in UniDiffuser conversion script.

* Rename image_processor to clip_image_processor in UniDiffuser tests.

* Add PipelineKarrasSchedulerTesterMixin to UniDiffuserPipelineFastTests.

* Add initial test for compiling the UniDiffuser model (not tested yet).

* Update encode_prompt and _encode_prompt to match that of StableDiffusionPipeline.

* Turn off standard classifier-free guidance for now.

* make style

* make fix-copies

* apply suggestions from review

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent db91e710
...@@ -73,17 +73,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): ...@@ -73,17 +73,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias") new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "query.weight") new_item = new_item.replace("q.weight", "to_q.weight")
new_item = new_item.replace("q.bias", "query.bias") new_item = new_item.replace("q.bias", "to_q.bias")
new_item = new_item.replace("k.weight", "key.weight") new_item = new_item.replace("k.weight", "to_k.weight")
new_item = new_item.replace("k.bias", "key.bias") new_item = new_item.replace("k.bias", "to_k.bias")
new_item = new_item.replace("v.weight", "value.weight") new_item = new_item.replace("v.weight", "to_v.weight")
new_item = new_item.replace("v.bias", "value.bias") new_item = new_item.replace("v.bias", "to_v.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight") new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias") new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
...@@ -92,6 +92,19 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): ...@@ -92,6 +92,19 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
return mapping return mapping
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
# Modified from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint # Modified from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
# config.num_head_channels => num_head_channels # config.num_head_channels => num_head_channels
def assign_to_checkpoint( def assign_to_checkpoint(
...@@ -104,8 +117,9 @@ def assign_to_checkpoint( ...@@ -104,8 +117,9 @@ def assign_to_checkpoint(
): ):
""" """
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new attention layers, and takes into account additional replacements that may arise.
checkpoint.
Assigns the weights to the new checkpoint.
""" """
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
...@@ -143,25 +157,16 @@ def assign_to_checkpoint( ...@@ -143,25 +157,16 @@ def assign_to_checkpoint(
new_path = new_path.replace(replacement["old"], replacement["new"]) new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear # proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path: is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
shape = old_checkpoint[path["old"]].shape
if is_attn_weight and len(shape) == 3:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
elif is_attn_weight and len(shape) == 4:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
else: else:
checkpoint[new_path] = old_checkpoint[path["old"]] checkpoint[new_path] = old_checkpoint[path["old"]]
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def create_vae_diffusers_config(config_type): def create_vae_diffusers_config(config_type):
# Hardcoded for now # Hardcoded for now
if args.config_type == "test": if args.config_type == "test":
...@@ -339,7 +344,7 @@ def create_text_decoder_config_big(): ...@@ -339,7 +344,7 @@ def create_text_decoder_config_big():
return text_decoder_config return text_decoder_config
# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments.convert_ldm_vae_checkpoint # Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint
def convert_vae_to_diffusers(ckpt, diffusers_model, num_head_channels=1): def convert_vae_to_diffusers(ckpt, diffusers_model, num_head_channels=1):
""" """
Converts a UniDiffuser autoencoder_kl.pth checkpoint to a diffusers AutoencoderKL. Converts a UniDiffuser autoencoder_kl.pth checkpoint to a diffusers AutoencoderKL.
...@@ -674,6 +679,11 @@ if __name__ == "__main__": ...@@ -674,6 +679,11 @@ if __name__ == "__main__":
type=int, type=int,
help="The UniDiffuser model type to convert to. Should be 0 for UniDiffuser-v0 and 1 for UniDiffuser-v1.", help="The UniDiffuser model type to convert to. Should be 0 for UniDiffuser-v0 and 1 for UniDiffuser-v1.",
) )
parser.add_argument(
"--safe_serialization",
action="store_true",
help="Whether to use safetensors/safe seialization when saving the pipeline.",
)
args = parser.parse_args() args = parser.parse_args()
...@@ -766,11 +776,11 @@ if __name__ == "__main__": ...@@ -766,11 +776,11 @@ if __name__ == "__main__":
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
image_encoder=image_encoder, image_encoder=image_encoder,
image_processor=image_processor, clip_image_processor=image_processor,
clip_tokenizer=clip_tokenizer, clip_tokenizer=clip_tokenizer,
text_decoder=text_decoder, text_decoder=text_decoder,
text_tokenizer=text_tokenizer, text_tokenizer=text_tokenizer,
unet=unet, unet=unet,
scheduler=scheduler, scheduler=scheduler,
) )
pipeline.save_pretrained(args.pipeline_output_path) pipeline.save_pretrained(args.pipeline_output_path, safe_serialization=args.safe_serialization)
...@@ -13,9 +13,12 @@ from transformers import ( ...@@ -13,9 +13,12 @@ from transformers import (
GPT2Tokenizer, GPT2Tokenizer,
) )
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL from ...models import AutoencoderKL
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging from ...utils import deprecate, logging
from ...utils.outputs import BaseOutput from ...utils.outputs import BaseOutput
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -26,30 +29,6 @@ from .modeling_uvit import UniDiffuserModel ...@@ -26,30 +29,6 @@ from .modeling_uvit import UniDiffuserModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
# New BaseOutput child class for joint image-text output # New BaseOutput child class for joint image-text output
@dataclass @dataclass
class ImageTextPipelineOutput(BaseOutput): class ImageTextPipelineOutput(BaseOutput):
...@@ -111,7 +90,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -111,7 +90,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
vae: AutoencoderKL, vae: AutoencoderKL,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
image_encoder: CLIPVisionModelWithProjection, image_encoder: CLIPVisionModelWithProjection,
image_processor: CLIPImageProcessor, clip_image_processor: CLIPImageProcessor,
clip_tokenizer: CLIPTokenizer, clip_tokenizer: CLIPTokenizer,
text_decoder: UniDiffuserTextDecoder, text_decoder: UniDiffuserTextDecoder,
text_tokenizer: GPT2Tokenizer, text_tokenizer: GPT2Tokenizer,
...@@ -130,7 +109,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -130,7 +109,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
image_encoder=image_encoder, image_encoder=image_encoder,
image_processor=image_processor, clip_image_processor=clip_image_processor,
clip_tokenizer=clip_tokenizer, clip_tokenizer=clip_tokenizer,
text_decoder=text_decoder, text_decoder=text_decoder,
text_tokenizer=text_tokenizer, text_tokenizer=text_tokenizer,
...@@ -139,6 +118,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -139,6 +118,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.num_channels_latents = vae.config.latent_channels self.num_channels_latents = vae.config.latent_channels
self.text_encoder_seq_len = text_encoder.config.max_position_embeddings self.text_encoder_seq_len = text_encoder.config.max_position_embeddings
...@@ -155,43 +135,38 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -155,43 +135,38 @@ class UniDiffuserPipeline(DiffusionPipeline):
# TODO: handle safety checking? # TODO: handle safety checking?
self.safety_checker = None self.safety_checker = None
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
# Add self.image_encoder, self.text_decoder to cpu_offloaded_models list def enable_vae_slicing(self):
def enable_model_cpu_offload(self, gpu_id=0):
r""" r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
""" """
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): self.vae.enable_slicing()
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
for cpu_offloaded_model in [ def disable_vae_slicing(self):
self.text_encoder.text_model, r"""
self.image_encoder, Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
self.unet, computing decoding in one step.
self.vae, """
self.text_decoder.encode_prefix, self.vae.disable_slicing()
self.text_decoder.decode_prefix,
self.text_decoder,
]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
if self.safety_checker is not None: # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
# We'll offload the last model manually. # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
self.final_offload_hook = hook def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
...@@ -370,8 +345,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -370,8 +345,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
) )
return batch_size, multiplier return batch_size, multiplier
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
# self.tokenizer => self.clip_tokenizer
def _encode_prompt( def _encode_prompt(
self, self,
prompt, prompt,
...@@ -381,6 +355,41 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -381,6 +355,41 @@ class UniDiffuserPipeline(DiffusionPipeline):
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
**kwargs,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
**kwargs,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with self.tokenizer->self.clip_tokenizer
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -396,8 +405,8 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -396,8 +405,8 @@ class UniDiffuserPipeline(DiffusionPipeline):
whether to use classifier free guidance or not whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*): negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
...@@ -405,7 +414,20 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -405,7 +414,20 @@ class UniDiffuserPipeline(DiffusionPipeline):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
...@@ -414,6 +436,10 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -414,6 +436,10 @@ class UniDiffuserPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.clip_tokenizer)
text_inputs = self.clip_tokenizer( text_inputs = self.clip_tokenizer(
prompt, prompt,
padding="max_length", padding="max_length",
...@@ -440,13 +466,31 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -440,13 +466,31 @@ class UniDiffuserPipeline(DiffusionPipeline):
else: else:
attention_mask = None attention_mask = None
prompt_embeds = self.text_encoder( if clip_skip is None:
text_input_ids.to(device), prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
attention_mask=attention_mask, prompt_embeds = prompt_embeds[0]
) else:
prompt_embeds = prompt_embeds[0] prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
...@@ -458,7 +502,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -458,7 +502,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt): elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError( raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}." f" {type(prompt)}."
...@@ -474,6 +518,10 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -474,6 +518,10 @@ class UniDiffuserPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.clip_tokenizer)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_input = self.clip_tokenizer( uncond_input = self.clip_tokenizer(
uncond_tokens, uncond_tokens,
...@@ -498,17 +546,12 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -498,17 +546,12 @@ class UniDiffuserPipeline(DiffusionPipeline):
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes. return prompt_embeds, negative_prompt_embeds
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents
# Add num_prompts_per_image argument, sample from autoencoder moment distribution # Add num_prompts_per_image argument, sample from autoencoder moment distribution
...@@ -587,7 +630,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -587,7 +630,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
) )
preprocessed_image = self.image_processor.preprocess( preprocessed_image = self.clip_image_processor.preprocess(
image, image,
return_tensors="pt", return_tensors="pt",
) )
...@@ -628,17 +671,6 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -628,17 +671,6 @@ class UniDiffuserPipeline(DiffusionPipeline):
return image_latents return image_latents
# Note that the CLIP latents are not decoded for image generation.
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
# Rename: decode_latents -> decode_image_latents
def decode_image_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def prepare_text_latents( def prepare_text_latents(
self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None
): ):
...@@ -720,6 +752,17 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -720,6 +752,17 @@ class UniDiffuserPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma latents = latents * self.scheduler.init_noise_sigma
return latents return latents
def decode_text_latents(self, text_latents, device):
output_token_list, seq_lengths = self.text_decoder.generate_captions(
text_latents, self.text_tokenizer.eos_token_id, device=device
)
output_list = output_token_list.cpu().numpy()
generated_text = [
self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True)
for output, length in zip(output_list, seq_lengths)
]
return generated_text
def _split(self, x, height, width): def _split(self, x, height, width):
r""" r"""
Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W) Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W)
...@@ -1181,7 +1224,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -1181,7 +1224,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
# Note that this differs from the formulation in the unidiffusers paper! # Note that this differs from the formulation in the unidiffusers paper!
# do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# check if scheduler is in sigmas space # check if scheduler is in sigmas space
# scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") # scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
...@@ -1194,15 +1237,18 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -1194,15 +1237,18 @@ class UniDiffuserPipeline(DiffusionPipeline):
if mode in ["text2img"]: if mode in ["text2img"]:
# 3.1. Encode input prompt, if available # 3.1. Encode input prompt, if available
assert prompt is not None or prompt_embeds is not None assert prompt is not None or prompt_embeds is not None
prompt_embeds = self._encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt, prompt=prompt,
device=device, device=device,
num_images_per_prompt=multiplier, num_images_per_prompt=multiplier,
do_classifier_free_guidance=False, # don't support standard classifier-free guidance for now do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
) )
# if do_classifier_free_guidance:
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
else: else:
# 3.2. Prepare text latent variables, if input not available # 3.2. Prepare text latent variables, if input not available
prompt_embeds = self.prepare_text_latents( prompt_embeds = self.prepare_text_latents(
...@@ -1224,7 +1270,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -1224,7 +1270,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
# 4.1. Encode images, if available # 4.1. Encode images, if available
assert image is not None, "`img2text` requires a conditioning image" assert image is not None, "`img2text` requires a conditioning image"
# Encode image using VAE # Encode image using VAE
image_vae = preprocess(image) image_vae = self.image_processor.preprocess(image)
height, width = image_vae.shape[-2:] height, width = image_vae.shape[-2:]
image_vae_latents = self.encode_image_vae_latents( image_vae_latents = self.encode_image_vae_latents(
image=image_vae, image=image_vae,
...@@ -1324,48 +1370,42 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -1324,48 +1370,42 @@ class UniDiffuserPipeline(DiffusionPipeline):
callback(i, t, latents) callback(i, t, latents)
# 9. Post-processing # 9. Post-processing
gen_image = None image = None
gen_text = None text = None
if mode == "joint": if mode == "joint":
image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width) image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width)
# Map latent VAE image back to pixel space if not output_type == "latent":
gen_image = self.decode_image_latents(image_vae_latents) # Map latent VAE image back to pixel space
image = self.vae.decode(image_vae_latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = image_vae_latents
# Generate text using the text decoder text = self.decode_text_latents(text_latents, device)
output_token_list, seq_lengths = self.text_decoder.generate_captions(
text_latents, self.text_tokenizer.eos_token_id, device=device
)
output_list = output_token_list.cpu().numpy()
gen_text = [
self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True)
for output, length in zip(output_list, seq_lengths)
]
elif mode in ["text2img", "img"]: elif mode in ["text2img", "img"]:
image_vae_latents, image_clip_latents = self._split(latents, height, width) image_vae_latents, image_clip_latents = self._split(latents, height, width)
gen_image = self.decode_image_latents(image_vae_latents)
if not output_type == "latent":
# Map latent VAE image back to pixel space
image = self.vae.decode(image_vae_latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
image = image_vae_latents
elif mode in ["img2text", "text"]: elif mode in ["img2text", "text"]:
text_latents = latents text_latents = latents
output_token_list, seq_lengths = self.text_decoder.generate_captions( text = self.decode_text_latents(text_latents, device)
text_latents, self.text_tokenizer.eos_token_id, device=device
)
output_list = output_token_list.cpu().numpy()
gen_text = [
self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True)
for output, length in zip(output_list, seq_lengths)
]
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
# 10. Convert to PIL # 10. Postprocess the image, if necessary
if output_type == "pil" and gen_image is not None: if image is not None:
gen_image = self.numpy_to_pil(gen_image) do_denormalize = [True] * image.shape[0]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload() self.final_offload_hook.offload()
if not return_dict: if not return_dict:
return (gen_image, gen_text) return (image, text)
return ImageTextPipelineOutput(images=gen_image, text=gen_text) return ImageTextPipelineOutput(images=image, text=text)
import gc import gc
import random import random
import traceback
import unittest import unittest
import numpy as np import numpy as np
...@@ -20,17 +21,70 @@ from diffusers import ( ...@@ -20,17 +21,70 @@ from diffusers import (
UniDiffuserPipeline, UniDiffuserPipeline,
UniDiffuserTextDecoder, UniDiffuserTextDecoder,
) )
from diffusers.utils.testing_utils import floats_tensor, load_image, nightly, require_torch_gpu, torch_device from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
load_image,
nightly,
require_torch_2,
require_torch_gpu,
run_test_in_subprocess,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..pipeline_params import (
from ..test_pipelines_common import PipelineTesterMixin IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
enable_full_determinism()
# Will be run via run_test_in_subprocess
def _test_unidiffuser_compile(in_queue, out_queue, timeout):
error = None
try:
inputs = in_queue.get(timeout=timeout)
torch_device = inputs.pop("torch_device")
seed = inputs.pop("seed")
inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(torch_device)
pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.set_progress_bar_config(disable=None)
class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520])
assert np.abs(image_slice - expected_slice).max() < 1e-1
except Exception:
error = f"{traceback.format_exc()}"
results = {"error": error}
out_queue.put(results, timeout=timeout)
out_queue.join()
class UniDiffuserPipelineFastTests(
PipelineTesterMixin, PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
):
pipeline_class = UniDiffuserPipeline pipeline_class = UniDiffuserPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
# vae_latents, not latents, is the argument that corresponds to VAE latent inputs
image_latents_params = frozenset(["vae_latents"])
def get_dummy_components(self): def get_dummy_components(self):
unet = UniDiffuserModel.from_pretrained( unet = UniDiffuserModel.from_pretrained(
...@@ -64,7 +118,7 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -64,7 +118,7 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
subfolder="image_encoder", subfolder="image_encoder",
) )
# From the Stable Diffusion Image Variation pipeline tests # From the Stable Diffusion Image Variation pipeline tests
image_processor = CLIPImageProcessor(crop_size=32, size=32) clip_image_processor = CLIPImageProcessor(crop_size=32, size=32)
# image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip") # image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip")
text_tokenizer = GPT2Tokenizer.from_pretrained( text_tokenizer = GPT2Tokenizer.from_pretrained(
...@@ -80,7 +134,7 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -80,7 +134,7 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"vae": vae, "vae": vae,
"text_encoder": text_encoder, "text_encoder": text_encoder,
"image_encoder": image_encoder, "image_encoder": image_encoder,
"image_processor": image_processor, "clip_image_processor": clip_image_processor,
"clip_tokenizer": clip_tokenizer, "clip_tokenizer": clip_tokenizer,
"text_decoder": text_decoder, "text_decoder": text_decoder,
"text_tokenizer": text_tokenizer, "text_tokenizer": text_tokenizer,
...@@ -619,6 +673,19 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase): ...@@ -619,6 +673,19 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase):
expected_text_prefix = "An astronaut" expected_text_prefix = "An astronaut"
assert text[0][: len(expected_text_prefix)] == expected_text_prefix assert text[0][: len(expected_text_prefix)] == expected_text_prefix
@unittest.skip(reason="Skip torch.compile test to speed up the slow test suite.")
@require_torch_2
def test_unidiffuser_compile(self, seed=0):
inputs = self.get_inputs(torch_device, seed=seed, generate_latents=True)
# Delete prompt and image for joint inference.
del inputs["prompt"]
del inputs["image"]
# Can't pickle a Generator object
del inputs["generator"]
inputs["torch_device"] = torch_device
inputs["seed"] = seed
run_test_in_subprocess(test_case=self, target_func=_test_unidiffuser_compile, inputs=inputs)
@nightly @nightly
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment