Unverified Commit ce778bb7 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

CleanUp DenseNet code (#5966)

parent b0dbbd7a
......@@ -34,22 +34,14 @@ class _DenseLayer(nn.Module):
self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False
) -> None:
super().__init__()
self.norm1: nn.BatchNorm2d
self.add_module("norm1", nn.BatchNorm2d(num_input_features))
self.relu1: nn.ReLU
self.add_module("relu1", nn.ReLU(inplace=True))
self.conv1: nn.Conv2d
self.add_module(
"conv1", nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)
)
self.norm2: nn.BatchNorm2d
self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))
self.relu2: nn.ReLU
self.add_module("relu2", nn.ReLU(inplace=True))
self.conv2: nn.Conv2d
self.add_module(
"conv2", nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
)
self.norm1 = nn.BatchNorm2d(num_input_features)
self.relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)
self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
self.relu2 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient
......@@ -136,10 +128,10 @@ class _DenseBlock(nn.ModuleDict):
class _Transition(nn.Sequential):
def __init__(self, num_input_features: int, num_output_features: int) -> None:
super().__init__()
self.add_module("norm", nn.BatchNorm2d(num_input_features))
self.add_module("relu", nn.ReLU(inplace=True))
self.add_module("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2))
self.norm = nn.BatchNorm2d(num_input_features)
self.relu = nn.ReLU(inplace=True)
self.conv = nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)
self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
class DenseNet(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment