Unverified Commit 5d5ecb45 authored by Yiwen Song's avatar Yiwen Song Committed by GitHub
Browse files

[ViT] Refactor forward function (#5172)

* refactor forward function

* reemove n from return
parent 058f4bd7
...@@ -202,7 +202,7 @@ class VisionTransformer(nn.Module): ...@@ -202,7 +202,7 @@ class VisionTransformer(nn.Module):
nn.init.zeros_(self.heads.head.weight) nn.init.zeros_(self.heads.head.weight)
nn.init.zeros_(self.heads.head.bias) nn.init.zeros_(self.heads.head.bias)
def forward(self, x: torch.Tensor): def _process_input(self, x: torch.Tensor) -> torch.Tensor:
n, c, h, w = x.shape n, c, h, w = x.shape
p = self.patch_size p = self.patch_size
torch._assert(h == self.image_size, "Wrong image height!") torch._assert(h == self.image_size, "Wrong image height!")
...@@ -221,7 +221,14 @@ class VisionTransformer(nn.Module): ...@@ -221,7 +221,14 @@ class VisionTransformer(nn.Module):
# embedding dimension # embedding dimension
x = x.permute(0, 2, 1) x = x.permute(0, 2, 1)
# Expand the class token to the full batch. return x
def forward(self, x: torch.Tensor):
# Reshaping and permuting the input tensor
x = self._process_input(x)
n = x.shape[0]
# Expand the class token to the full batch
batch_class_token = self.class_token.expand(n, -1, -1) batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1) x = torch.cat([batch_class_token, x], dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment