"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "2a02d7f5396ef6b1d0ad40535238092e6a4bc1a9"
Unverified Commit 7116fd24 authored by Zehuan Huang's avatar Zehuan Huang Committed by GitHub
Browse files

Support pass kwargs to cogvideox custom attention processor (#10456)

* Support pass kwargs to cogvideox custom attention processor

* remove args in cogvideox attn processor

* remove unused kwargs
parent 553b1384
...@@ -120,8 +120,10 @@ class CogVideoXBlock(nn.Module): ...@@ -120,8 +120,10 @@ class CogVideoXBlock(nn.Module):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1) text_seq_length = encoder_hidden_states.size(1)
attention_kwargs = attention_kwargs or {}
# norm & modulate # norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
...@@ -133,6 +135,7 @@ class CogVideoXBlock(nn.Module): ...@@ -133,6 +135,7 @@ class CogVideoXBlock(nn.Module):
hidden_states=norm_hidden_states, hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
**attention_kwargs,
) )
hidden_states = hidden_states + gate_msa * attn_hidden_states hidden_states = hidden_states + gate_msa * attn_hidden_states
...@@ -498,6 +501,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -498,6 +501,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
encoder_hidden_states, encoder_hidden_states,
emb, emb,
image_rotary_emb, image_rotary_emb,
attention_kwargs,
**ckpt_kwargs, **ckpt_kwargs,
) )
else: else:
...@@ -506,6 +510,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -506,6 +510,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
temb=emb, temb=emb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
) )
if not self.config.use_rotary_positional_embeddings: if not self.config.use_rotary_positional_embeddings:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment