Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8b4f8ba7
Unverified
Commit
8b4f8ba7
authored
Mar 12, 2025
by
hlky
Committed by
GitHub
Mar 12, 2025
Browse files
Use `output_size` in `repeat_interleave` (#11030)
parent
54280464
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
56 additions
and
27 deletions
+56
-27
src/diffusers/models/attention_processor.py
src/diffusers/models/attention_processor.py
+10
-4
src/diffusers/models/autoencoders/autoencoder_dc.py
src/diffusers/models/autoencoders/autoencoder_dc.py
+4
-2
src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
+1
-1
src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
+3
-1
src/diffusers/models/controlnets/controlnet_sparsectrl.py
src/diffusers/models/controlnets/controlnet_sparsectrl.py
+1
-1
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+5
-3
src/diffusers/models/transformers/latte_transformer_3d.py
src/diffusers/models/transformers/latte_transformer_3d.py
+12
-6
src/diffusers/models/transformers/prior_transformer.py
src/diffusers/models/transformers/prior_transformer.py
+5
-1
src/diffusers/models/unets/unet_3d_condition.py
src/diffusers/models/unets/unet_3d_condition.py
+4
-2
src/diffusers/models/unets/unet_i2vgen_xl.py
src/diffusers/models/unets/unet_i2vgen_xl.py
+2
-2
src/diffusers/models/unets/unet_motion_model.py
src/diffusers/models/unets/unet_motion_model.py
+5
-2
src/diffusers/models/unets/unet_spatio_temporal_condition.py
src/diffusers/models/unets/unet_spatio_temporal_condition.py
+4
-2
No files found.
src/diffusers/models/attention_processor.py
View file @
8b4f8ba7
...
...
@@ -741,10 +741,14 @@ class Attention(nn.Module):
if
out_dim
==
3
:
if
attention_mask
.
shape
[
0
]
<
batch_size
*
head_size
:
attention_mask
=
attention_mask
.
repeat_interleave
(
head_size
,
dim
=
0
)
attention_mask
=
attention_mask
.
repeat_interleave
(
head_size
,
dim
=
0
,
output_size
=
attention_mask
.
shape
[
0
]
*
head_size
)
elif
out_dim
==
4
:
attention_mask
=
attention_mask
.
unsqueeze
(
1
)
attention_mask
=
attention_mask
.
repeat_interleave
(
head_size
,
dim
=
1
)
attention_mask
=
attention_mask
.
repeat_interleave
(
head_size
,
dim
=
1
,
output_size
=
attention_mask
.
shape
[
1
]
*
head_size
)
return
attention_mask
...
...
@@ -3704,8 +3708,10 @@ class StableAudioAttnProcessor2_0:
if
kv_heads
!=
attn
.
heads
:
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
heads_per_kv_head
=
attn
.
heads
//
kv_heads
key
=
torch
.
repeat_interleave
(
key
,
heads_per_kv_head
,
dim
=
1
)
value
=
torch
.
repeat_interleave
(
value
,
heads_per_kv_head
,
dim
=
1
)
key
=
torch
.
repeat_interleave
(
key
,
heads_per_kv_head
,
dim
=
1
,
output_size
=
key
.
shape
[
1
]
*
heads_per_kv_head
)
value
=
torch
.
repeat_interleave
(
value
,
heads_per_kv_head
,
dim
=
1
,
output_size
=
value
.
shape
[
1
]
*
heads_per_kv_head
)
if
attn
.
norm_q
is
not
None
:
query
=
attn
.
norm_q
(
query
)
...
...
src/diffusers/models/autoencoders/autoencoder_dc.py
View file @
8b4f8ba7
...
...
@@ -190,7 +190,7 @@ class DCUpBlock2d(nn.Module):
x
=
F
.
pixel_shuffle
(
x
,
self
.
factor
)
if
self
.
shortcut
:
y
=
hidden_states
.
repeat_interleave
(
self
.
repeats
,
dim
=
1
)
y
=
hidden_states
.
repeat_interleave
(
self
.
repeats
,
dim
=
1
,
output_size
=
hidden_states
.
shape
[
1
]
*
self
.
repeats
)
y
=
F
.
pixel_shuffle
(
y
,
self
.
factor
)
hidden_states
=
x
+
y
else
:
...
...
@@ -361,7 +361,9 @@ class Decoder(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
in_shortcut
:
x
=
hidden_states
.
repeat_interleave
(
self
.
in_shortcut_repeats
,
dim
=
1
)
x
=
hidden_states
.
repeat_interleave
(
self
.
in_shortcut_repeats
,
dim
=
1
,
output_size
=
hidden_states
.
shape
[
1
]
*
self
.
in_shortcut_repeats
)
hidden_states
=
self
.
conv_in
(
hidden_states
)
+
x
else
:
hidden_states
=
self
.
conv_in
(
hidden_states
)
...
...
src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
View file @
8b4f8ba7
...
...
@@ -103,7 +103,7 @@ class AllegroTemporalConvLayer(nn.Module):
if
self
.
down_sample
:
identity
=
hidden_states
[:,
:,
::
2
]
elif
self
.
up_sample
:
identity
=
hidden_states
.
repeat_interleave
(
2
,
dim
=
2
)
identity
=
hidden_states
.
repeat_interleave
(
2
,
dim
=
2
,
output_size
=
hidden_states
.
shape
[
2
]
*
2
)
else
:
identity
=
hidden_states
...
...
src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
View file @
8b4f8ba7
...
...
@@ -426,7 +426,9 @@ class FourierFeatures(nn.Module):
w
=
w
.
repeat
(
num_channels
)[
None
,
:,
None
,
None
,
None
]
# [1, num_channels * num_freqs, 1, 1, 1]
# Interleaved repeat of input channels to match w
h
=
inputs
.
repeat_interleave
(
num_freqs
,
dim
=
1
)
# [B, C * num_freqs, T, H, W]
h
=
inputs
.
repeat_interleave
(
num_freqs
,
dim
=
1
,
output_size
=
inputs
.
shape
[
1
]
*
num_freqs
)
# [B, C * num_freqs, T, H, W]
# Scale channels by frequency.
h
=
w
*
h
...
...
src/diffusers/models/controlnets/controlnet_sparsectrl.py
View file @
8b4f8ba7
...
...
@@ -687,7 +687,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
t_emb
=
t_emb
.
to
(
dtype
=
sample
.
dtype
)
emb
=
self
.
time_embedding
(
t_emb
,
timestep_cond
)
emb
=
emb
.
repeat_interleave
(
sample_num_frames
,
dim
=
0
)
emb
=
emb
.
repeat_interleave
(
sample_num_frames
,
dim
=
0
,
output_size
=
emb
.
shape
[
0
]
*
sample_num_frames
)
# 2. pre-process
batch_size
,
channels
,
num_frames
,
height
,
width
=
sample
.
shape
...
...
src/diffusers/models/embeddings.py
View file @
8b4f8ba7
...
...
@@ -139,7 +139,9 @@ def get_3d_sincos_pos_embed(
# 3. Concat
pos_embed_spatial
=
pos_embed_spatial
[
None
,
:,
:]
pos_embed_spatial
=
pos_embed_spatial
.
repeat_interleave
(
temporal_size
,
dim
=
0
)
# [T, H*W, D // 4 * 3]
pos_embed_spatial
=
pos_embed_spatial
.
repeat_interleave
(
temporal_size
,
dim
=
0
,
output_size
=
pos_embed_spatial
.
shape
[
0
]
*
temporal_size
)
# [T, H*W, D // 4 * 3]
pos_embed_temporal
=
pos_embed_temporal
[:,
None
,
:]
pos_embed_temporal
=
pos_embed_temporal
.
repeat_interleave
(
...
...
@@ -1154,8 +1156,8 @@ def get_1d_rotary_pos_embed(
freqs
=
torch
.
outer
(
pos
,
freqs
)
# type: ignore # [S, D/2]
if
use_real
and
repeat_interleave_real
:
# flux, hunyuan-dit, cogvideox
freqs_cos
=
freqs
.
cos
().
repeat_interleave
(
2
,
dim
=
1
).
float
()
# [S, D]
freqs_sin
=
freqs
.
sin
().
repeat_interleave
(
2
,
dim
=
1
).
float
()
# [S, D]
freqs_cos
=
freqs
.
cos
().
repeat_interleave
(
2
,
dim
=
1
,
output_size
=
freqs
.
shape
[
1
]
*
2
).
float
()
# [S, D]
freqs_sin
=
freqs
.
sin
().
repeat_interleave
(
2
,
dim
=
1
,
output_size
=
freqs
.
shape
[
1
]
*
2
).
float
()
# [S, D]
return
freqs_cos
,
freqs_sin
elif
use_real
:
# stable audio, allegro
...
...
src/diffusers/models/transformers/latte_transformer_3d.py
View file @
8b4f8ba7
...
...
@@ -227,13 +227,17 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
# Prepare text embeddings for spatial block
# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
encoder_hidden_states
=
self
.
caption_projection
(
encoder_hidden_states
)
# 3 120 1152
encoder_hidden_states_spatial
=
encoder_hidden_states
.
repeat_interleave
(
num_frame
,
dim
=
0
).
view
(
-
1
,
encoder_hidden_states
.
shape
[
-
2
],
encoder_hidden_states
.
shape
[
-
1
]
)
encoder_hidden_states_spatial
=
encoder_hidden_states
.
repeat_interleave
(
num_frame
,
dim
=
0
,
output_size
=
encoder_hidden_states
.
shape
[
0
]
*
num_frame
)
.
view
(
-
1
,
encoder_hidden_states
.
shape
[
-
2
],
encoder_hidden_states
.
shape
[
-
1
])
# Prepare timesteps for spatial and temporal block
timestep_spatial
=
timestep
.
repeat_interleave
(
num_frame
,
dim
=
0
).
view
(
-
1
,
timestep
.
shape
[
-
1
])
timestep_temp
=
timestep
.
repeat_interleave
(
num_patches
,
dim
=
0
).
view
(
-
1
,
timestep
.
shape
[
-
1
])
timestep_spatial
=
timestep
.
repeat_interleave
(
num_frame
,
dim
=
0
,
output_size
=
timestep
.
shape
[
0
]
*
num_frame
).
view
(
-
1
,
timestep
.
shape
[
-
1
])
timestep_temp
=
timestep
.
repeat_interleave
(
num_patches
,
dim
=
0
,
output_size
=
timestep
.
shape
[
0
]
*
num_patches
).
view
(
-
1
,
timestep
.
shape
[
-
1
])
# Spatial and temporal transformer blocks
for
i
,
(
spatial_block
,
temp_block
)
in
enumerate
(
...
...
@@ -299,7 +303,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
).
permute
(
0
,
2
,
1
,
3
)
hidden_states
=
hidden_states
.
reshape
(
-
1
,
hidden_states
.
shape
[
-
2
],
hidden_states
.
shape
[
-
1
])
embedded_timestep
=
embedded_timestep
.
repeat_interleave
(
num_frame
,
dim
=
0
).
view
(
-
1
,
embedded_timestep
.
shape
[
-
1
])
embedded_timestep
=
embedded_timestep
.
repeat_interleave
(
num_frame
,
dim
=
0
,
output_size
=
embedded_timestep
.
shape
[
0
]
*
num_frame
).
view
(
-
1
,
embedded_timestep
.
shape
[
-
1
])
shift
,
scale
=
(
self
.
scale_shift_table
[
None
]
+
embedded_timestep
[:,
None
]).
chunk
(
2
,
dim
=
1
)
hidden_states
=
self
.
norm_out
(
hidden_states
)
# Modulation
...
...
src/diffusers/models/transformers/prior_transformer.py
View file @
8b4f8ba7
...
...
@@ -353,7 +353,11 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
attention_mask
=
(
1
-
attention_mask
.
to
(
hidden_states
.
dtype
))
*
-
10000.0
attention_mask
=
F
.
pad
(
attention_mask
,
(
0
,
self
.
additional_embeddings
),
value
=
0.0
)
attention_mask
=
(
attention_mask
[:,
None
,
:]
+
self
.
causal_attention_mask
).
to
(
hidden_states
.
dtype
)
attention_mask
=
attention_mask
.
repeat_interleave
(
self
.
config
.
num_attention_heads
,
dim
=
0
)
attention_mask
=
attention_mask
.
repeat_interleave
(
self
.
config
.
num_attention_heads
,
dim
=
0
,
output_size
=
attention_mask
.
shape
[
0
]
*
self
.
config
.
num_attention_heads
,
)
if
self
.
norm_in
is
not
None
:
hidden_states
=
self
.
norm_in
(
hidden_states
)
...
...
src/diffusers/models/unets/unet_3d_condition.py
View file @
8b4f8ba7
...
...
@@ -638,8 +638,10 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
t_emb
=
t_emb
.
to
(
dtype
=
self
.
dtype
)
emb
=
self
.
time_embedding
(
t_emb
,
timestep_cond
)
emb
=
emb
.
repeat_interleave
(
repeats
=
num_frames
,
dim
=
0
)
encoder_hidden_states
=
encoder_hidden_states
.
repeat_interleave
(
repeats
=
num_frames
,
dim
=
0
)
emb
=
emb
.
repeat_interleave
(
num_frames
,
dim
=
0
,
output_size
=
emb
.
shape
[
0
]
*
num_frames
)
encoder_hidden_states
=
encoder_hidden_states
.
repeat_interleave
(
num_frames
,
dim
=
0
,
output_size
=
encoder_hidden_states
.
shape
[
0
]
*
num_frames
)
# 2. pre-process
sample
=
sample
.
permute
(
0
,
2
,
1
,
3
,
4
).
reshape
((
sample
.
shape
[
0
]
*
num_frames
,
-
1
)
+
sample
.
shape
[
3
:])
...
...
src/diffusers/models/unets/unet_i2vgen_xl.py
View file @
8b4f8ba7
...
...
@@ -592,7 +592,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
# 3. time + FPS embeddings.
emb
=
t_emb
+
fps_emb
emb
=
emb
.
repeat_interleave
(
repeats
=
num_frames
,
dim
=
0
)
emb
=
emb
.
repeat_interleave
(
num_frames
,
dim
=
0
,
output_size
=
emb
.
shape
[
0
]
*
num_frames
)
# 4. context embeddings.
# The context embeddings consist of both text embeddings from the input prompt
...
...
@@ -620,7 +620,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
image_emb
=
self
.
context_embedding
(
image_embeddings
)
image_emb
=
image_emb
.
view
(
-
1
,
self
.
config
.
in_channels
,
self
.
config
.
cross_attention_dim
)
context_emb
=
torch
.
cat
([
context_emb
,
image_emb
],
dim
=
1
)
context_emb
=
context_emb
.
repeat_interleave
(
repeats
=
num_frames
,
dim
=
0
)
context_emb
=
context_emb
.
repeat_interleave
(
num_frames
,
dim
=
0
,
output_size
=
context_emb
.
shape
[
0
]
*
num_frames
)
image_latents
=
image_latents
.
permute
(
0
,
2
,
1
,
3
,
4
).
reshape
(
image_latents
.
shape
[
0
]
*
image_latents
.
shape
[
2
],
...
...
src/diffusers/models/unets/unet_motion_model.py
View file @
8b4f8ba7
...
...
@@ -2059,7 +2059,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
aug_emb
=
self
.
add_embedding
(
add_embeds
)
emb
=
emb
if
aug_emb
is
None
else
emb
+
aug_emb
emb
=
emb
.
repeat_interleave
(
repeats
=
num_frames
,
dim
=
0
)
emb
=
emb
.
repeat_interleave
(
num_frames
,
dim
=
0
,
output_size
=
emb
.
shape
[
0
]
*
num_frames
)
if
self
.
encoder_hid_proj
is
not
None
and
self
.
config
.
encoder_hid_dim_type
==
"ip_image_proj"
:
if
"image_embeds"
not
in
added_cond_kwargs
:
...
...
@@ -2068,7 +2068,10 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
)
image_embeds
=
added_cond_kwargs
.
get
(
"image_embeds"
)
image_embeds
=
self
.
encoder_hid_proj
(
image_embeds
)
image_embeds
=
[
image_embed
.
repeat_interleave
(
repeats
=
num_frames
,
dim
=
0
)
for
image_embed
in
image_embeds
]
image_embeds
=
[
image_embed
.
repeat_interleave
(
num_frames
,
dim
=
0
,
output_size
=
image_embed
.
shape
[
0
]
*
num_frames
)
for
image_embed
in
image_embeds
]
encoder_hidden_states
=
(
encoder_hidden_states
,
image_embeds
)
# 2. pre-process
...
...
src/diffusers/models/unets/unet_spatio_temporal_condition.py
View file @
8b4f8ba7
...
...
@@ -431,9 +431,11 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
sample
=
sample
.
flatten
(
0
,
1
)
# Repeat the embeddings num_video_frames times
# emb: [batch, channels] -> [batch * frames, channels]
emb
=
emb
.
repeat_interleave
(
num_frames
,
dim
=
0
)
emb
=
emb
.
repeat_interleave
(
num_frames
,
dim
=
0
,
output_size
=
emb
.
shape
[
0
]
*
num_frames
)
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
encoder_hidden_states
=
encoder_hidden_states
.
repeat_interleave
(
num_frames
,
dim
=
0
)
encoder_hidden_states
=
encoder_hidden_states
.
repeat_interleave
(
num_frames
,
dim
=
0
,
output_size
=
encoder_hidden_states
.
shape
[
0
]
*
num_frames
)
# 2. pre-process
sample
=
self
.
conv_in
(
sample
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment