"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "a8f563dbf8520020054aa01f5ae169999775fd19"
Unverified Commit 97e0ef4d authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Hidream refactoring follow ups (#11299)



* HiDream Image

* update

* -einops

* py3.8

* fix -einops

* mixins, offload_seq, option_components

* docs

* Apply style fixes

* trigger tests

* Apply suggestions from code review
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>

* joint_attention_kwargs -> attention_kwargs, fixes

* fast tests

* -_init_weights

* style tests

* move reshape logic

* update slice 😴

* supports_dduf

* 🤷🏻

‍♂️

* Update src/diffusers/models/transformers/transformer_hidream_image.py
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>

* address review comments

* update tests

* doc updates

* update

* Update src/diffusers/models/transformers/transformer_hidream_image.py

* Apply style fixes

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent ed41db85
...@@ -604,8 +604,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -604,8 +604,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
): ):
super().__init__() super().__init__()
self.out_channels = out_channels or in_channels self.out_channels = out_channels or in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.inner_dim = num_attention_heads * attention_head_dim
self.llama_layers = llama_layers
self.t_embedder = HiDreamImageTimestepEmbed(self.inner_dim) self.t_embedder = HiDreamImageTimestepEmbed(self.inner_dim)
self.p_embedder = HiDreamImagePooledEmbed(text_emb_dim, self.inner_dim) self.p_embedder = HiDreamImagePooledEmbed(text_emb_dim, self.inner_dim)
...@@ -621,13 +620,13 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -621,13 +620,13 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
HiDreamBlock( HiDreamBlock(
HiDreamImageTransformerBlock( HiDreamImageTransformerBlock(
dim=self.inner_dim, dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads, num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim, attention_head_dim=attention_head_dim,
num_routed_experts=num_routed_experts, num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts, num_activated_experts=num_activated_experts,
) )
) )
for _ in range(self.config.num_layers) for _ in range(num_layers)
] ]
) )
...@@ -636,42 +635,26 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -636,42 +635,26 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
HiDreamBlock( HiDreamBlock(
HiDreamImageSingleTransformerBlock( HiDreamImageSingleTransformerBlock(
dim=self.inner_dim, dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads, num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim, attention_head_dim=attention_head_dim,
num_routed_experts=num_routed_experts, num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts, num_activated_experts=num_activated_experts,
) )
) )
for _ in range(self.config.num_single_layers) for _ in range(num_single_layers)
] ]
) )
self.final_layer = HiDreamImageOutEmbed(self.inner_dim, patch_size, self.out_channels) self.final_layer = HiDreamImageOutEmbed(self.inner_dim, patch_size, self.out_channels)
caption_channels = [ caption_channels = [caption_channels[1]] * (num_layers + num_single_layers) + [caption_channels[0]]
caption_channels[1],
] * (num_layers + num_single_layers) + [
caption_channels[0],
]
caption_projection = [] caption_projection = []
for caption_channel in caption_channels: for caption_channel in caption_channels:
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim)) caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim))
self.caption_projection = nn.ModuleList(caption_projection) self.caption_projection = nn.ModuleList(caption_projection)
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size) self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
def expand_timesteps(self, timesteps, batch_size, device): self.gradient_checkpointing = False
if not torch.is_tensor(timesteps):
is_mps = device.type == "mps"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(batch_size)
return timesteps
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]: def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
if is_training: if is_training:
...@@ -773,7 +756,6 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -773,7 +756,6 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
hidden_states = out hidden_states = out
# 0. time # 0. time
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
timesteps = self.t_embedder(timesteps, hidden_states_type) timesteps = self.t_embedder(timesteps, hidden_states_type)
p_embedder = self.p_embedder(pooled_embeds) p_embedder = self.p_embedder(pooled_embeds)
temb = timesteps + p_embedder temb = timesteps + p_embedder
...@@ -793,7 +775,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -793,7 +775,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
T5_encoder_hidden_states = encoder_hidden_states[0] T5_encoder_hidden_states = encoder_hidden_states[0]
encoder_hidden_states = encoder_hidden_states[-1] encoder_hidden_states = encoder_hidden_states[-1]
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] encoder_hidden_states = [encoder_hidden_states[k] for k in self.config.llama_layers]
if self.caption_projection is not None: if self.caption_projection is not None:
new_encoder_hidden_states = [] new_encoder_hidden_states = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment