"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "af28ae2d5ba0ef80d99fff7859ebea730e1cf3f8"
Unverified Commit e6fd9ada authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[I2vGenXL] clean up things (#6845)

* remove _to_tensor

* remove _to_tensor definition

* remove _collapse_frames_into_batch

* remove lora for not bloating the code.

* remove sample_size.

* simplify code a bit more

* ensure timesteps are always in tensor.
parent 493228a7
...@@ -48,29 +48,6 @@ from .unet_3d_condition import UNet3DConditionOutput ...@@ -48,29 +48,6 @@ from .unet_3d_condition import UNet3DConditionOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _to_tensor(inputs, device):
if not torch.is_tensor(inputs):
# TODO: this requires sync between CPU and GPU. So try to pass `inputs` as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = device.type == "mps"
if isinstance(inputs, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
inputs = torch.tensor([inputs], dtype=dtype, device=device)
elif len(inputs.shape) == 0:
inputs = inputs[None].to(device)
return inputs
def _collapse_frames_into_batch(sample: torch.Tensor) -> torch.Tensor:
batch_size, channels, num_frames, height, width = sample.shape
sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
return sample
class I2VGenXLTransformerTemporalEncoder(nn.Module): class I2VGenXLTransformerTemporalEncoder(nn.Module):
def __init__( def __init__(
self, self,
...@@ -174,8 +151,6 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -174,8 +151,6 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
): ):
super().__init__() super().__init__()
self.sample_size = sample_size
# Check inputs # Check inputs
if len(down_block_types) != len(up_block_types): if len(down_block_types) != len(up_block_types):
raise ValueError( raise ValueError(
...@@ -543,7 +518,18 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -543,7 +518,18 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
forward_upsample_size = True forward_upsample_size = True
# 1. time # 1. time
timesteps = _to_tensor(timestep, sample.device) timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0]) timesteps = timesteps.expand(sample.shape[0])
...@@ -572,7 +558,13 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -572,7 +558,13 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
context_emb = sample.new_zeros(batch_size, 0, self.config.cross_attention_dim) context_emb = sample.new_zeros(batch_size, 0, self.config.cross_attention_dim)
context_emb = torch.cat([context_emb, encoder_hidden_states], dim=1) context_emb = torch.cat([context_emb, encoder_hidden_states], dim=1)
image_latents_context_embs = _collapse_frames_into_batch(image_latents[:, :, :1, :]) image_latents_for_context_embds = image_latents[:, :, :1, :]
image_latents_context_embs = image_latents_for_context_embds.permute(0, 2, 1, 3, 4).reshape(
image_latents_for_context_embds.shape[0] * image_latents_for_context_embds.shape[2],
image_latents_for_context_embds.shape[1],
image_latents_for_context_embds.shape[3],
image_latents_for_context_embds.shape[4],
)
image_latents_context_embs = self.image_latents_context_embedding(image_latents_context_embs) image_latents_context_embs = self.image_latents_context_embedding(image_latents_context_embs)
_batch_size, _channels, _height, _width = image_latents_context_embs.shape _batch_size, _channels, _height, _width = image_latents_context_embs.shape
...@@ -586,7 +578,12 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -586,7 +578,12 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
context_emb = torch.cat([context_emb, image_emb], dim=1) context_emb = torch.cat([context_emb, image_emb], dim=1)
context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0) context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0)
image_latents = _collapse_frames_into_batch(image_latents) image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(
image_latents.shape[0] * image_latents.shape[2],
image_latents.shape[1],
image_latents.shape[3],
image_latents.shape[4],
)
image_latents = self.image_latents_proj_in(image_latents) image_latents = self.image_latents_proj_in(image_latents)
image_latents = ( image_latents = (
image_latents[None, :] image_latents[None, :]
......
...@@ -22,18 +22,13 @@ import torch ...@@ -22,18 +22,13 @@ import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin
from ...models import AutoencoderKL from ...models import AutoencoderKL
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND,
BaseOutput, BaseOutput,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -207,7 +202,6 @@ class I2VGenXLPipeline(DiffusionPipeline): ...@@ -207,7 +202,6 @@ class I2VGenXLPipeline(DiffusionPipeline):
negative_prompt=None, negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
): ):
r""" r"""
...@@ -233,23 +227,10 @@ class I2VGenXLPipeline(DiffusionPipeline): ...@@ -233,23 +227,10 @@ class I2VGenXLPipeline(DiffusionPipeline):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*): clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings. the output of the pre-final layer will be used for computing the prompt embeddings.
""" """
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif prompt is not None and isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
...@@ -380,10 +361,6 @@ class I2VGenXLPipeline(DiffusionPipeline): ...@@ -380,10 +361,6 @@ class I2VGenXLPipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
def _encode_image(self, image, device, num_videos_per_prompt): def _encode_image(self, image, device, num_videos_per_prompt):
...@@ -706,9 +683,6 @@ class I2VGenXLPipeline(DiffusionPipeline): ...@@ -706,9 +683,6 @@ class I2VGenXLPipeline(DiffusionPipeline):
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
# 3.1 Encode input text prompt # 3.1 Encode input text prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
...@@ -716,7 +690,6 @@ class I2VGenXLPipeline(DiffusionPipeline): ...@@ -716,7 +690,6 @@ class I2VGenXLPipeline(DiffusionPipeline):
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip, clip_skip=clip_skip,
) )
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment