Unverified Commit 4a4afd5e authored by hlky's avatar hlky Committed by GitHub
Browse files

Fix batch > 1 in HunyuanVideo (#10548)

parent aa79d7da
......@@ -727,7 +727,8 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
for i in range(batch_size):
attention_mask[i, : effective_sequence_length[i]] = True
attention_mask = attention_mask.unsqueeze(1) # [B, 1, N], for broadcasting across attention heads
# [B, 1, 1, N], for broadcasting across attention heads
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment