Unverified Commit 5b4f79d9 authored by Vladislav Sovrasov's avatar Vladislav Sovrasov Committed by GitHub
Browse files

Don't use named args in MHA calls to allow applying pytorch forward hooks to VIT (#6956)

parent d710f3d1
......@@ -110,7 +110,7 @@ class EncoderBlock(nn.Module):
def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
x = self.ln_1(input)
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
x, _ = self.self_attention(x, x, x, need_weights=False)
x = self.dropout(x)
x = x + input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment