"vscode:/vscode.git/clone" did not exist on "435d37ce5acb01f446ebb6fdd274915bc2f27bc8"
Unverified Commit 11d22e0e authored by Samuel Tesfai's avatar Samuel Tesfai Committed by GitHub
Browse files

Cross attention module to Wan Attention (#12058)



* Cross attention module to Wan Attention

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 9a38fab5
...@@ -180,6 +180,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin): ...@@ -180,6 +180,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
added_kv_proj_dim: Optional[int] = None, added_kv_proj_dim: Optional[int] = None,
cross_attention_dim_head: Optional[int] = None, cross_attention_dim_head: Optional[int] = None,
processor=None, processor=None,
is_cross_attention=None,
): ):
super().__init__() super().__init__()
...@@ -207,6 +208,8 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin): ...@@ -207,6 +208,8 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
self.is_cross_attention = cross_attention_dim_head is not None
self.set_processor(processor) self.set_processor(processor)
def fuse_projections(self): def fuse_projections(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment