Unverified Commit c8651158 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

`torch.compile` fullgraph compatibility for Hunyuan Video (#11457)

udpate
parent 60892c55
...@@ -1068,17 +1068,15 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, ...@@ -1068,17 +1068,15 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
latent_sequence_length = hidden_states.shape[1] latent_sequence_length = hidden_states.shape[1]
condition_sequence_length = encoder_hidden_states.shape[1] condition_sequence_length = encoder_hidden_states.shape[1]
sequence_length = latent_sequence_length + condition_sequence_length sequence_length = latent_sequence_length + condition_sequence_length
attention_mask = torch.zeros( attention_mask = torch.ones(
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N] ) # [B, N]
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,] effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
indices = torch.arange(sequence_length, device=hidden_states.device).unsqueeze(0) # [1, N]
for i in range(batch_size): mask_indices = indices >= effective_sequence_length.unsqueeze(1) # [B, N]
attention_mask[i, : effective_sequence_length[i]] = True attention_mask = attention_mask.masked_fill(mask_indices, False)
# [B, 1, 1, N], for broadcasting across attention heads attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, N]
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
# 4. Transformer blocks # 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment