Unverified Commit 23f413c2 authored by Muhammed Abdullah's avatar Muhammed Abdullah Committed by GitHub
Browse files

Added Dropout parameter to Models Constructors (#4580)



* Added Dropout parameter of Models

* Added argument description for dropout in MobileNet v2 and v3
Updated quantization/googlenet.py as per the changes in constructor in googlenet

* Moved new dropout parameter n the end

* Updated googlenet.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 321f39e7
...@@ -15,7 +15,7 @@ model_urls = { ...@@ -15,7 +15,7 @@ model_urls = {
class AlexNet(nn.Module): class AlexNet(nn.Module):
def __init__(self, num_classes: int = 1000) -> None: def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:
super(AlexNet, self).__init__() super(AlexNet, self).__init__()
self.features = nn.Sequential( self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
...@@ -34,10 +34,10 @@ class AlexNet(nn.Module): ...@@ -34,10 +34,10 @@ class AlexNet(nn.Module):
) )
self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.Dropout(), nn.Dropout(p=dropout),
nn.Linear(256 * 6 * 6, 4096), nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Dropout(), nn.Dropout(p=dropout),
nn.Linear(4096, 4096), nn.Linear(4096, 4096),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(4096, num_classes), nn.Linear(4096, num_classes),
......
...@@ -71,6 +71,8 @@ class GoogLeNet(nn.Module): ...@@ -71,6 +71,8 @@ class GoogLeNet(nn.Module):
transform_input: bool = False, transform_input: bool = False,
init_weights: Optional[bool] = None, init_weights: Optional[bool] = None,
blocks: Optional[List[Callable[..., nn.Module]]] = None, blocks: Optional[List[Callable[..., nn.Module]]] = None,
dropout: float = 0.2,
dropout_aux: float = 0.7,
) -> None: ) -> None:
super(GoogLeNet, self).__init__() super(GoogLeNet, self).__init__()
if blocks is None: if blocks is None:
...@@ -112,14 +114,14 @@ class GoogLeNet(nn.Module): ...@@ -112,14 +114,14 @@ class GoogLeNet(nn.Module):
self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128) self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128)
if aux_logits: if aux_logits:
self.aux1 = inception_aux_block(512, num_classes) self.aux1 = inception_aux_block(512, num_classes, dropout=dropout_aux)
self.aux2 = inception_aux_block(528, num_classes) self.aux2 = inception_aux_block(528, num_classes, dropout=dropout_aux)
else: else:
self.aux1 = None # type: ignore[assignment] self.aux1 = None # type: ignore[assignment]
self.aux2 = None # type: ignore[assignment] self.aux2 = None # type: ignore[assignment]
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(0.2) self.dropout = nn.Dropout(p=dropout)
self.fc = nn.Linear(1024, num_classes) self.fc = nn.Linear(1024, num_classes)
if init_weights: if init_weights:
...@@ -264,7 +266,11 @@ class Inception(nn.Module): ...@@ -264,7 +266,11 @@ class Inception(nn.Module):
class InceptionAux(nn.Module): class InceptionAux(nn.Module):
def __init__( def __init__(
self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None self,
in_channels: int,
num_classes: int,
conv_block: Optional[Callable[..., nn.Module]] = None,
dropout: float = 0.7,
) -> None: ) -> None:
super(InceptionAux, self).__init__() super(InceptionAux, self).__init__()
if conv_block is None: if conv_block is None:
...@@ -273,6 +279,7 @@ class InceptionAux(nn.Module): ...@@ -273,6 +279,7 @@ class InceptionAux(nn.Module):
self.fc1 = nn.Linear(2048, 1024) self.fc1 = nn.Linear(2048, 1024)
self.fc2 = nn.Linear(1024, num_classes) self.fc2 = nn.Linear(1024, num_classes)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
...@@ -284,7 +291,7 @@ class InceptionAux(nn.Module): ...@@ -284,7 +291,7 @@ class InceptionAux(nn.Module):
# N x 2048 # N x 2048
x = F.relu(self.fc1(x), inplace=True) x = F.relu(self.fc1(x), inplace=True)
# N x 1024 # N x 1024
x = F.dropout(x, 0.7, training=self.training) x = self.dropout(x)
# N x 1024 # N x 1024
x = self.fc2(x) x = self.fc2(x)
# N x 1000 (num_classes) # N x 1000 (num_classes)
......
...@@ -70,6 +70,7 @@ class Inception3(nn.Module): ...@@ -70,6 +70,7 @@ class Inception3(nn.Module):
transform_input: bool = False, transform_input: bool = False,
inception_blocks: Optional[List[Callable[..., nn.Module]]] = None, inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
init_weights: Optional[bool] = None, init_weights: Optional[bool] = None,
dropout: float = 0.5,
) -> None: ) -> None:
super(Inception3, self).__init__() super(Inception3, self).__init__()
if inception_blocks is None: if inception_blocks is None:
...@@ -115,7 +116,7 @@ class Inception3(nn.Module): ...@@ -115,7 +116,7 @@ class Inception3(nn.Module):
self.Mixed_7b = inception_e(1280) self.Mixed_7b = inception_e(1280)
self.Mixed_7c = inception_e(2048) self.Mixed_7c = inception_e(2048)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout() self.dropout = nn.Dropout(p=dropout)
self.fc = nn.Linear(2048, num_classes) self.fc = nn.Linear(2048, num_classes)
if init_weights: if init_weights:
for m in self.modules(): for m in self.modules():
......
...@@ -93,6 +93,7 @@ class MobileNetV2(nn.Module): ...@@ -93,6 +93,7 @@ class MobileNetV2(nn.Module):
round_nearest: int = 8, round_nearest: int = 8,
block: Optional[Callable[..., nn.Module]] = None, block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None,
dropout: float = 0.2,
) -> None: ) -> None:
""" """
MobileNet V2 main class MobileNet V2 main class
...@@ -105,6 +106,7 @@ class MobileNetV2(nn.Module): ...@@ -105,6 +106,7 @@ class MobileNetV2(nn.Module):
Set to 1 to turn off rounding Set to 1 to turn off rounding
block: Module specifying inverted residual building block for mobilenet block: Module specifying inverted residual building block for mobilenet
norm_layer: Module specifying the normalization layer to use norm_layer: Module specifying the normalization layer to use
dropout (float): The droupout probability
""" """
super(MobileNetV2, self).__init__() super(MobileNetV2, self).__init__()
...@@ -161,7 +163,7 @@ class MobileNetV2(nn.Module): ...@@ -161,7 +163,7 @@ class MobileNetV2(nn.Module):
# building classifier # building classifier
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.Dropout(0.2), nn.Dropout(p=dropout),
nn.Linear(self.last_channel, num_classes), nn.Linear(self.last_channel, num_classes),
) )
......
...@@ -135,6 +135,7 @@ class MobileNetV3(nn.Module): ...@@ -135,6 +135,7 @@ class MobileNetV3(nn.Module):
num_classes: int = 1000, num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None, block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None,
dropout: float = 0.2,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
...@@ -146,6 +147,7 @@ class MobileNetV3(nn.Module): ...@@ -146,6 +147,7 @@ class MobileNetV3(nn.Module):
num_classes (int): Number of classes num_classes (int): Number of classes
block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
dropout (float): The droupout probability
""" """
super().__init__() super().__init__()
...@@ -200,7 +202,7 @@ class MobileNetV3(nn.Module): ...@@ -200,7 +202,7 @@ class MobileNetV3(nn.Module):
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.Linear(lastconv_output_channels, last_channel), nn.Linear(lastconv_output_channels, last_channel),
nn.Hardswish(inplace=True), nn.Hardswish(inplace=True),
nn.Dropout(p=0.2, inplace=True), nn.Dropout(p=dropout, inplace=True),
nn.Linear(last_channel, num_classes), nn.Linear(last_channel, num_classes),
) )
......
...@@ -116,7 +116,6 @@ class QuantizableInceptionAux(InceptionAux): ...@@ -116,7 +116,6 @@ class QuantizableInceptionAux(InceptionAux):
conv_block=QuantizableBasicConv2d, *args, **kwargs conv_block=QuantizableBasicConv2d, *args, **kwargs
) )
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.7)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
......
...@@ -33,7 +33,7 @@ class Fire(nn.Module): ...@@ -33,7 +33,7 @@ class Fire(nn.Module):
class SqueezeNet(nn.Module): class SqueezeNet(nn.Module):
def __init__(self, version: str = "1_0", num_classes: int = 1000) -> None: def __init__(self, version: str = "1_0", num_classes: int = 1000, dropout: float = 0.5) -> None:
super(SqueezeNet, self).__init__() super(SqueezeNet, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
if version == "1_0": if version == "1_0":
...@@ -77,7 +77,7 @@ class SqueezeNet(nn.Module): ...@@ -77,7 +77,7 @@ class SqueezeNet(nn.Module):
# Final convolution is initialized differently from the rest # Final convolution is initialized differently from the rest
final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.Dropout(p=0.5), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)) nn.Dropout(p=dropout), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1))
) )
for m in self.modules(): for m in self.modules():
......
...@@ -32,17 +32,19 @@ model_urls = { ...@@ -32,17 +32,19 @@ model_urls = {
class VGG(nn.Module): class VGG(nn.Module):
def __init__(self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True) -> None: def __init__(
self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
) -> None:
super(VGG, self).__init__() super(VGG, self).__init__()
self.features = features self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096), nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True), nn.ReLU(True),
nn.Dropout(), nn.Dropout(p=dropout),
nn.Linear(4096, 4096), nn.Linear(4096, 4096),
nn.ReLU(True), nn.ReLU(True),
nn.Dropout(), nn.Dropout(p=dropout),
nn.Linear(4096, num_classes), nn.Linear(4096, num_classes),
) )
if init_weights: if init_weights:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment