Unverified Commit 8b4f8ba7 authored by hlky's avatar hlky Committed by GitHub
Browse files

Use `output_size` in `repeat_interleave` (#11030)

parent 54280464
......@@ -741,10 +741,14 @@ class Attention(nn.Module):
if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
attention_mask = attention_mask.repeat_interleave(
head_size, dim=0, output_size=attention_mask.shape[0] * head_size
)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
attention_mask = attention_mask.repeat_interleave(
head_size, dim=1, output_size=attention_mask.shape[1] * head_size
)
return attention_mask
......@@ -3704,8 +3708,10 @@ class StableAudioAttnProcessor2_0:
if kv_heads != attn.heads:
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
heads_per_kv_head = attn.heads // kv_heads
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head)
value = torch.repeat_interleave(
value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head
)
if attn.norm_q is not None:
query = attn.norm_q(query)
......
......@@ -190,7 +190,7 @@ class DCUpBlock2d(nn.Module):
x = F.pixel_shuffle(x, self.factor)
if self.shortcut:
y = hidden_states.repeat_interleave(self.repeats, dim=1)
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
y = F.pixel_shuffle(y, self.factor)
hidden_states = x + y
else:
......@@ -361,7 +361,9 @@ class Decoder(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.in_shortcut:
x = hidden_states.repeat_interleave(self.in_shortcut_repeats, dim=1)
x = hidden_states.repeat_interleave(
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
)
hidden_states = self.conv_in(hidden_states) + x
else:
hidden_states = self.conv_in(hidden_states)
......
......@@ -103,7 +103,7 @@ class AllegroTemporalConvLayer(nn.Module):
if self.down_sample:
identity = hidden_states[:, :, ::2]
elif self.up_sample:
identity = hidden_states.repeat_interleave(2, dim=2)
identity = hidden_states.repeat_interleave(2, dim=2, output_size=hidden_states.shape[2] * 2)
else:
identity = hidden_states
......
......@@ -426,7 +426,9 @@ class FourierFeatures(nn.Module):
w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1]
# Interleaved repeat of input channels to match w
h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
h = inputs.repeat_interleave(
num_freqs, dim=1, output_size=inputs.shape[1] * num_freqs
) # [B, C * num_freqs, T, H, W]
# Scale channels by frequency.
h = w * h
......
......@@ -687,7 +687,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(sample_num_frames, dim=0)
emb = emb.repeat_interleave(sample_num_frames, dim=0, output_size=emb.shape[0] * sample_num_frames)
# 2. pre-process
batch_size, channels, num_frames, height, width = sample.shape
......
......@@ -139,7 +139,9 @@ def get_3d_sincos_pos_embed(
# 3. Concat
pos_embed_spatial = pos_embed_spatial[None, :, :]
pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
pos_embed_spatial = pos_embed_spatial.repeat_interleave(
temporal_size, dim=0, output_size=pos_embed_spatial.shape[0] * temporal_size
) # [T, H*W, D // 4 * 3]
pos_embed_temporal = pos_embed_temporal[:, None, :]
pos_embed_temporal = pos_embed_temporal.repeat_interleave(
......@@ -1154,8 +1156,8 @@ def get_1d_rotary_pos_embed(
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio, allegro
......
......@@ -227,13 +227,17 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
# Prepare text embeddings for spatial block
# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view(
-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]
)
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(
num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame
).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1])
# Prepare timesteps for spatial and temporal block
timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1])
timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1])
timestep_spatial = timestep.repeat_interleave(
num_frame, dim=0, output_size=timestep.shape[0] * num_frame
).view(-1, timestep.shape[-1])
timestep_temp = timestep.repeat_interleave(
num_patches, dim=0, output_size=timestep.shape[0] * num_patches
).view(-1, timestep.shape[-1])
# Spatial and temporal transformer blocks
for i, (spatial_block, temp_block) in enumerate(
......@@ -299,7 +303,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
).permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
embedded_timestep = embedded_timestep.repeat_interleave(
num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame
).view(-1, embedded_timestep.shape[-1])
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
......
......@@ -353,7 +353,11 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
attention_mask = attention_mask.repeat_interleave(
self.config.num_attention_heads,
dim=0,
output_size=attention_mask.shape[0] * self.config.num_attention_heads,
)
if self.norm_in is not None:
hidden_states = self.norm_in(hidden_states)
......
......@@ -638,8 +638,10 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
)
# 2. pre-process
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
......
......@@ -592,7 +592,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
# 3. time + FPS embeddings.
emb = t_emb + fps_emb
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
# 4. context embeddings.
# The context embeddings consist of both text embeddings from the input prompt
......@@ -620,7 +620,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
image_emb = self.context_embedding(image_embeddings)
image_emb = image_emb.view(-1, self.config.in_channels, self.config.cross_attention_dim)
context_emb = torch.cat([context_emb, image_emb], dim=1)
context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0)
context_emb = context_emb.repeat_interleave(num_frames, dim=0, output_size=context_emb.shape[0] * num_frames)
image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(
image_latents.shape[0] * image_latents.shape[2],
......
......@@ -2059,7 +2059,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
aug_emb = self.add_embedding(add_embeds)
emb = emb if aug_emb is None else emb + aug_emb
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:
......@@ -2068,7 +2068,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
)
image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds)
image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds]
image_embeds = [
image_embed.repeat_interleave(num_frames, dim=0, output_size=image_embed.shape[0] * num_frames)
for image_embed in image_embeds
]
encoder_hidden_states = (encoder_hidden_states, image_embeds)
# 2. pre-process
......
......@@ -431,9 +431,11 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
sample = sample.flatten(0, 1)
# Repeat the embeddings num_video_frames times
# emb: [batch, channels] -> [batch * frames, channels]
emb = emb.repeat_interleave(num_frames, dim=0)
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
)
# 2. pre-process
sample = self.conv_in(sample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment