Unverified Commit c888d6dc authored by tcmyxc's avatar tcmyxc Committed by GitHub
Browse files

Update vision_transformer.py (#5820)



the assert msg should be same
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 669d5655
......@@ -79,7 +79,7 @@ class EncoderBlock(nn.Module):
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}")
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
x = self.ln_1(input)
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
x = self.dropout(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment