Unverified Commit 4ae20e53 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Move all weight initializations from private methods to constructors (#5331)

* Move weight initialization in constructors.

* Fixing mypy for ViT.

* remove unnecessary import
parent cd00ea15
...@@ -90,15 +90,12 @@ class GoogLeNet(nn.Module): ...@@ -90,15 +90,12 @@ class GoogLeNet(nn.Module):
self.fc = nn.Linear(1024, num_classes) self.fc = nn.Linear(1024, num_classes)
if init_weights: if init_weights:
self._initialize_weights() for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
def _initialize_weights(self) -> None: torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2)
for m in self.modules(): elif isinstance(m, nn.BatchNorm2d):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): nn.init.constant_(m.weight, 1)
torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2) nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _transform_input(self, x: Tensor) -> Tensor: def _transform_input(self, x: Tensor) -> Tensor:
if self.transform_input: if self.transform_input:
......
...@@ -128,15 +128,7 @@ class MNASNet(torch.nn.Module): ...@@ -128,15 +128,7 @@ class MNASNet(torch.nn.Module):
] ]
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes)) self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes))
self._initialize_weights()
def forward(self, x: Tensor) -> Tensor:
x = self.layers(x)
# Equivalent to global avgpool and removing H and W dimensions.
x = x.mean([2, 3])
return self.classifier(x)
def _initialize_weights(self) -> None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
...@@ -149,6 +141,12 @@ class MNASNet(torch.nn.Module): ...@@ -149,6 +141,12 @@ class MNASNet(torch.nn.Module):
nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid") nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid")
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
def forward(self, x: Tensor) -> Tensor:
x = self.layers(x)
# Equivalent to global avgpool and removing H and W dimensions.
x = x.mean([2, 3])
return self.classifier(x)
def _load_from_state_dict( def _load_from_state_dict(
self, self,
state_dict: Dict, state_dict: Dict,
......
...@@ -132,9 +132,6 @@ class FeatureEncoder(nn.Module): ...@@ -132,9 +132,6 @@ class FeatureEncoder(nn.Module):
self.conv = nn.Conv2d(layers[3], layers[4], kernel_size=1) self.conv = nn.Conv2d(layers[3], layers[4], kernel_size=1)
self._init_weights()
def _init_weights(self):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
......
...@@ -366,20 +366,6 @@ class RegNet(nn.Module): ...@@ -366,20 +366,6 @@ class RegNet(nn.Module):
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(in_features=current_width, out_features=num_classes) self.fc = nn.Linear(in_features=current_width, out_features=num_classes)
# Init weights and good to go
self._reset_parameters()
def forward(self, x: Tensor) -> Tensor:
x = self.stem(x)
x = self.trunk_output(x)
x = self.avgpool(x)
x = x.flatten(start_dim=1)
x = self.fc(x)
return x
def _reset_parameters(self) -> None:
# Performs ResNet-style weight initialization # Performs ResNet-style weight initialization
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
...@@ -393,6 +379,16 @@ class RegNet(nn.Module): ...@@ -393,6 +379,16 @@ class RegNet(nn.Module):
nn.init.normal_(m.weight, mean=0.0, std=0.01) nn.init.normal_(m.weight, mean=0.0, std=0.01)
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
def forward(self, x: Tensor) -> Tensor:
x = self.stem(x)
x = self.trunk_output(x)
x = self.avgpool(x)
x = x.flatten(start_dim=1)
x = self.fc(x)
return x
def _regnet(arch: str, block_params: BlockParams, pretrained: bool, progress: bool, **kwargs: Any) -> RegNet: def _regnet(arch: str, block_params: BlockParams, pretrained: bool, progress: bool, **kwargs: Any) -> RegNet:
norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1)) norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1))
......
...@@ -50,7 +50,17 @@ class VGG(nn.Module): ...@@ -50,7 +50,17 @@ class VGG(nn.Module):
nn.Linear(4096, num_classes), nn.Linear(4096, num_classes),
) )
if init_weights: if init_weights:
self._initialize_weights() for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x) x = self.features(x)
...@@ -59,19 +69,6 @@ class VGG(nn.Module): ...@@ -59,19 +69,6 @@ class VGG(nn.Module):
x = self.classifier(x) x = self.classifier(x)
return x return x
def _initialize_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential: def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
layers: List[nn.Module] = [] layers: List[nn.Module] = []
......
...@@ -223,7 +223,17 @@ class VideoResNet(nn.Module): ...@@ -223,7 +223,17 @@ class VideoResNet(nn.Module):
self.fc = nn.Linear(512 * block.expansion, num_classes) self.fc = nn.Linear(512 * block.expansion, num_classes)
# init weights # init weights
self._initialize_weights() for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm3d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
if zero_init_residual: if zero_init_residual:
for m in self.modules(): for m in self.modules():
...@@ -270,19 +280,6 @@ class VideoResNet(nn.Module): ...@@ -270,19 +280,6 @@ class VideoResNet(nn.Module):
return nn.Sequential(*layers) return nn.Sequential(*layers)
def _initialize_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm3d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def _video_resnet(arch: str, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: def _video_resnet(arch: str, pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet:
model = VideoResNet(**kwargs) model = VideoResNet(**kwargs)
......
...@@ -44,9 +44,7 @@ class MLPBlock(nn.Sequential): ...@@ -44,9 +44,7 @@ class MLPBlock(nn.Sequential):
self.dropout_1 = nn.Dropout(dropout) self.dropout_1 = nn.Dropout(dropout)
self.linear_2 = nn.Linear(mlp_dim, in_dim) self.linear_2 = nn.Linear(mlp_dim, in_dim)
self.dropout_2 = nn.Dropout(dropout) self.dropout_2 = nn.Dropout(dropout)
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.linear_1.weight) nn.init.xavier_uniform_(self.linear_1.weight)
nn.init.xavier_uniform_(self.linear_2.weight) nn.init.xavier_uniform_(self.linear_2.weight)
nn.init.normal_(self.linear_1.bias, std=1e-6) nn.init.normal_(self.linear_1.bias, std=1e-6)
...@@ -211,28 +209,29 @@ class VisionTransformer(nn.Module): ...@@ -211,28 +209,29 @@ class VisionTransformer(nn.Module):
heads_layers["head"] = nn.Linear(representation_size, num_classes) heads_layers["head"] = nn.Linear(representation_size, num_classes)
self.heads = nn.Sequential(heads_layers) self.heads = nn.Sequential(heads_layers)
self._init_weights()
def _init_weights(self):
if isinstance(self.conv_proj, nn.Conv2d): if isinstance(self.conv_proj, nn.Conv2d):
# Init the patchify stem # Init the patchify stem
fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1] fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.conv_proj.bias) if self.conv_proj.bias is not None:
else: nn.init.zeros_(self.conv_proj.bias)
elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
# Init the last 1x1 conv of the conv stem # Init the last 1x1 conv of the conv stem
nn.init.normal_( nn.init.normal_(
self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels) self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
) )
nn.init.zeros_(self.conv_proj.conv_last.bias) if self.conv_proj.conv_last.bias is not None:
nn.init.zeros_(self.conv_proj.conv_last.bias)
if hasattr(self.heads, "pre_logits"): if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
fan_in = self.heads.pre_logits.in_features fan_in = self.heads.pre_logits.in_features
nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
nn.init.zeros_(self.heads.pre_logits.bias) nn.init.zeros_(self.heads.pre_logits.bias)
nn.init.zeros_(self.heads.head.weight) if isinstance(self.heads.head, nn.Linear):
nn.init.zeros_(self.heads.head.bias) nn.init.zeros_(self.heads.head.weight)
nn.init.zeros_(self.heads.head.bias)
def _process_input(self, x: torch.Tensor) -> torch.Tensor: def _process_input(self, x: torch.Tensor) -> torch.Tensor:
n, c, h, w = x.shape n, c, h, w = x.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment