Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c8651158
Unverified
Commit
c8651158
authored
Apr 30, 2025
by
Aryan
Committed by
GitHub
Apr 30, 2025
Browse files
`torch.compile` fullgraph compatibility for Hunyuan Video (#11457)
udpate
parent
60892c55
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
7 deletions
+5
-7
src/diffusers/models/transformers/transformer_hunyuan_video.py
...iffusers/models/transformers/transformer_hunyuan_video.py
+5
-7
No files found.
src/diffusers/models/transformers/transformer_hunyuan_video.py
View file @
c8651158
...
...
@@ -1068,17 +1068,15 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
latent_sequence_length
=
hidden_states
.
shape
[
1
]
condition_sequence_length
=
encoder_hidden_states
.
shape
[
1
]
sequence_length
=
latent_sequence_length
+
condition_sequence_length
attention_mask
=
torch
.
zero
s
(
attention_mask
=
torch
.
one
s
(
batch_size
,
sequence_length
,
device
=
hidden_states
.
device
,
dtype
=
torch
.
bool
)
# [B, N]
effective_condition_sequence_length
=
encoder_attention_mask
.
sum
(
dim
=
1
,
dtype
=
torch
.
int
)
# [B,]
effective_sequence_length
=
latent_sequence_length
+
effective_condition_sequence_length
for
i
in
range
(
batch_size
):
attention_mask
[
i
,
:
effective_sequence_length
[
i
]]
=
True
# [B, 1, 1, N], for broadcasting across attention heads
attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
1
)
indices
=
torch
.
arange
(
sequence_length
,
device
=
hidden_states
.
device
).
unsqueeze
(
0
)
# [1, N]
mask_indices
=
indices
>=
effective_sequence_length
.
unsqueeze
(
1
)
# [B, N]
attention_mask
=
attention_mask
.
masked_fill
(
mask_indices
,
False
)
attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
1
)
# [B, 1, 1, N]
# 4. Transformer blocks
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment