"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "636feba552ab1f9a1ecdd3748e1245fa45b76d0d"
Unverified Commit 5b4f79d9 authored by Vladislav Sovrasov's avatar Vladislav Sovrasov Committed by GitHub
Browse files

Don't use named args in MHA calls to allow applying pytorch forward hooks to VIT (#6956)

parent d710f3d1
...@@ -110,7 +110,7 @@ class EncoderBlock(nn.Module): ...@@ -110,7 +110,7 @@ class EncoderBlock(nn.Module):
def forward(self, input: torch.Tensor): def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
x = self.ln_1(input) x = self.ln_1(input)
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) x, _ = self.self_attention(x, x, x, need_weights=False)
x = self.dropout(x) x = self.dropout(x)
x = x + input x = x + input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment