Unverified Commit 7116fd24 authored by Zehuan Huang's avatar Zehuan Huang Committed by GitHub
Browse files

Support pass kwargs to cogvideox custom attention processor (#10456)

* Support pass kwargs to cogvideox custom attention processor

* remove args in cogvideox attn processor

* remove unused kwargs
parent 553b1384
......@@ -120,8 +120,10 @@ class CogVideoXBlock(nn.Module):
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
attention_kwargs = attention_kwargs or {}
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
......@@ -133,6 +135,7 @@ class CogVideoXBlock(nn.Module):
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**attention_kwargs,
)
hidden_states = hidden_states + gate_msa * attn_hidden_states
......@@ -498,6 +501,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
encoder_hidden_states,
emb,
image_rotary_emb,
attention_kwargs,
**ckpt_kwargs,
)
else:
......@@ -506,6 +510,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
)
if not self.config.use_rotary_positional_embeddings:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment