Unverified Commit 8d9674ed authored by Tri Dao's avatar Tri Dao Committed by GitHub
Browse files

Merge pull request #102 from Lamikins/main

fixed cross attention typeerror
parents 93383bd5 aec35fd6
...@@ -341,6 +341,7 @@ class MHA(nn.Module): ...@@ -341,6 +341,7 @@ class MHA(nn.Module):
self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2, self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
groups=3 * embed_dim) groups=3 * embed_dim)
else: else:
inner_attn_cls = inner_cross_attn_cls
self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs) self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
if not self.return_residual: if not self.return_residual:
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment