Unverified Commit 05679329 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[Hi Dream] follow-up (#11296)

* add
parent 29d2afbf
...@@ -8,7 +8,7 @@ from ...configuration_utils import ConfigMixin, register_to_config ...@@ -8,7 +8,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin from ...loaders import PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention from ..attention import Attention
from ..embeddings import TimestepEmbedding, Timesteps from ..embeddings import TimestepEmbedding, Timesteps
...@@ -686,46 +686,108 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -686,46 +686,108 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
x = torch.cat(x_arr, dim=0) x = torch.cat(x_arr, dim=0)
return x return x
def patchify(self, x, max_seq, img_sizes=None): def patchify(self, hidden_states):
pz2 = self.config.patch_size * self.config.patch_size batch_size, channels, height, width = hidden_states.shape
if isinstance(x, torch.Tensor): patch_size = self.config.patch_size
B, C = x.shape[0], x.shape[1] patch_height, patch_width = height // patch_size, width // patch_size
device = x.device device = hidden_states.device
dtype = x.dtype dtype = hidden_states.dtype
# create img_sizes
img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)
# create hidden_states_masks
if hidden_states.shape[-2] != hidden_states.shape[-1]:
hidden_states_masks = torch.zeros((batch_size, self.max_seq), dtype=dtype, device=device)
hidden_states_masks[:, : patch_height * patch_width] = 1.0
else: else:
B, C = len(x), x[0].shape[0] hidden_states_masks = None
device = x[0].device
dtype = x[0].dtype # create img_ids
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device) img_ids = torch.zeros(patch_height, patch_width, 3, device=device)
row_indices = torch.arange(patch_height, device=device)[:, None]
col_indices = torch.arange(patch_width, device=device)[None, :]
img_ids[..., 1] = img_ids[..., 1] + row_indices
img_ids[..., 2] = img_ids[..., 2] + col_indices
img_ids = img_ids.reshape(patch_height * patch_width, -1)
if hidden_states.shape[-2] != hidden_states.shape[-1]:
# Handle non-square latents
img_ids_pad = torch.zeros(self.max_seq, 3, device=device)
img_ids_pad[: patch_height * patch_width, :] = img_ids
img_ids = img_ids_pad.unsqueeze(0).repeat(batch_size, 1, 1)
else:
img_ids = img_ids.unsqueeze(0).repeat(batch_size, 1, 1)
# patchify hidden_states
if hidden_states.shape[-2] != hidden_states.shape[-1]:
# Handle non-square latents
out = torch.zeros(
(batch_size, channels, self.max_seq, patch_size * patch_size),
dtype=dtype,
device=device,
)
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height, patch_size, patch_width, patch_size
)
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height * patch_width, patch_size * patch_size
)
out[:, :, 0 : patch_height * patch_width] = hidden_states
hidden_states = out
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
batch_size, self.max_seq, patch_size * patch_size * channels
)
if img_sizes is not None:
for i, img_size in enumerate(img_sizes):
x_masks[i, 0 : img_size[0] * img_size[1]] = 1
B, C, S, _ = x.shape
x = x.permute(0, 2, 3, 1).reshape(B, S, pz2 * C)
elif isinstance(x, torch.Tensor):
B, C, Hp1, Wp2 = x.shape
pH, pW = Hp1 // self.config.patch_size, Wp2 // self.config.patch_size
x = x.reshape(B, C, pH, self.config.patch_size, pW, self.config.patch_size)
x = x.permute(0, 2, 4, 3, 5, 1)
x = x.reshape(B, pH * pW, self.config.patch_size * self.config.patch_size * C)
img_sizes = [[pH, pW]] * B
x_masks = None
else: else:
raise NotImplementedError # Handle square latents
return x, x_masks, img_sizes hidden_states = hidden_states.reshape(
batch_size, channels, patch_height, patch_size, patch_width, patch_size
)
hidden_states = hidden_states.permute(0, 2, 4, 3, 5, 1)
hidden_states = hidden_states.reshape(
batch_size, patch_height * patch_width, patch_size * patch_size * channels
)
return hidden_states, hidden_states_masks, img_sizes, img_ids
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
timesteps: torch.LongTensor = None, timesteps: torch.LongTensor = None,
encoder_hidden_states: torch.Tensor = None, encoder_hidden_states_t5: torch.Tensor = None,
encoder_hidden_states_llama3: torch.Tensor = None,
pooled_embeds: torch.Tensor = None, pooled_embeds: torch.Tensor = None,
img_sizes: Optional[List[Tuple[int, int]]] = None,
img_ids: Optional[torch.Tensor] = None, img_ids: Optional[torch.Tensor] = None,
img_sizes: Optional[List[Tuple[int, int]]] = None,
hidden_states_masks: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
**kwargs,
): ):
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if encoder_hidden_states is not None:
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
deprecate("encoder_hidden_states", "0.34.0", deprecation_message)
encoder_hidden_states_t5 = encoder_hidden_states[0]
encoder_hidden_states_llama3 = encoder_hidden_states[1]
if img_ids is not None and img_sizes is not None and hidden_states_masks is None:
deprecation_message = (
"Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
)
deprecate("img_ids", "0.34.0", deprecation_message)
if hidden_states_masks is not None and (img_ids is None or img_sizes is None):
raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.")
elif hidden_states_masks is not None and hidden_states.ndim != 3:
raise ValueError(
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
)
if attention_kwargs is not None: if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy() attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0) lora_scale = attention_kwargs.pop("scale", 1.0)
...@@ -745,42 +807,19 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -745,42 +807,19 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
hidden_states_type = hidden_states.dtype hidden_states_type = hidden_states.dtype
if hidden_states.shape[-2] != hidden_states.shape[-1]: # Patchify the input
B, C, H, W = hidden_states.shape if hidden_states_masks is None:
patch_size = self.config.patch_size hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)
pH, pW = H // patch_size, W // patch_size
out = torch.zeros( # Embed the hidden states
(B, C, self.max_seq, patch_size * patch_size), hidden_states = self.x_embedder(hidden_states)
dtype=hidden_states.dtype,
device=hidden_states.device,
)
hidden_states = hidden_states.reshape(B, C, pH, patch_size, pW, patch_size)
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
hidden_states = hidden_states.reshape(B, C, pH * pW, patch_size * patch_size)
out[:, :, 0 : pH * pW] = hidden_states
hidden_states = out
# 0. time # 0. time
timesteps = self.t_embedder(timesteps, hidden_states_type) timesteps = self.t_embedder(timesteps, hidden_states_type)
p_embedder = self.p_embedder(pooled_embeds) p_embedder = self.p_embedder(pooled_embeds)
temb = timesteps + p_embedder temb = timesteps + p_embedder
hidden_states, hidden_states_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.config.llama_layers]
if hidden_states_masks is None:
pH, pW = img_sizes[0]
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
img_ids = (
img_ids.reshape(img_ids.shape[0] * img_ids.shape[1], img_ids.shape[2])
.unsqueeze(0)
.repeat(batch_size, 1, 1)
)
hidden_states = self.x_embedder(hidden_states)
T5_encoder_hidden_states = encoder_hidden_states[0]
encoder_hidden_states = encoder_hidden_states[-1]
encoder_hidden_states = [encoder_hidden_states[k] for k in self.config.llama_layers]
if self.caption_projection is not None: if self.caption_projection is not None:
new_encoder_hidden_states = [] new_encoder_hidden_states = []
...@@ -789,9 +828,9 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -789,9 +828,9 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
new_encoder_hidden_states.append(enc_hidden_state) new_encoder_hidden_states.append(enc_hidden_state)
encoder_hidden_states = new_encoder_hidden_states encoder_hidden_states = new_encoder_hidden_states
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5)
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states.append(T5_encoder_hidden_states) encoder_hidden_states.append(encoder_hidden_states_t5)
txt_ids = torch.zeros( txt_ids = torch.zeros(
batch_size, batch_size,
......
...@@ -43,7 +43,7 @@ enable_full_determinism() ...@@ -43,7 +43,7 @@ enable_full_determinism()
class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = HiDreamImagePipeline pipeline_class = HiDreamImagePipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "prompt_embeds", "negative_prompt_embeds"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment