Unverified Commit 8b4f8ba7 authored by hlky's avatar hlky Committed by GitHub
Browse files

Use `output_size` in `repeat_interleave` (#11030)

parent 54280464
...@@ -741,10 +741,14 @@ class Attention(nn.Module): ...@@ -741,10 +741,14 @@ class Attention(nn.Module):
if out_dim == 3: if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size: if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0) attention_mask = attention_mask.repeat_interleave(
head_size, dim=0, output_size=attention_mask.shape[0] * head_size
)
elif out_dim == 4: elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1) attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1) attention_mask = attention_mask.repeat_interleave(
head_size, dim=1, output_size=attention_mask.shape[1] * head_size
)
return attention_mask return attention_mask
...@@ -3704,8 +3708,10 @@ class StableAudioAttnProcessor2_0: ...@@ -3704,8 +3708,10 @@ class StableAudioAttnProcessor2_0:
if kv_heads != attn.heads: if kv_heads != attn.heads:
# if GQA or MQA, repeat the key/value heads to reach the number of query heads. # if GQA or MQA, repeat the key/value heads to reach the number of query heads.
heads_per_kv_head = attn.heads // kv_heads heads_per_kv_head = attn.heads // kv_heads
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head)
value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) value = torch.repeat_interleave(
value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head
)
if attn.norm_q is not None: if attn.norm_q is not None:
query = attn.norm_q(query) query = attn.norm_q(query)
......
...@@ -190,7 +190,7 @@ class DCUpBlock2d(nn.Module): ...@@ -190,7 +190,7 @@ class DCUpBlock2d(nn.Module):
x = F.pixel_shuffle(x, self.factor) x = F.pixel_shuffle(x, self.factor)
if self.shortcut: if self.shortcut:
y = hidden_states.repeat_interleave(self.repeats, dim=1) y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
y = F.pixel_shuffle(y, self.factor) y = F.pixel_shuffle(y, self.factor)
hidden_states = x + y hidden_states = x + y
else: else:
...@@ -361,7 +361,9 @@ class Decoder(nn.Module): ...@@ -361,7 +361,9 @@ class Decoder(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.in_shortcut: if self.in_shortcut:
x = hidden_states.repeat_interleave(self.in_shortcut_repeats, dim=1) x = hidden_states.repeat_interleave(
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
)
hidden_states = self.conv_in(hidden_states) + x hidden_states = self.conv_in(hidden_states) + x
else: else:
hidden_states = self.conv_in(hidden_states) hidden_states = self.conv_in(hidden_states)
......
...@@ -103,7 +103,7 @@ class AllegroTemporalConvLayer(nn.Module): ...@@ -103,7 +103,7 @@ class AllegroTemporalConvLayer(nn.Module):
if self.down_sample: if self.down_sample:
identity = hidden_states[:, :, ::2] identity = hidden_states[:, :, ::2]
elif self.up_sample: elif self.up_sample:
identity = hidden_states.repeat_interleave(2, dim=2) identity = hidden_states.repeat_interleave(2, dim=2, output_size=hidden_states.shape[2] * 2)
else: else:
identity = hidden_states identity = hidden_states
......
...@@ -426,7 +426,9 @@ class FourierFeatures(nn.Module): ...@@ -426,7 +426,9 @@ class FourierFeatures(nn.Module):
w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1] w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1]
# Interleaved repeat of input channels to match w # Interleaved repeat of input channels to match w
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W] h = inputs.repeat_interleave(
num_freqs, dim=1, output_size=inputs.shape[1] * num_freqs
) # [B, C * num_freqs, T, H, W]
# Scale channels by frequency. # Scale channels by frequency.
h = w * h h = w * h
......
...@@ -687,7 +687,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -687,7 +687,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
t_emb = t_emb.to(dtype=sample.dtype) t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond) emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(sample_num_frames, dim=0) emb = emb.repeat_interleave(sample_num_frames, dim=0, output_size=emb.shape[0] * sample_num_frames)
# 2. pre-process # 2. pre-process
batch_size, channels, num_frames, height, width = sample.shape batch_size, channels, num_frames, height, width = sample.shape
......
...@@ -139,7 +139,9 @@ def get_3d_sincos_pos_embed( ...@@ -139,7 +139,9 @@ def get_3d_sincos_pos_embed(
# 3. Concat # 3. Concat
pos_embed_spatial = pos_embed_spatial[None, :, :] pos_embed_spatial = pos_embed_spatial[None, :, :]
pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3] pos_embed_spatial = pos_embed_spatial.repeat_interleave(
temporal_size, dim=0, output_size=pos_embed_spatial.shape[0] * temporal_size
) # [T, H*W, D // 4 * 3]
pos_embed_temporal = pos_embed_temporal[:, None, :] pos_embed_temporal = pos_embed_temporal[:, None, :]
pos_embed_temporal = pos_embed_temporal.repeat_interleave( pos_embed_temporal = pos_embed_temporal.repeat_interleave(
...@@ -1154,8 +1156,8 @@ def get_1d_rotary_pos_embed( ...@@ -1154,8 +1156,8 @@ def get_1d_rotary_pos_embed(
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real: if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox # flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
elif use_real: elif use_real:
# stable audio, allegro # stable audio, allegro
......
...@@ -227,13 +227,17 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): ...@@ -227,13 +227,17 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
# Prepare text embeddings for spatial block # Prepare text embeddings for spatial block
# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152 encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view( encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(
-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1] num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame
) ).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1])
# Prepare timesteps for spatial and temporal block # Prepare timesteps for spatial and temporal block
timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1]) timestep_spatial = timestep.repeat_interleave(
timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1]) num_frame, dim=0, output_size=timestep.shape[0] * num_frame
).view(-1, timestep.shape[-1])
timestep_temp = timestep.repeat_interleave(
num_patches, dim=0, output_size=timestep.shape[0] * num_patches
).view(-1, timestep.shape[-1])
# Spatial and temporal transformer blocks # Spatial and temporal transformer blocks
for i, (spatial_block, temp_block) in enumerate( for i, (spatial_block, temp_block) in enumerate(
...@@ -299,7 +303,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): ...@@ -299,7 +303,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
).permute(0, 2, 1, 3) ).permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1]) embedded_timestep = embedded_timestep.repeat_interleave(
num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame
).view(-1, embedded_timestep.shape[-1])
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) hidden_states = self.norm_out(hidden_states)
# Modulation # Modulation
......
...@@ -353,7 +353,11 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef ...@@ -353,7 +353,11 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) attention_mask = attention_mask.repeat_interleave(
self.config.num_attention_heads,
dim=0,
output_size=attention_mask.shape[0] * self.config.num_attention_heads,
)
if self.norm_in is not None: if self.norm_in is not None:
hidden_states = self.norm_in(hidden_states) hidden_states = self.norm_in(hidden_states)
......
...@@ -638,8 +638,10 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -638,8 +638,10 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
t_emb = t_emb.to(dtype=self.dtype) t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb, timestep_cond) emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(repeats=num_frames, dim=0) emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) encoder_hidden_states = encoder_hidden_states.repeat_interleave(
num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
)
# 2. pre-process # 2. pre-process
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
......
...@@ -592,7 +592,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -592,7 +592,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
# 3. time + FPS embeddings. # 3. time + FPS embeddings.
emb = t_emb + fps_emb emb = t_emb + fps_emb
emb = emb.repeat_interleave(repeats=num_frames, dim=0) emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
# 4. context embeddings. # 4. context embeddings.
# The context embeddings consist of both text embeddings from the input prompt # The context embeddings consist of both text embeddings from the input prompt
...@@ -620,7 +620,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -620,7 +620,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
image_emb = self.context_embedding(image_embeddings) image_emb = self.context_embedding(image_embeddings)
image_emb = image_emb.view(-1, self.config.in_channels, self.config.cross_attention_dim) image_emb = image_emb.view(-1, self.config.in_channels, self.config.cross_attention_dim)
context_emb = torch.cat([context_emb, image_emb], dim=1) context_emb = torch.cat([context_emb, image_emb], dim=1)
context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0) context_emb = context_emb.repeat_interleave(num_frames, dim=0, output_size=context_emb.shape[0] * num_frames)
image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape( image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(
image_latents.shape[0] * image_latents.shape[2], image_latents.shape[0] * image_latents.shape[2],
......
...@@ -2059,7 +2059,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -2059,7 +2059,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
aug_emb = self.add_embedding(add_embeds) aug_emb = self.add_embedding(add_embeds)
emb = emb if aug_emb is None else emb + aug_emb emb = emb if aug_emb is None else emb + aug_emb
emb = emb.repeat_interleave(repeats=num_frames, dim=0) emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs: if "image_embeds" not in added_cond_kwargs:
...@@ -2068,7 +2068,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -2068,7 +2068,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
) )
image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds) image_embeds = self.encoder_hid_proj(image_embeds)
image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds] image_embeds = [
image_embed.repeat_interleave(num_frames, dim=0, output_size=image_embed.shape[0] * num_frames)
for image_embed in image_embeds
]
encoder_hidden_states = (encoder_hidden_states, image_embeds) encoder_hidden_states = (encoder_hidden_states, image_embeds)
# 2. pre-process # 2. pre-process
......
...@@ -431,9 +431,11 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL ...@@ -431,9 +431,11 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
sample = sample.flatten(0, 1) sample = sample.flatten(0, 1)
# Repeat the embeddings num_video_frames times # Repeat the embeddings num_video_frames times
# emb: [batch, channels] -> [batch * frames, channels] # emb: [batch, channels] -> [batch * frames, channels]
emb = emb.repeat_interleave(num_frames, dim=0) emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) encoder_hidden_states = encoder_hidden_states.repeat_interleave(
num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
)
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment