Unverified Commit 76c74b37 authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

VAN: update modules names (#16201)

* done

* done
parent 99e2982f
...@@ -207,10 +207,10 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_ ...@@ -207,10 +207,10 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
} }
names_to_original_checkpoints = { names_to_original_checkpoints = {
"van-tiny": "https://huggingface.co/Visual-Attention-Network/VAN-Tiny/resolve/main/van_tiny_754.pth.tar", "van-tiny": "https://huggingface.co/Visual-Attention-Network/VAN-Tiny-original/resolve/main/van_tiny_754.pth.tar",
"van-small": "https://huggingface.co/Visual-Attention-Network/VAN-Small/resolve/main/van_small_811.pth.tar", "van-small": "https://huggingface.co/Visual-Attention-Network/VAN-Small-original/resolve/main/van_small_811.pth.tar",
"van-base": "https://huggingface.co/Visual-Attention-Network/VAN-Base/resolve/main/van_base_828.pth.tar", "van-base": "https://huggingface.co/Visual-Attention-Network/VAN-Base-original/resolve/main/van_base_828.pth.tar",
"van-large": "https://huggingface.co/Visual-Attention-Network/VAN-Large/resolve/main/van_large_839.pth.tar", "van-large": "https://huggingface.co/Visual-Attention-Network/VAN-Large-original/resolve/main/van_large_839.pth.tar",
} }
if model_name: if model_name:
......
...@@ -154,8 +154,10 @@ class VanOverlappingPatchEmbedder(nn.Sequential): ...@@ -154,8 +154,10 @@ class VanOverlappingPatchEmbedder(nn.Sequential):
def __init__(self, in_channels: int, hidden_size: int, patch_size: int = 7, stride: int = 4): def __init__(self, in_channels: int, hidden_size: int, patch_size: int = 7, stride: int = 4):
super().__init__() super().__init__()
self.conv = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=patch_size // 2) self.convolution = nn.Conv2d(
self.norm = nn.BatchNorm2d(hidden_size) in_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=patch_size // 2
)
self.normalization = nn.BatchNorm2d(hidden_size)
class VanMlpLayer(nn.Sequential): class VanMlpLayer(nn.Sequential):
...@@ -173,12 +175,12 @@ class VanMlpLayer(nn.Sequential): ...@@ -173,12 +175,12 @@ class VanMlpLayer(nn.Sequential):
dropout_rate: float = 0.5, dropout_rate: float = 0.5,
): ):
super().__init__() super().__init__()
self.fc1 = nn.Conv2d(in_channels, hidden_size, kernel_size=1) self.in_dense = nn.Conv2d(in_channels, hidden_size, kernel_size=1)
self.depth_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1, groups=hidden_size) self.depth_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1, groups=hidden_size)
self.activation = ACT2FN[hidden_act] self.activation = ACT2FN[hidden_act]
self.drop1 = nn.Dropout(dropout_rate) self.dropout1 = nn.Dropout(dropout_rate)
self.fc2 = nn.Conv2d(hidden_size, out_channels, kernel_size=1) self.out_dense = nn.Conv2d(hidden_size, out_channels, kernel_size=1)
self.drop2 = nn.Dropout(dropout_rate) self.dropout2 = nn.Dropout(dropout_rate)
class VanLargeKernelAttention(nn.Sequential): class VanLargeKernelAttention(nn.Sequential):
...@@ -267,10 +269,10 @@ class VanLayer(nn.Module): ...@@ -267,10 +269,10 @@ class VanLayer(nn.Module):
): ):
super().__init__() super().__init__()
self.drop_path = VanDropPath(drop_path) if drop_path_rate > 0.0 else nn.Identity() self.drop_path = VanDropPath(drop_path) if drop_path_rate > 0.0 else nn.Identity()
self.pre_norm = nn.BatchNorm2d(hidden_size) self.pre_normomalization = nn.BatchNorm2d(hidden_size)
self.attention = VanSpatialAttentionLayer(hidden_size, config.hidden_act) self.attention = VanSpatialAttentionLayer(hidden_size, config.hidden_act)
self.attention_scaling = VanLayerScaling(hidden_size, config.layer_scale_init_value) self.attention_scaling = VanLayerScaling(hidden_size, config.layer_scale_init_value)
self.post_norm = nn.BatchNorm2d(hidden_size) self.post_normalization = nn.BatchNorm2d(hidden_size)
self.mlp = VanMlpLayer( self.mlp = VanMlpLayer(
hidden_size, hidden_size * mlp_ratio, hidden_size, config.hidden_act, config.dropout_rate hidden_size, hidden_size * mlp_ratio, hidden_size, config.hidden_act, config.dropout_rate
) )
...@@ -279,7 +281,7 @@ class VanLayer(nn.Module): ...@@ -279,7 +281,7 @@ class VanLayer(nn.Module):
def forward(self, hidden_state): def forward(self, hidden_state):
residual = hidden_state residual = hidden_state
# attention # attention
hidden_state = self.pre_norm(hidden_state) hidden_state = self.pre_normomalization(hidden_state)
hidden_state = self.attention(hidden_state) hidden_state = self.attention(hidden_state)
hidden_state = self.attention_scaling(hidden_state) hidden_state = self.attention_scaling(hidden_state)
hidden_state = self.drop_path(hidden_state) hidden_state = self.drop_path(hidden_state)
...@@ -287,7 +289,7 @@ class VanLayer(nn.Module): ...@@ -287,7 +289,7 @@ class VanLayer(nn.Module):
hidden_state = residual + hidden_state hidden_state = residual + hidden_state
residual = hidden_state residual = hidden_state
# mlp # mlp
hidden_state = self.post_norm(hidden_state) hidden_state = self.post_normalization(hidden_state)
hidden_state = self.mlp(hidden_state) hidden_state = self.mlp(hidden_state)
hidden_state = self.mlp_scaling(hidden_state) hidden_state = self.mlp_scaling(hidden_state)
hidden_state = self.drop_path(hidden_state) hidden_state = self.drop_path(hidden_state)
...@@ -325,7 +327,7 @@ class VanStage(nn.Module): ...@@ -325,7 +327,7 @@ class VanStage(nn.Module):
for _ in range(depth) for _ in range(depth)
] ]
) )
self.norm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) self.normalization = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_state): def forward(self, hidden_state):
hidden_state = self.embeddings(hidden_state) hidden_state = self.embeddings(hidden_state)
...@@ -333,7 +335,7 @@ class VanStage(nn.Module): ...@@ -333,7 +335,7 @@ class VanStage(nn.Module):
# rearrange b c h w -> b (h w) c # rearrange b c h w -> b (h w) c
batch_size, hidden_size, height, width = hidden_state.shape batch_size, hidden_size, height, width = hidden_state.shape
hidden_state = hidden_state.flatten(2).transpose(1, 2) hidden_state = hidden_state.flatten(2).transpose(1, 2)
hidden_state = self.norm(hidden_state) hidden_state = self.normalization(hidden_state)
# rearrange b (h w) c- > b c h w # rearrange b (h w) c- > b c h w
hidden_state = hidden_state.view(batch_size, height, width, hidden_size).permute(0, 3, 1, 2) hidden_state = hidden_state.view(batch_size, height, width, hidden_size).permute(0, 3, 1, 2)
return hidden_state return hidden_state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment