Unverified Commit 5f0edb97 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add ufmt (usort + black) as code formatter (#4384)



* add ufmt as code formatter

* cleanup

* quote ufmt requirement

* split imports into more groups

* regenerate circleci config

* fix CI

* clarify local testing utils section

* use ufmt pre-commit hook

* split relative imports into local category

* Revert "split relative imports into local category"

This reverts commit f2e224cde2008c56c9347c1f69746d39065cdd51.

* pin black and usort dependencies

* fix local test utils detection

* fix ufmt rev

* add reference utils to local category

* fix usort config

* remove custom categories sorting

* Run pre-commit without fixing flake8

* got a double import in merge
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent e45489b1
from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from typing import Any from torchvision.models import shufflenetv2
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from torchvision.models import shufflenetv2
from .utils import _replace_relu, quantize_model from .utils import _replace_relu, quantize_model
__all__ = [ __all__ = [
'QuantizableShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', "QuantizableShuffleNetV2",
'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' "shufflenet_v2_x0_5",
"shufflenet_v2_x1_0",
"shufflenet_v2_x1_5",
"shufflenet_v2_x2_0",
] ]
quant_model_urls = { quant_model_urls = {
'shufflenetv2_x0.5_fbgemm': None, "shufflenetv2_x0.5_fbgemm": None,
'shufflenetv2_x1.0_fbgemm': "shufflenetv2_x1.0_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
'https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth', "shufflenetv2_x1.5_fbgemm": None,
'shufflenetv2_x1.5_fbgemm': None, "shufflenetv2_x2.0_fbgemm": None,
'shufflenetv2_x2.0_fbgemm': None,
} }
...@@ -42,9 +45,7 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2): ...@@ -42,9 +45,7 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableShuffleNetV2, self).__init__( # type: ignore[misc] super(QuantizableShuffleNetV2, self).__init__( # type: ignore[misc]
*args, *args, inverted_residual=QuantizableInvertedResidual, **kwargs
inverted_residual=QuantizableInvertedResidual,
**kwargs
) )
self.quant = torch.quantization.QuantStub() self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub() self.dequant = torch.quantization.DeQuantStub()
...@@ -69,9 +70,7 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2): ...@@ -69,9 +70,7 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
for m in self.modules(): for m in self.modules():
if type(m) == QuantizableInvertedResidual: if type(m) == QuantizableInvertedResidual:
if len(m.branch1._modules.items()) > 0: if len(m.branch1._modules.items()) > 0:
torch.quantization.fuse_modules( torch.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True)
m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True
)
torch.quantization.fuse_modules( torch.quantization.fuse_modules(
m.branch2, m.branch2,
[["0", "1", "2"], ["3", "4"], ["5", "6", "7"]], [["0", "1", "2"], ["3", "4"], ["5", "6", "7"]],
...@@ -93,19 +92,18 @@ def _shufflenetv2( ...@@ -93,19 +92,18 @@ def _shufflenetv2(
if quantize: if quantize:
# TODO use pretrained as a string to specify the backend # TODO use pretrained as a string to specify the backend
backend = 'fbgemm' backend = "fbgemm"
quantize_model(model, backend) quantize_model(model, backend)
else: else:
assert pretrained in [True, False] assert pretrained in [True, False]
if pretrained: if pretrained:
if quantize: if quantize:
model_url = quant_model_urls[arch + '_' + backend] model_url = quant_model_urls[arch + "_" + backend]
else: else:
model_url = shufflenetv2.model_urls[arch] model_url = shufflenetv2.model_urls[arch]
state_dict = load_state_dict_from_url(model_url, state_dict = load_state_dict_from_url(model_url, progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
...@@ -127,8 +125,9 @@ def shufflenet_v2_x0_5( ...@@ -127,8 +125,9 @@ def shufflenet_v2_x0_5(
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model quantize (bool): If True, return a quantized version of the model
""" """
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, quantize, return _shufflenetv2(
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) "shufflenetv2_x0.5", pretrained, progress, quantize, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs
)
def shufflenet_v2_x1_0( def shufflenet_v2_x1_0(
...@@ -147,8 +146,9 @@ def shufflenet_v2_x1_0( ...@@ -147,8 +146,9 @@ def shufflenet_v2_x1_0(
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model quantize (bool): If True, return a quantized version of the model
""" """
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, quantize, return _shufflenetv2(
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) "shufflenetv2_x1.0", pretrained, progress, quantize, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs
)
def shufflenet_v2_x1_5( def shufflenet_v2_x1_5(
...@@ -167,8 +167,9 @@ def shufflenet_v2_x1_5( ...@@ -167,8 +167,9 @@ def shufflenet_v2_x1_5(
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model quantize (bool): If True, return a quantized version of the model
""" """
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, quantize, return _shufflenetv2(
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) "shufflenetv2_x1.5", pretrained, progress, quantize, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs
)
def shufflenet_v2_x2_0( def shufflenet_v2_x2_0(
...@@ -187,5 +188,6 @@ def shufflenet_v2_x2_0( ...@@ -187,5 +188,6 @@ def shufflenet_v2_x2_0(
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model quantize (bool): If True, return a quantized version of the model
""" """
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, quantize, return _shufflenetv2(
[4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) "shufflenetv2_x2.0", pretrained, progress, quantize, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs
)
...@@ -23,14 +23,15 @@ def quantize_model(model: nn.Module, backend: str) -> None: ...@@ -23,14 +23,15 @@ def quantize_model(model: nn.Module, backend: str) -> None:
torch.backends.quantized.engine = backend torch.backends.quantized.engine = backend
model.eval() model.eval()
# Make sure that weight qconfig matches that of the serialized models # Make sure that weight qconfig matches that of the serialized models
if backend == 'fbgemm': if backend == "fbgemm":
model.qconfig = torch.quantization.QConfig( # type: ignore[assignment] model.qconfig = torch.quantization.QConfig( # type: ignore[assignment]
activation=torch.quantization.default_observer, activation=torch.quantization.default_observer,
weight=torch.quantization.default_per_channel_weight_observer) weight=torch.quantization.default_per_channel_weight_observer,
elif backend == 'qnnpack': )
elif backend == "qnnpack":
model.qconfig = torch.quantization.QConfig( # type: ignore[assignment] model.qconfig = torch.quantization.QConfig( # type: ignore[assignment]
activation=torch.quantization.default_observer, activation=torch.quantization.default_observer, weight=torch.quantization.default_weight_observer
weight=torch.quantization.default_weight_observer) )
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
model.fuse_model() # type: ignore[operator] model.fuse_model() # type: ignore[operator]
......
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
import math import math
import torch
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, List, Optional, Tuple
import torch
from torch import nn, Tensor from torch import nn, Tensor
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
...@@ -16,10 +16,23 @@ from ..ops.misc import ConvNormActivation, SqueezeExcitation ...@@ -16,10 +16,23 @@ from ..ops.misc import ConvNormActivation, SqueezeExcitation
from ._utils import _make_divisible from ._utils import _make_divisible
__all__ = ["RegNet", "regnet_y_400mf", "regnet_y_800mf", "regnet_y_1_6gf", __all__ = [
"regnet_y_3_2gf", "regnet_y_8gf", "regnet_y_16gf", "regnet_y_32gf", "RegNet",
"regnet_x_400mf", "regnet_x_800mf", "regnet_x_1_6gf", "regnet_x_3_2gf", "regnet_y_400mf",
"regnet_x_8gf", "regnet_x_16gf", "regnet_x_32gf"] "regnet_y_800mf",
"regnet_y_1_6gf",
"regnet_y_3_2gf",
"regnet_y_8gf",
"regnet_y_16gf",
"regnet_y_32gf",
"regnet_x_400mf",
"regnet_x_800mf",
"regnet_x_1_6gf",
"regnet_x_3_2gf",
"regnet_x_8gf",
"regnet_x_16gf",
"regnet_x_32gf",
]
model_urls = { model_urls = {
...@@ -42,8 +55,9 @@ class SimpleStemIN(ConvNormActivation): ...@@ -42,8 +55,9 @@ class SimpleStemIN(ConvNormActivation):
norm_layer: Callable[..., nn.Module], norm_layer: Callable[..., nn.Module],
activation_layer: Callable[..., nn.Module], activation_layer: Callable[..., nn.Module],
) -> None: ) -> None:
super().__init__(width_in, width_out, kernel_size=3, stride=2, super().__init__(
norm_layer=norm_layer, activation_layer=activation_layer) width_in, width_out, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=activation_layer
)
class BottleneckTransform(nn.Sequential): class BottleneckTransform(nn.Sequential):
...@@ -64,10 +78,12 @@ class BottleneckTransform(nn.Sequential): ...@@ -64,10 +78,12 @@ class BottleneckTransform(nn.Sequential):
w_b = int(round(width_out * bottleneck_multiplier)) w_b = int(round(width_out * bottleneck_multiplier))
g = w_b // group_width g = w_b // group_width
layers["a"] = ConvNormActivation(width_in, w_b, kernel_size=1, stride=1, layers["a"] = ConvNormActivation(
norm_layer=norm_layer, activation_layer=activation_layer) width_in, w_b, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=activation_layer
layers["b"] = ConvNormActivation(w_b, w_b, kernel_size=3, stride=stride, groups=g, )
norm_layer=norm_layer, activation_layer=activation_layer) layers["b"] = ConvNormActivation(
w_b, w_b, kernel_size=3, stride=stride, groups=g, norm_layer=norm_layer, activation_layer=activation_layer
)
if se_ratio: if se_ratio:
# The SE reduction ratio is defined with respect to the # The SE reduction ratio is defined with respect to the
...@@ -79,8 +95,9 @@ class BottleneckTransform(nn.Sequential): ...@@ -79,8 +95,9 @@ class BottleneckTransform(nn.Sequential):
activation=activation_layer, activation=activation_layer,
) )
layers["c"] = ConvNormActivation(w_b, width_out, kernel_size=1, stride=1, layers["c"] = ConvNormActivation(
norm_layer=norm_layer, activation_layer=None) w_b, width_out, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=None
)
super().__init__(layers) super().__init__(layers)
...@@ -104,8 +121,9 @@ class ResBottleneckBlock(nn.Module): ...@@ -104,8 +121,9 @@ class ResBottleneckBlock(nn.Module):
self.proj = None self.proj = None
should_proj = (width_in != width_out) or (stride != 1) should_proj = (width_in != width_out) or (stride != 1)
if should_proj: if should_proj:
self.proj = ConvNormActivation(width_in, width_out, kernel_size=1, self.proj = ConvNormActivation(
stride=stride, norm_layer=norm_layer, activation_layer=None) width_in, width_out, kernel_size=1, stride=stride, norm_layer=norm_layer, activation_layer=None
)
self.f = BottleneckTransform( self.f = BottleneckTransform(
width_in, width_in,
width_out, width_out,
...@@ -217,10 +235,7 @@ class BlockParams: ...@@ -217,10 +235,7 @@ class BlockParams:
# Compute the block widths. Each stage has one unique block width # Compute the block widths. Each stage has one unique block width
widths_cont = torch.arange(depth) * w_a + w_0 widths_cont = torch.arange(depth) * w_a + w_0
block_capacity = torch.round(torch.log(widths_cont / w_0) / math.log(w_m)) block_capacity = torch.round(torch.log(widths_cont / w_0) / math.log(w_m))
block_widths = ( block_widths = (torch.round(torch.divide(w_0 * torch.pow(w_m, block_capacity), QUANT)) * QUANT).int().tolist()
torch.round(torch.divide(w_0 * torch.pow(w_m, block_capacity), QUANT))
* QUANT
).int().tolist()
num_stages = len(set(block_widths)) num_stages = len(set(block_widths))
# Convert to per stage parameters # Convert to per stage parameters
...@@ -254,14 +269,12 @@ class BlockParams: ...@@ -254,14 +269,12 @@ class BlockParams:
) )
def _get_expanded_params(self): def _get_expanded_params(self):
return zip( return zip(self.widths, self.strides, self.depths, self.group_widths, self.bottleneck_multipliers)
self.widths, self.strides, self.depths, self.group_widths, self.bottleneck_multipliers
)
@staticmethod @staticmethod
def _adjust_widths_groups_compatibilty( def _adjust_widths_groups_compatibilty(
stage_widths: List[int], bottleneck_ratios: List[float], stage_widths: List[int], bottleneck_ratios: List[float], group_widths: List[int]
group_widths: List[int]) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
""" """
Adjusts the compatibility of widths and groups, Adjusts the compatibility of widths and groups,
depending on the bottleneck ratio. depending on the bottleneck ratio.
...@@ -389,8 +402,7 @@ def regnet_y_400mf(pretrained: bool = False, progress: bool = True, **kwargs: An ...@@ -389,8 +402,7 @@ def regnet_y_400mf(pretrained: bool = False, progress: bool = True, **kwargs: An
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs)
group_width=8, se_ratio=0.25, **kwargs)
return _regnet("regnet_y_400mf", params, pretrained, progress, **kwargs) return _regnet("regnet_y_400mf", params, pretrained, progress, **kwargs)
...@@ -403,8 +415,7 @@ def regnet_y_800mf(pretrained: bool = False, progress: bool = True, **kwargs: An ...@@ -403,8 +415,7 @@ def regnet_y_800mf(pretrained: bool = False, progress: bool = True, **kwargs: An
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs)
group_width=16, se_ratio=0.25, **kwargs)
return _regnet("regnet_y_800mf", params, pretrained, progress, **kwargs) return _regnet("regnet_y_800mf", params, pretrained, progress, **kwargs)
...@@ -417,8 +428,9 @@ def regnet_y_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: An ...@@ -417,8 +428,9 @@ def regnet_y_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: An
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
params = BlockParams.from_init_params(depth=27, w_0=48, w_a=20.71, w_m=2.65, params = BlockParams.from_init_params(
group_width=24, se_ratio=0.25, **kwargs) depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs
)
return _regnet("regnet_y_1_6gf", params, pretrained, progress, **kwargs) return _regnet("regnet_y_1_6gf", params, pretrained, progress, **kwargs)
...@@ -431,8 +443,9 @@ def regnet_y_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: An ...@@ -431,8 +443,9 @@ def regnet_y_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: An
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
params = BlockParams.from_init_params(depth=21, w_0=80, w_a=42.63, w_m=2.66, params = BlockParams.from_init_params(
group_width=24, se_ratio=0.25, **kwargs) depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs
)
return _regnet("regnet_y_3_2gf", params, pretrained, progress, **kwargs) return _regnet("regnet_y_3_2gf", params, pretrained, progress, **kwargs)
...@@ -445,8 +458,9 @@ def regnet_y_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) ...@@ -445,8 +458,9 @@ def regnet_y_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any)
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
params = BlockParams.from_init_params(depth=17, w_0=192, w_a=76.82, w_m=2.19, params = BlockParams.from_init_params(
group_width=56, se_ratio=0.25, **kwargs) depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs
)
return _regnet("regnet_y_8gf", params, pretrained, progress, **kwargs) return _regnet("regnet_y_8gf", params, pretrained, progress, **kwargs)
...@@ -459,8 +473,9 @@ def regnet_y_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any ...@@ -459,8 +473,9 @@ def regnet_y_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
params = BlockParams.from_init_params(depth=18, w_0=200, w_a=106.23, w_m=2.48, params = BlockParams.from_init_params(
group_width=112, se_ratio=0.25, **kwargs) depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs
)
return _regnet("regnet_y_16gf", params, pretrained, progress, **kwargs) return _regnet("regnet_y_16gf", params, pretrained, progress, **kwargs)
...@@ -473,8 +488,9 @@ def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any ...@@ -473,8 +488,9 @@ def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
params = BlockParams.from_init_params(depth=20, w_0=232, w_a=115.89, w_m=2.53, params = BlockParams.from_init_params(
group_width=232, se_ratio=0.25, **kwargs) depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs
)
return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs) return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs)
...@@ -487,8 +503,7 @@ def regnet_x_400mf(pretrained: bool = False, progress: bool = True, **kwargs: An ...@@ -487,8 +503,7 @@ def regnet_x_400mf(pretrained: bool = False, progress: bool = True, **kwargs: An
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
group_width=16, **kwargs)
return _regnet("regnet_x_400mf", params, pretrained, progress, **kwargs) return _regnet("regnet_x_400mf", params, pretrained, progress, **kwargs)
...@@ -501,8 +516,7 @@ def regnet_x_800mf(pretrained: bool = False, progress: bool = True, **kwargs: An ...@@ -501,8 +516,7 @@ def regnet_x_800mf(pretrained: bool = False, progress: bool = True, **kwargs: An
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
group_width=16, **kwargs)
return _regnet("regnet_x_800mf", params, pretrained, progress, **kwargs) return _regnet("regnet_x_800mf", params, pretrained, progress, **kwargs)
...@@ -515,8 +529,7 @@ def regnet_x_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: An ...@@ -515,8 +529,7 @@ def regnet_x_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: An
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
group_width=24, **kwargs)
return _regnet("regnet_x_1_6gf", params, pretrained, progress, **kwargs) return _regnet("regnet_x_1_6gf", params, pretrained, progress, **kwargs)
...@@ -529,8 +542,7 @@ def regnet_x_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: An ...@@ -529,8 +542,7 @@ def regnet_x_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: An
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
group_width=48, **kwargs)
return _regnet("regnet_x_3_2gf", params, pretrained, progress, **kwargs) return _regnet("regnet_x_3_2gf", params, pretrained, progress, **kwargs)
...@@ -543,8 +555,7 @@ def regnet_x_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) ...@@ -543,8 +555,7 @@ def regnet_x_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any)
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
group_width=120, **kwargs)
return _regnet("regnet_x_8gf", params, pretrained, progress, **kwargs) return _regnet("regnet_x_8gf", params, pretrained, progress, **kwargs)
...@@ -557,8 +568,7 @@ def regnet_x_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any ...@@ -557,8 +568,7 @@ def regnet_x_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
group_width=128, **kwargs)
return _regnet("regnet_x_16gf", params, pretrained, progress, **kwargs) return _regnet("regnet_x_16gf", params, pretrained, progress, **kwargs)
...@@ -571,8 +581,8 @@ def regnet_x_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any ...@@ -571,8 +581,8 @@ def regnet_x_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
group_width=168, **kwargs)
return _regnet("regnet_x_32gf", params, pretrained, progress, **kwargs) return _regnet("regnet_x_32gf", params, pretrained, progress, **kwargs)
# TODO(kazhang): Add RegNetZ_500MF and RegNetZ_4GF # TODO(kazhang): Add RegNetZ_500MF and RegNetZ_4GF
from typing import Type, Any, Callable, Union, List, Optional
import torch import torch
from torch import Tensor
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from typing import Type, Any, Callable, Union, List, Optional
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', __all__ = [
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', "ResNet",
'wide_resnet50_2', 'wide_resnet101_2'] "resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
"resnext50_32x4d",
"resnext101_32x8d",
"wide_resnet50_2",
"wide_resnet101_2",
]
model_urls = { model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
} }
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding""" """3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, return nn.Conv2d(
padding=dilation, groups=groups, bias=False, dilation=dilation) in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation,
)
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
...@@ -46,13 +65,13 @@ class BasicBlock(nn.Module): ...@@ -46,13 +65,13 @@ class BasicBlock(nn.Module):
groups: int = 1, groups: int = 1,
base_width: int = 64, base_width: int = 64,
dilation: int = 1, dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None: ) -> None:
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
if norm_layer is None: if norm_layer is None:
norm_layer = nn.BatchNorm2d norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64: if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64') raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1: if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock") raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1 # Both self.conv1 and self.downsample layers downsample the input when stride != 1
...@@ -101,12 +120,12 @@ class Bottleneck(nn.Module): ...@@ -101,12 +120,12 @@ class Bottleneck(nn.Module):
groups: int = 1, groups: int = 1,
base_width: int = 64, base_width: int = 64,
dilation: int = 1, dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None: ) -> None:
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
if norm_layer is None: if norm_layer is None:
norm_layer = nn.BatchNorm2d norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1 # Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width) self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width) self.bn1 = norm_layer(width)
...@@ -142,7 +161,6 @@ class Bottleneck(nn.Module): ...@@ -142,7 +161,6 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module): class ResNet(nn.Module):
def __init__( def __init__(
self, self,
block: Type[Union[BasicBlock, Bottleneck]], block: Type[Union[BasicBlock, Bottleneck]],
...@@ -152,7 +170,7 @@ class ResNet(nn.Module): ...@@ -152,7 +170,7 @@ class ResNet(nn.Module):
groups: int = 1, groups: int = 1,
width_per_group: int = 64, width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None, replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None: ) -> None:
super(ResNet, self).__init__() super(ResNet, self).__init__()
if norm_layer is None: if norm_layer is None:
...@@ -166,28 +184,26 @@ class ResNet(nn.Module): ...@@ -166,28 +184,26 @@ class ResNet(nn.Module):
# the 2x2 stride with a dilated convolution instead # the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False] replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3: if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None " raise ValueError(
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)) "replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)
)
self.groups = groups self.groups = groups
self.base_width = width_per_group self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
bias=False)
self.bn1 = norm_layer(self.inplanes) self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0]) self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes) self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
...@@ -202,8 +218,14 @@ class ResNet(nn.Module): ...@@ -202,8 +218,14 @@ class ResNet(nn.Module):
elif isinstance(m, BasicBlock): elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, def _make_layer(
stride: int = 1, dilate: bool = False) -> nn.Sequential: self,
block: Type[Union[BasicBlock, Bottleneck]],
planes: int,
blocks: int,
stride: int = 1,
dilate: bool = False,
) -> nn.Sequential:
norm_layer = self._norm_layer norm_layer = self._norm_layer
downsample = None downsample = None
previous_dilation = self.dilation previous_dilation = self.dilation
...@@ -217,13 +239,23 @@ class ResNet(nn.Module): ...@@ -217,13 +239,23 @@ class ResNet(nn.Module):
) )
layers = [] layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups, layers.append(
self.base_width, previous_dilation, norm_layer)) block(
self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
)
)
self.inplanes = planes * block.expansion self.inplanes = planes * block.expansion
for _ in range(1, blocks): for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups, layers.append(
base_width=self.base_width, dilation=self.dilation, block(
norm_layer=norm_layer)) self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer,
)
)
return nn.Sequential(*layers) return nn.Sequential(*layers)
...@@ -255,12 +287,11 @@ def _resnet( ...@@ -255,12 +287,11 @@ def _resnet(
layers: List[int], layers: List[int],
pretrained: bool, pretrained: bool,
progress: bool, progress: bool,
**kwargs: Any **kwargs: Any,
) -> ResNet: ) -> ResNet:
model = ResNet(block, layers, **kwargs) model = ResNet(block, layers, **kwargs)
if pretrained: if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
...@@ -273,8 +304,7 @@ def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ...@@ -273,8 +304,7 @@ def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
**kwargs)
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
...@@ -285,8 +315,7 @@ def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ...@@ -285,8 +315,7 @@ def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
**kwargs)
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
...@@ -297,8 +326,7 @@ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ...@@ -297,8 +326,7 @@ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
**kwargs)
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
...@@ -309,8 +337,7 @@ def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ...@@ -309,8 +337,7 @@ def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
**kwargs)
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
...@@ -321,8 +348,7 @@ def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ...@@ -321,8 +348,7 @@ def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs)
**kwargs)
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
...@@ -333,10 +359,9 @@ def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: A ...@@ -333,10 +359,9 @@ def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: A
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
kwargs['groups'] = 32 kwargs["groups"] = 32
kwargs['width_per_group'] = 4 kwargs["width_per_group"] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
pretrained, progress, **kwargs)
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
...@@ -347,10 +372,9 @@ def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: ...@@ -347,10 +372,9 @@ def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
kwargs['groups'] = 32 kwargs["groups"] = 32
kwargs['width_per_group'] = 8 kwargs["width_per_group"] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
pretrained, progress, **kwargs)
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
...@@ -366,9 +390,8 @@ def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: A ...@@ -366,9 +390,8 @@ def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: A
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
kwargs['width_per_group'] = 64 * 2 kwargs["width_per_group"] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
pretrained, progress, **kwargs)
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
...@@ -384,6 +407,5 @@ def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: ...@@ -384,6 +407,5 @@ def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
kwargs['width_per_group'] = 64 * 2 kwargs["width_per_group"] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
pretrained, progress, **kwargs)
...@@ -6,14 +6,9 @@ from torch.nn import functional as F ...@@ -6,14 +6,9 @@ from torch.nn import functional as F
class _SimpleSegmentationModel(nn.Module): class _SimpleSegmentationModel(nn.Module):
__constants__ = ['aux_classifier'] __constants__ = ["aux_classifier"]
def __init__( def __init__(self, backbone: nn.Module, classifier: nn.Module, aux_classifier: Optional[nn.Module] = None) -> None:
self,
backbone: nn.Module,
classifier: nn.Module,
aux_classifier: Optional[nn.Module] = None
) -> None:
super(_SimpleSegmentationModel, self).__init__() super(_SimpleSegmentationModel, self).__init__()
self.backbone = backbone self.backbone = backbone
self.classifier = classifier self.classifier = classifier
...@@ -27,13 +22,13 @@ class _SimpleSegmentationModel(nn.Module): ...@@ -27,13 +22,13 @@ class _SimpleSegmentationModel(nn.Module):
result = OrderedDict() result = OrderedDict()
x = features["out"] x = features["out"]
x = self.classifier(x) x = self.classifier(x)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
result["out"] = x result["out"] = x
if self.aux_classifier is not None: if self.aux_classifier is not None:
x = features["aux"] x = features["aux"]
x = self.aux_classifier(x) x = self.aux_classifier(x)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
result["aux"] = x result["aux"] = x
return result return result
from typing import List
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from typing import List
from ._utils import _SimpleSegmentationModel from ._utils import _SimpleSegmentationModel
...@@ -24,6 +25,7 @@ class DeepLabV3(_SimpleSegmentationModel): ...@@ -24,6 +25,7 @@ class DeepLabV3(_SimpleSegmentationModel):
the backbone and returns a dense prediction. the backbone and returns a dense prediction.
aux_classifier (nn.Module, optional): auxiliary classifier used during training aux_classifier (nn.Module, optional): auxiliary classifier used during training
""" """
pass pass
...@@ -34,7 +36,7 @@ class DeepLabHead(nn.Sequential): ...@@ -34,7 +36,7 @@ class DeepLabHead(nn.Sequential):
nn.Conv2d(256, 256, 3, padding=1, bias=False), nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256), nn.BatchNorm2d(256),
nn.ReLU(), nn.ReLU(),
nn.Conv2d(256, num_classes, 1) nn.Conv2d(256, num_classes, 1),
) )
...@@ -43,7 +45,7 @@ class ASPPConv(nn.Sequential): ...@@ -43,7 +45,7 @@ class ASPPConv(nn.Sequential):
modules = [ modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels), nn.BatchNorm2d(out_channels),
nn.ReLU() nn.ReLU(),
] ]
super(ASPPConv, self).__init__(*modules) super(ASPPConv, self).__init__(*modules)
...@@ -54,23 +56,23 @@ class ASPPPooling(nn.Sequential): ...@@ -54,23 +56,23 @@ class ASPPPooling(nn.Sequential):
nn.AdaptiveAvgPool2d(1), nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels), nn.BatchNorm2d(out_channels),
nn.ReLU()) nn.ReLU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
size = x.shape[-2:] size = x.shape[-2:]
for mod in self: for mod in self:
x = mod(x) x = mod(x)
return F.interpolate(x, size=size, mode='bilinear', align_corners=False) return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
class ASPP(nn.Module): class ASPP(nn.Module):
def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None: def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
super(ASPP, self).__init__() super(ASPP, self).__init__()
modules = [] modules = []
modules.append(nn.Sequential( modules.append(
nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU())
nn.BatchNorm2d(out_channels), )
nn.ReLU()))
rates = tuple(atrous_rates) rates = tuple(atrous_rates)
for rate in rates: for rate in rates:
...@@ -84,7 +86,8 @@ class ASPP(nn.Module): ...@@ -84,7 +86,8 @@ class ASPP(nn.Module):
nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False), nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels), nn.BatchNorm2d(out_channels),
nn.ReLU(), nn.ReLU(),
nn.Dropout(0.5)) nn.Dropout(0.5),
)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
_res = [] _res = []
......
...@@ -19,6 +19,7 @@ class FCN(_SimpleSegmentationModel): ...@@ -19,6 +19,7 @@ class FCN(_SimpleSegmentationModel):
the backbone and returns a dense prediction. the backbone and returns a dense prediction.
aux_classifier (nn.Module, optional): auxiliary classifier used during training aux_classifier (nn.Module, optional): auxiliary classifier used during training
""" """
pass pass
...@@ -30,7 +31,7 @@ class FCNHead(nn.Sequential): ...@@ -30,7 +31,7 @@ class FCNHead(nn.Sequential):
nn.BatchNorm2d(inter_channels), nn.BatchNorm2d(inter_channels),
nn.ReLU(), nn.ReLU(),
nn.Dropout(0.1), nn.Dropout(0.1),
nn.Conv2d(inter_channels, channels, 1) nn.Conv2d(inter_channels, channels, 1),
] ]
super(FCNHead, self).__init__(*layers) super(FCNHead, self).__init__(*layers)
from collections import OrderedDict from collections import OrderedDict
from typing import Dict
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import functional as F from torch.nn import functional as F
from typing import Dict
__all__ = ["LRASPP"] __all__ = ["LRASPP"]
...@@ -25,12 +25,7 @@ class LRASPP(nn.Module): ...@@ -25,12 +25,7 @@ class LRASPP(nn.Module):
""" """
def __init__( def __init__(
self, self, backbone: nn.Module, low_channels: int, high_channels: int, num_classes: int, inter_channels: int = 128
backbone: nn.Module,
low_channels: int,
high_channels: int,
num_classes: int,
inter_channels: int = 128
) -> None: ) -> None:
super().__init__() super().__init__()
self.backbone = backbone self.backbone = backbone
...@@ -39,7 +34,7 @@ class LRASPP(nn.Module): ...@@ -39,7 +34,7 @@ class LRASPP(nn.Module):
def forward(self, input: Tensor) -> Dict[str, Tensor]: def forward(self, input: Tensor) -> Dict[str, Tensor]:
features = self.backbone(input) features = self.backbone(input)
out = self.classifier(features) out = self.classifier(features)
out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False) out = F.interpolate(out, size=input.shape[-2:], mode="bilinear", align_corners=False)
result = OrderedDict() result = OrderedDict()
result["out"] = out result["out"] = out
...@@ -48,19 +43,12 @@ class LRASPP(nn.Module): ...@@ -48,19 +43,12 @@ class LRASPP(nn.Module):
class LRASPPHead(nn.Module): class LRASPPHead(nn.Module):
def __init__(self, low_channels: int, high_channels: int, num_classes: int, inter_channels: int) -> None:
def __init__(
self,
low_channels: int,
high_channels: int,
num_classes: int,
inter_channels: int
) -> None:
super().__init__() super().__init__()
self.cbr = nn.Sequential( self.cbr = nn.Sequential(
nn.Conv2d(high_channels, inter_channels, 1, bias=False), nn.Conv2d(high_channels, inter_channels, 1, bias=False),
nn.BatchNorm2d(inter_channels), nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True) nn.ReLU(inplace=True),
) )
self.scale = nn.Sequential( self.scale = nn.Sequential(
nn.AdaptiveAvgPool2d(1), nn.AdaptiveAvgPool2d(1),
...@@ -77,6 +65,6 @@ class LRASPPHead(nn.Module): ...@@ -77,6 +65,6 @@ class LRASPPHead(nn.Module):
x = self.cbr(high) x = self.cbr(high)
s = self.scale(high) s = self.scale(high)
x = x * s x = x * s
x = F.interpolate(x, size=low.shape[-2:], mode='bilinear', align_corners=False) x = F.interpolate(x, size=low.shape[-2:], mode="bilinear", align_corners=False)
return self.low_classifier(low) + self.high_classifier(x) return self.low_classifier(low) + self.high_classifier(x)
from torch import nn
from typing import Any, Optional from typing import Any, Optional
from .._utils import IntermediateLayerGetter
from torch import nn
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from .. import mobilenetv3 from .. import mobilenetv3
from .. import resnet from .. import resnet
from .._utils import IntermediateLayerGetter
from .deeplabv3 import DeepLabHead, DeepLabV3 from .deeplabv3 import DeepLabHead, DeepLabV3
from .fcn import FCN, FCNHead from .fcn import FCN, FCNHead
from .lraspp import LRASPP from .lraspp import LRASPP
__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101', __all__ = [
'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large'] "fcn_resnet50",
"fcn_resnet101",
"deeplabv3_resnet50",
"deeplabv3_resnet101",
"deeplabv3_mobilenet_v3_large",
"lraspp_mobilenet_v3_large",
]
model_urls = { model_urls = {
'fcn_resnet50_coco': 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth', "fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth', "fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth', "deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth', "deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
'deeplabv3_mobilenet_v3_large_coco': "deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
'https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth', "lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
'lraspp_mobilenet_v3_large_coco': 'https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth',
} }
def _segm_model( def _segm_model(
name: str, name: str, backbone_name: str, num_classes: int, aux: Optional[bool], pretrained_backbone: bool = True
backbone_name: str,
num_classes: int,
aux: Optional[bool],
pretrained_backbone: bool = True
) -> nn.Module: ) -> nn.Module:
if 'resnet' in backbone_name: if "resnet" in backbone_name:
backbone = resnet.__dict__[backbone_name]( backbone = resnet.__dict__[backbone_name](
pretrained=pretrained_backbone, pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]
replace_stride_with_dilation=[False, True, True]) )
out_layer = 'layer4' out_layer = "layer4"
out_inplanes = 2048 out_inplanes = 2048
aux_layer = 'layer3' aux_layer = "layer3"
aux_inplanes = 1024 aux_inplanes = 1024
elif 'mobilenet_v3' in backbone_name: elif "mobilenet_v3" in backbone_name:
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
...@@ -52,11 +55,11 @@ def _segm_model( ...@@ -52,11 +55,11 @@ def _segm_model(
aux_layer = str(aux_pos) aux_layer = str(aux_pos)
aux_inplanes = backbone[aux_pos].out_channels aux_inplanes = backbone[aux_pos].out_channels
else: else:
raise NotImplementedError('backbone {} is not supported as of now'.format(backbone_name)) raise NotImplementedError("backbone {} is not supported as of now".format(backbone_name))
return_layers = {out_layer: 'out'} return_layers = {out_layer: "out"}
if aux: if aux:
return_layers[aux_layer] = 'aux' return_layers[aux_layer] = "aux"
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
aux_classifier = None aux_classifier = None
...@@ -64,8 +67,8 @@ def _segm_model( ...@@ -64,8 +67,8 @@ def _segm_model(
aux_classifier = FCNHead(aux_inplanes, num_classes) aux_classifier = FCNHead(aux_inplanes, num_classes)
model_map = { model_map = {
'deeplabv3': (DeepLabHead, DeepLabV3), "deeplabv3": (DeepLabHead, DeepLabV3),
'fcn': (FCNHead, FCN), "fcn": (FCNHead, FCN),
} }
classifier = model_map[name][0](out_inplanes, num_classes) classifier = model_map[name][0](out_inplanes, num_classes)
base_model = model_map[name][1] base_model = model_map[name][1]
...@@ -81,7 +84,7 @@ def _load_model( ...@@ -81,7 +84,7 @@ def _load_model(
progress: bool, progress: bool,
num_classes: int, num_classes: int,
aux_loss: Optional[bool], aux_loss: Optional[bool],
**kwargs: Any **kwargs: Any,
) -> nn.Module: ) -> nn.Module:
if pretrained: if pretrained:
aux_loss = True aux_loss = True
...@@ -93,10 +96,10 @@ def _load_model( ...@@ -93,10 +96,10 @@ def _load_model(
def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: bool) -> None: def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: bool) -> None:
arch = arch_type + '_' + backbone + '_coco' arch = arch_type + "_" + backbone + "_coco"
model_url = model_urls.get(arch, None) model_url = model_urls.get(arch, None)
if model_url is None: if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) raise NotImplementedError("pretrained {} is not supported as of now".format(arch))
else: else:
state_dict = load_state_dict_from_url(model_url, progress=progress) state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
...@@ -113,7 +116,7 @@ def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_ba ...@@ -113,7 +116,7 @@ def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_ba
low_channels = backbone[low_pos].out_channels low_channels = backbone[low_pos].out_channels
high_channels = backbone[high_pos].out_channels high_channels = backbone[high_pos].out_channels
backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): 'low', str(high_pos): 'high'}) backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"})
model = LRASPP(backbone, low_channels, high_channels, num_classes) model = LRASPP(backbone, low_channels, high_channels, num_classes)
return model return model
...@@ -124,7 +127,7 @@ def fcn_resnet50( ...@@ -124,7 +127,7 @@ def fcn_resnet50(
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: int = 21,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
**kwargs: Any **kwargs: Any,
) -> nn.Module: ) -> nn.Module:
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
...@@ -135,7 +138,7 @@ def fcn_resnet50( ...@@ -135,7 +138,7 @@ def fcn_resnet50(
num_classes (int): number of output classes of the model (including the background) num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss aux_loss (bool): If True, it uses an auxiliary loss
""" """
return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) return _load_model("fcn", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs)
def fcn_resnet101( def fcn_resnet101(
...@@ -143,7 +146,7 @@ def fcn_resnet101( ...@@ -143,7 +146,7 @@ def fcn_resnet101(
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: int = 21,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
**kwargs: Any **kwargs: Any,
) -> nn.Module: ) -> nn.Module:
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
...@@ -154,7 +157,7 @@ def fcn_resnet101( ...@@ -154,7 +157,7 @@ def fcn_resnet101(
num_classes (int): number of output classes of the model (including the background) num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss aux_loss (bool): If True, it uses an auxiliary loss
""" """
return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) return _load_model("fcn", "resnet101", pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_resnet50( def deeplabv3_resnet50(
...@@ -162,7 +165,7 @@ def deeplabv3_resnet50( ...@@ -162,7 +165,7 @@ def deeplabv3_resnet50(
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: int = 21,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
**kwargs: Any **kwargs: Any,
) -> nn.Module: ) -> nn.Module:
"""Constructs a DeepLabV3 model with a ResNet-50 backbone. """Constructs a DeepLabV3 model with a ResNet-50 backbone.
...@@ -173,7 +176,7 @@ def deeplabv3_resnet50( ...@@ -173,7 +176,7 @@ def deeplabv3_resnet50(
num_classes (int): number of output classes of the model (including the background) num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss aux_loss (bool): If True, it uses an auxiliary loss
""" """
return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) return _load_model("deeplabv3", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_resnet101( def deeplabv3_resnet101(
...@@ -181,7 +184,7 @@ def deeplabv3_resnet101( ...@@ -181,7 +184,7 @@ def deeplabv3_resnet101(
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: int = 21,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
**kwargs: Any **kwargs: Any,
) -> nn.Module: ) -> nn.Module:
"""Constructs a DeepLabV3 model with a ResNet-101 backbone. """Constructs a DeepLabV3 model with a ResNet-101 backbone.
...@@ -192,7 +195,7 @@ def deeplabv3_resnet101( ...@@ -192,7 +195,7 @@ def deeplabv3_resnet101(
num_classes (int): The number of classes num_classes (int): The number of classes
aux_loss (bool): If True, include an auxiliary classifier aux_loss (bool): If True, include an auxiliary classifier
""" """
return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) return _load_model("deeplabv3", "resnet101", pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_mobilenet_v3_large( def deeplabv3_mobilenet_v3_large(
...@@ -200,7 +203,7 @@ def deeplabv3_mobilenet_v3_large( ...@@ -200,7 +203,7 @@ def deeplabv3_mobilenet_v3_large(
progress: bool = True, progress: bool = True,
num_classes: int = 21, num_classes: int = 21,
aux_loss: Optional[bool] = None, aux_loss: Optional[bool] = None,
**kwargs: Any **kwargs: Any,
) -> nn.Module: ) -> nn.Module:
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
...@@ -211,14 +214,11 @@ def deeplabv3_mobilenet_v3_large( ...@@ -211,14 +214,11 @@ def deeplabv3_mobilenet_v3_large(
num_classes (int): number of output classes of the model (including the background) num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss aux_loss (bool): If True, it uses an auxiliary loss
""" """
return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs) return _load_model("deeplabv3", "mobilenet_v3_large", pretrained, progress, num_classes, aux_loss, **kwargs)
def lraspp_mobilenet_v3_large( def lraspp_mobilenet_v3_large(
pretrained: bool = False, pretrained: bool = False, progress: bool = True, num_classes: int = 21, **kwargs: Any
progress: bool = True,
num_classes: int = 21,
**kwargs: Any
) -> nn.Module: ) -> nn.Module:
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
...@@ -229,14 +229,14 @@ def lraspp_mobilenet_v3_large( ...@@ -229,14 +229,14 @@ def lraspp_mobilenet_v3_large(
num_classes (int): number of output classes of the model (including the background) num_classes (int): number of output classes of the model (including the background)
""" """
if kwargs.pop("aux_loss", False): if kwargs.pop("aux_loss", False):
raise NotImplementedError('This model does not use auxiliary loss') raise NotImplementedError("This model does not use auxiliary loss")
backbone_name = 'mobilenet_v3_large' backbone_name = "mobilenet_v3_large"
if pretrained: if pretrained:
kwargs["pretrained_backbone"] = False kwargs["pretrained_backbone"] = False
model = _segm_lraspp_mobilenetv3(backbone_name, num_classes, **kwargs) model = _segm_lraspp_mobilenetv3(backbone_name, num_classes, **kwargs)
if pretrained: if pretrained:
_load_weights(model, 'lraspp', backbone_name, progress) _load_weights(model, "lraspp", backbone_name, progress)
return model return model
from typing import Callable, Any, List
import torch import torch
from torch import Tensor
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from typing import Callable, Any, List
__all__ = [ __all__ = ["ShuffleNetV2", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", "shufflenet_v2_x1_5", "shufflenet_v2_x2_0"]
'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
]
model_urls = { model_urls = {
'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth', "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth', "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
'shufflenetv2_x1.5': None, "shufflenetv2_x1.5": None,
'shufflenetv2_x2.0': None, "shufflenetv2_x2.0": None,
} }
...@@ -23,8 +22,7 @@ def channel_shuffle(x: Tensor, groups: int) -> Tensor: ...@@ -23,8 +22,7 @@ def channel_shuffle(x: Tensor, groups: int) -> Tensor:
channels_per_group = num_channels // groups channels_per_group = num_channels // groups
# reshape # reshape
x = x.view(batchsize, groups, x = x.view(batchsize, groups, channels_per_group, height, width)
channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous() x = torch.transpose(x, 1, 2).contiguous()
...@@ -35,16 +33,11 @@ def channel_shuffle(x: Tensor, groups: int) -> Tensor: ...@@ -35,16 +33,11 @@ def channel_shuffle(x: Tensor, groups: int) -> Tensor:
class InvertedResidual(nn.Module): class InvertedResidual(nn.Module):
def __init__( def __init__(self, inp: int, oup: int, stride: int) -> None:
self,
inp: int,
oup: int,
stride: int
) -> None:
super(InvertedResidual, self).__init__() super(InvertedResidual, self).__init__()
if not (1 <= stride <= 3): if not (1 <= stride <= 3):
raise ValueError('illegal stride value') raise ValueError("illegal stride value")
self.stride = stride self.stride = stride
branch_features = oup // 2 branch_features = oup // 2
...@@ -62,8 +55,14 @@ class InvertedResidual(nn.Module): ...@@ -62,8 +55,14 @@ class InvertedResidual(nn.Module):
self.branch1 = nn.Sequential() self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential( self.branch2 = nn.Sequential(
nn.Conv2d(inp if (self.stride > 1) else branch_features, nn.Conv2d(
branch_features, kernel_size=1, stride=1, padding=0, bias=False), inp if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(branch_features), nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
...@@ -75,12 +74,7 @@ class InvertedResidual(nn.Module): ...@@ -75,12 +74,7 @@ class InvertedResidual(nn.Module):
@staticmethod @staticmethod
def depthwise_conv( def depthwise_conv(
i: int, i: int, o: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False
o: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = False
) -> nn.Conv2d: ) -> nn.Conv2d:
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
...@@ -102,14 +96,14 @@ class ShuffleNetV2(nn.Module): ...@@ -102,14 +96,14 @@ class ShuffleNetV2(nn.Module):
stages_repeats: List[int], stages_repeats: List[int],
stages_out_channels: List[int], stages_out_channels: List[int],
num_classes: int = 1000, num_classes: int = 1000,
inverted_residual: Callable[..., nn.Module] = InvertedResidual inverted_residual: Callable[..., nn.Module] = InvertedResidual,
) -> None: ) -> None:
super(ShuffleNetV2, self).__init__() super(ShuffleNetV2, self).__init__()
if len(stages_repeats) != 3: if len(stages_repeats) != 3:
raise ValueError('expected stages_repeats as list of 3 positive ints') raise ValueError("expected stages_repeats as list of 3 positive ints")
if len(stages_out_channels) != 5: if len(stages_out_channels) != 5:
raise ValueError('expected stages_out_channels as list of 5 positive ints') raise ValueError("expected stages_out_channels as list of 5 positive ints")
self._stage_out_channels = stages_out_channels self._stage_out_channels = stages_out_channels
input_channels = 3 input_channels = 3
...@@ -127,9 +121,8 @@ class ShuffleNetV2(nn.Module): ...@@ -127,9 +121,8 @@ class ShuffleNetV2(nn.Module):
self.stage2: nn.Sequential self.stage2: nn.Sequential
self.stage3: nn.Sequential self.stage3: nn.Sequential
self.stage4: nn.Sequential self.stage4: nn.Sequential
stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip( for name, repeats, output_channels in zip(stage_names, stages_repeats, self._stage_out_channels[1:]):
stage_names, stages_repeats, self._stage_out_channels[1:]):
seq = [inverted_residual(input_channels, output_channels, 2)] seq = [inverted_residual(input_channels, output_channels, 2)]
for i in range(repeats - 1): for i in range(repeats - 1):
seq.append(inverted_residual(output_channels, output_channels, 1)) seq.append(inverted_residual(output_channels, output_channels, 1))
...@@ -167,7 +160,7 @@ def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwa ...@@ -167,7 +160,7 @@ def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwa
if pretrained: if pretrained:
model_url = model_urls[arch] model_url = model_urls[arch]
if model_url is None: if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) raise NotImplementedError("pretrained {} is not supported as of now".format(arch))
else: else:
state_dict = load_state_dict_from_url(model_url, progress=progress) state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
...@@ -185,8 +178,7 @@ def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs ...@@ -185,8 +178,7 @@ def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, return _shufflenetv2("shufflenetv2_x0.5", pretrained, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
...@@ -199,8 +191,7 @@ def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs ...@@ -199,8 +191,7 @@ def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, return _shufflenetv2("shufflenetv2_x1.0", pretrained, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
...@@ -213,8 +204,7 @@ def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs ...@@ -213,8 +204,7 @@ def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, return _shufflenetv2("shufflenetv2_x1.5", pretrained, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
...@@ -227,5 +217,4 @@ def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs ...@@ -227,5 +217,4 @@ def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, return _shufflenetv2("shufflenetv2_x2.0", pretrained, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
[4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.init as init import torch.nn.init as init
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from typing import Any
__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] __all__ = ["SqueezeNet", "squeezenet1_0", "squeezenet1_1"]
model_urls = { model_urls = {
'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth', "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth', "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
} }
class Fire(nn.Module): class Fire(nn.Module):
def __init__(self, inplanes: int, squeeze_planes: int, expand1x1_planes: int, expand3x3_planes: int) -> None:
def __init__(
self,
inplanes: int,
squeeze_planes: int,
expand1x1_planes: int,
expand3x3_planes: int
) -> None:
super(Fire, self).__init__() super(Fire, self).__init__()
self.inplanes = inplanes self.inplanes = inplanes
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
self.squeeze_activation = nn.ReLU(inplace=True) self.squeeze_activation = nn.ReLU(inplace=True)
self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)
kernel_size=1)
self.expand1x1_activation = nn.ReLU(inplace=True) self.expand1x1_activation = nn.ReLU(inplace=True)
self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1)
kernel_size=3, padding=1)
self.expand3x3_activation = nn.ReLU(inplace=True) self.expand3x3_activation = nn.ReLU(inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.squeeze_activation(self.squeeze(x)) x = self.squeeze_activation(self.squeeze(x))
return torch.cat([ return torch.cat(
self.expand1x1_activation(self.expand1x1(x)), [self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], 1
self.expand3x3_activation(self.expand3x3(x)) )
], 1)
class SqueezeNet(nn.Module): class SqueezeNet(nn.Module):
def __init__(self, version: str = "1_0", num_classes: int = 1000) -> None:
def __init__(
self,
version: str = '1_0',
num_classes: int = 1000
) -> None:
super(SqueezeNet, self).__init__() super(SqueezeNet, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
if version == '1_0': if version == "1_0":
self.features = nn.Sequential( self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=7, stride=2), nn.Conv2d(3, 96, kernel_size=7, stride=2),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
...@@ -65,7 +52,7 @@ class SqueezeNet(nn.Module): ...@@ -65,7 +52,7 @@ class SqueezeNet(nn.Module):
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(512, 64, 256, 256), Fire(512, 64, 256, 256),
) )
elif version == '1_1': elif version == "1_1":
self.features = nn.Sequential( self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2), nn.Conv2d(3, 64, kernel_size=3, stride=2),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
...@@ -85,16 +72,12 @@ class SqueezeNet(nn.Module): ...@@ -85,16 +72,12 @@ class SqueezeNet(nn.Module):
# FIXME: Is this needed? SqueezeNet should only be called from the # FIXME: Is this needed? SqueezeNet should only be called from the
# FIXME: squeezenet1_x() functions # FIXME: squeezenet1_x() functions
# FIXME: This checking is not done for the other models # FIXME: This checking is not done for the other models
raise ValueError("Unsupported SqueezeNet version {version}:" raise ValueError("Unsupported SqueezeNet version {version}:" "1_0 or 1_1 expected".format(version=version))
"1_0 or 1_1 expected".format(version=version))
# Final convolution is initialized differently from the rest # Final convolution is initialized differently from the rest
final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.Dropout(p=0.5), nn.Dropout(p=0.5), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1))
final_conv,
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1))
) )
for m in self.modules(): for m in self.modules():
...@@ -115,9 +98,8 @@ class SqueezeNet(nn.Module): ...@@ -115,9 +98,8 @@ class SqueezeNet(nn.Module):
def _squeezenet(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet: def _squeezenet(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet:
model = SqueezeNet(version, **kwargs) model = SqueezeNet(version, **kwargs)
if pretrained: if pretrained:
arch = 'squeezenet' + version arch = "squeezenet" + version
state_dict = load_state_dict_from_url(model_urls[arch], state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
...@@ -132,7 +114,7 @@ def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any ...@@ -132,7 +114,7 @@ def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _squeezenet('1_0', pretrained, progress, **kwargs) return _squeezenet("1_0", pretrained, progress, **kwargs)
def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet: def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet:
...@@ -146,4 +128,4 @@ def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any ...@@ -146,4 +128,4 @@ def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _squeezenet('1_1', pretrained, progress, **kwargs) return _squeezenet("1_1", pretrained, progress, **kwargs)
from typing import Union, List, Dict, Any, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from typing import Union, List, Dict, Any, cast
__all__ = [ __all__ = [
'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', "VGG",
'vgg19_bn', 'vgg19', "vgg11",
"vgg11_bn",
"vgg13",
"vgg13_bn",
"vgg16",
"vgg16_bn",
"vgg19_bn",
"vgg19",
] ]
model_urls = { model_urls = {
'vgg11': 'https://download.pytorch.org/models/vgg11-8a719046.pth', "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
'vgg13': 'https://download.pytorch.org/models/vgg13-19584684.pth', "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
} }
class VGG(nn.Module): class VGG(nn.Module):
def __init__(self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True) -> None:
def __init__(
self,
features: nn.Module,
num_classes: int = 1000,
init_weights: bool = True
) -> None:
super(VGG, self).__init__() super(VGG, self).__init__()
self.features = features self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
...@@ -55,7 +58,7 @@ class VGG(nn.Module): ...@@ -55,7 +58,7 @@ class VGG(nn.Module):
def _initialize_weights(self) -> None: def _initialize_weights(self) -> None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
...@@ -70,7 +73,7 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ ...@@ -70,7 +73,7 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ
layers: List[nn.Module] = [] layers: List[nn.Module] = []
in_channels = 3 in_channels = 3
for v in cfg: for v in cfg:
if v == 'M': if v == "M":
layers += [nn.MaxPool2d(kernel_size=2, stride=2)] layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else: else:
v = cast(int, v) v = cast(int, v)
...@@ -84,20 +87,19 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ ...@@ -84,20 +87,19 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ
cfgs: Dict[str, List[Union[str, int]]] = { cfgs: Dict[str, List[Union[str, int]]] = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
} }
def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG: def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
if pretrained: if pretrained:
kwargs['init_weights'] = False kwargs["init_weights"] = False
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if pretrained: if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
...@@ -111,7 +113,7 @@ def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG ...@@ -111,7 +113,7 @@ def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) return _vgg("vgg11", "A", False, pretrained, progress, **kwargs)
def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
...@@ -123,7 +125,7 @@ def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ...@@ -123,7 +125,7 @@ def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) return _vgg("vgg11_bn", "A", True, pretrained, progress, **kwargs)
def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
...@@ -135,7 +137,7 @@ def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG ...@@ -135,7 +137,7 @@ def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) return _vgg("vgg13", "B", False, pretrained, progress, **kwargs)
def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
...@@ -147,7 +149,7 @@ def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ...@@ -147,7 +149,7 @@ def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) return _vgg("vgg13_bn", "B", True, pretrained, progress, **kwargs)
def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
...@@ -159,7 +161,7 @@ def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG ...@@ -159,7 +161,7 @@ def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) return _vgg("vgg16", "D", False, pretrained, progress, **kwargs)
def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
...@@ -171,7 +173,7 @@ def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ...@@ -171,7 +173,7 @@ def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) return _vgg("vgg16_bn", "D", True, pretrained, progress, **kwargs)
def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
...@@ -183,7 +185,7 @@ def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG ...@@ -183,7 +185,7 @@ def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) return _vgg("vgg19", "E", False, pretrained, progress, **kwargs)
def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
...@@ -195,4 +197,4 @@ def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ...@@ -195,4 +197,4 @@ def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) return _vgg("vgg19_bn", "E", True, pretrained, progress, **kwargs)
from torch import Tensor
import torch.nn as nn
from typing import Tuple, Optional, Callable, List, Type, Any, Union from typing import Tuple, Optional, Callable, List, Type, Any, Union
import torch.nn as nn
from torch import Tensor
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
__all__ = ['r3d_18', 'mc3_18', 'r2plus1d_18'] __all__ = ["r3d_18", "mc3_18", "r2plus1d_18"]
model_urls = { model_urls = {
'r3d_18': 'https://download.pytorch.org/models/r3d_18-b3b3357e.pth', "r3d_18": "https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
'mc3_18': 'https://download.pytorch.org/models/mc3_18-a90a0ba3.pth', "mc3_18": "https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
'r2plus1d_18': 'https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth', "r2plus1d_18": "https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
} }
class Conv3DSimple(nn.Conv3d): class Conv3DSimple(nn.Conv3d):
def __init__( def __init__(
self, self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1
in_planes: int,
out_planes: int,
midplanes: Optional[int] = None,
stride: int = 1,
padding: int = 1
) -> None: ) -> None:
super(Conv3DSimple, self).__init__( super(Conv3DSimple, self).__init__(
...@@ -30,7 +26,8 @@ class Conv3DSimple(nn.Conv3d): ...@@ -30,7 +26,8 @@ class Conv3DSimple(nn.Conv3d):
kernel_size=(3, 3, 3), kernel_size=(3, 3, 3),
stride=stride, stride=stride,
padding=padding, padding=padding,
bias=False) bias=False,
)
@staticmethod @staticmethod
def get_downsample_stride(stride: int) -> Tuple[int, int, int]: def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
...@@ -38,24 +35,22 @@ class Conv3DSimple(nn.Conv3d): ...@@ -38,24 +35,22 @@ class Conv3DSimple(nn.Conv3d):
class Conv2Plus1D(nn.Sequential): class Conv2Plus1D(nn.Sequential):
def __init__(self, in_planes: int, out_planes: int, midplanes: int, stride: int = 1, padding: int = 1) -> None:
def __init__(
self,
in_planes: int,
out_planes: int,
midplanes: int,
stride: int = 1,
padding: int = 1
) -> None:
super(Conv2Plus1D, self).__init__( super(Conv2Plus1D, self).__init__(
nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), nn.Conv3d(
stride=(1, stride, stride), padding=(0, padding, padding), in_planes,
bias=False), midplanes,
kernel_size=(1, 3, 3),
stride=(1, stride, stride),
padding=(0, padding, padding),
bias=False,
),
nn.BatchNorm3d(midplanes), nn.BatchNorm3d(midplanes),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1), nn.Conv3d(
stride=(stride, 1, 1), padding=(padding, 0, 0), midplanes, out_planes, kernel_size=(3, 1, 1), stride=(stride, 1, 1), padding=(padding, 0, 0), bias=False
bias=False)) ),
)
@staticmethod @staticmethod
def get_downsample_stride(stride: int) -> Tuple[int, int, int]: def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
...@@ -63,14 +58,8 @@ class Conv2Plus1D(nn.Sequential): ...@@ -63,14 +58,8 @@ class Conv2Plus1D(nn.Sequential):
class Conv3DNoTemporal(nn.Conv3d): class Conv3DNoTemporal(nn.Conv3d):
def __init__( def __init__(
self, self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1
in_planes: int,
out_planes: int,
midplanes: Optional[int] = None,
stride: int = 1,
padding: int = 1
) -> None: ) -> None:
super(Conv3DNoTemporal, self).__init__( super(Conv3DNoTemporal, self).__init__(
...@@ -79,7 +68,8 @@ class Conv3DNoTemporal(nn.Conv3d): ...@@ -79,7 +68,8 @@ class Conv3DNoTemporal(nn.Conv3d):
kernel_size=(1, 3, 3), kernel_size=(1, 3, 3),
stride=(1, stride, stride), stride=(1, stride, stride),
padding=(0, padding, padding), padding=(0, padding, padding),
bias=False) bias=False,
)
@staticmethod @staticmethod
def get_downsample_stride(stride: int) -> Tuple[int, int, int]: def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
...@@ -102,14 +92,9 @@ class BasicBlock(nn.Module): ...@@ -102,14 +92,9 @@ class BasicBlock(nn.Module):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
self.conv1 = nn.Sequential( self.conv1 = nn.Sequential(
conv_builder(inplanes, planes, midplanes, stride), conv_builder(inplanes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
nn.BatchNorm3d(planes),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
conv_builder(planes, planes, midplanes),
nn.BatchNorm3d(planes)
) )
self.conv2 = nn.Sequential(conv_builder(planes, planes, midplanes), nn.BatchNorm3d(planes))
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
...@@ -145,21 +130,17 @@ class Bottleneck(nn.Module): ...@@ -145,21 +130,17 @@ class Bottleneck(nn.Module):
# 1x1x1 # 1x1x1
self.conv1 = nn.Sequential( self.conv1 = nn.Sequential(
nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
nn.BatchNorm3d(planes),
nn.ReLU(inplace=True)
) )
# Second kernel # Second kernel
self.conv2 = nn.Sequential( self.conv2 = nn.Sequential(
conv_builder(planes, planes, midplanes, stride), conv_builder(planes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
nn.BatchNorm3d(planes),
nn.ReLU(inplace=True)
) )
# 1x1x1 # 1x1x1
self.conv3 = nn.Sequential( self.conv3 = nn.Sequential(
nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
nn.BatchNorm3d(planes * self.expansion) nn.BatchNorm3d(planes * self.expansion),
) )
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.downsample = downsample self.downsample = downsample
...@@ -182,35 +163,31 @@ class Bottleneck(nn.Module): ...@@ -182,35 +163,31 @@ class Bottleneck(nn.Module):
class BasicStem(nn.Sequential): class BasicStem(nn.Sequential):
"""The default conv-batchnorm-relu stem """The default conv-batchnorm-relu stem"""
"""
def __init__(self) -> None: def __init__(self) -> None:
super(BasicStem, self).__init__( super(BasicStem, self).__init__(
nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False),
padding=(1, 3, 3), bias=False),
nn.BatchNorm3d(64), nn.BatchNorm3d(64),
nn.ReLU(inplace=True)) nn.ReLU(inplace=True),
)
class R2Plus1dStem(nn.Sequential): class R2Plus1dStem(nn.Sequential):
"""R(2+1)D stem is different than the default one as it uses separated 3D convolution """R(2+1)D stem is different than the default one as it uses separated 3D convolution"""
"""
def __init__(self) -> None: def __init__(self) -> None:
super(R2Plus1dStem, self).__init__( super(R2Plus1dStem, self).__init__(
nn.Conv3d(3, 45, kernel_size=(1, 7, 7), nn.Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False),
stride=(1, 2, 2), padding=(0, 3, 3),
bias=False),
nn.BatchNorm3d(45), nn.BatchNorm3d(45),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv3d(45, 64, kernel_size=(3, 1, 1), nn.Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False),
stride=(1, 1, 1), padding=(1, 0, 0),
bias=False),
nn.BatchNorm3d(64), nn.BatchNorm3d(64),
nn.ReLU(inplace=True)) nn.ReLU(inplace=True),
)
class VideoResNet(nn.Module): class VideoResNet(nn.Module):
def __init__( def __init__(
self, self,
block: Type[Union[BasicBlock, Bottleneck]], block: Type[Union[BasicBlock, Bottleneck]],
...@@ -273,16 +250,15 @@ class VideoResNet(nn.Module): ...@@ -273,16 +250,15 @@ class VideoResNet(nn.Module):
conv_builder: Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]], conv_builder: Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]],
planes: int, planes: int,
blocks: int, blocks: int,
stride: int = 1 stride: int = 1,
) -> nn.Sequential: ) -> nn.Sequential:
downsample = None downsample = None
if stride != 1 or self.inplanes != planes * block.expansion: if stride != 1 or self.inplanes != planes * block.expansion:
ds_stride = conv_builder.get_downsample_stride(stride) ds_stride = conv_builder.get_downsample_stride(stride)
downsample = nn.Sequential( downsample = nn.Sequential(
nn.Conv3d(self.inplanes, planes * block.expansion, nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=ds_stride, bias=False),
kernel_size=1, stride=ds_stride, bias=False), nn.BatchNorm3d(planes * block.expansion),
nn.BatchNorm3d(planes * block.expansion)
) )
layers = [] layers = []
layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))
...@@ -296,8 +272,7 @@ class VideoResNet(nn.Module): ...@@ -296,8 +272,7 @@ class VideoResNet(nn.Module):
def _initialize_weights(self) -> None: def _initialize_weights(self) -> None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv3d): if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
nonlinearity='relu')
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm3d): elif isinstance(m, nn.BatchNorm3d):
...@@ -312,8 +287,7 @@ def _video_resnet(arch: str, pretrained: bool = False, progress: bool = True, ** ...@@ -312,8 +287,7 @@ def _video_resnet(arch: str, pretrained: bool = False, progress: bool = True, **
model = VideoResNet(**kwargs) model = VideoResNet(**kwargs)
if pretrained: if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
...@@ -330,12 +304,16 @@ def r3d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> Vi ...@@ -330,12 +304,16 @@ def r3d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> Vi
nn.Module: R3D-18 network nn.Module: R3D-18 network
""" """
return _video_resnet('r3d_18', return _video_resnet(
pretrained, progress, "r3d_18",
block=BasicBlock, pretrained,
conv_makers=[Conv3DSimple] * 4, progress,
layers=[2, 2, 2, 2], block=BasicBlock,
stem=BasicStem, **kwargs) conv_makers=[Conv3DSimple] * 4,
layers=[2, 2, 2, 2],
stem=BasicStem,
**kwargs,
)
def mc3_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: def mc3_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet:
...@@ -349,12 +327,16 @@ def mc3_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> Vi ...@@ -349,12 +327,16 @@ def mc3_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> Vi
Returns: Returns:
nn.Module: MC3 Network definition nn.Module: MC3 Network definition
""" """
return _video_resnet('mc3_18', return _video_resnet(
pretrained, progress, "mc3_18",
block=BasicBlock, pretrained,
conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item] progress,
layers=[2, 2, 2, 2], block=BasicBlock,
stem=BasicStem, **kwargs) conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item]
layers=[2, 2, 2, 2],
stem=BasicStem,
**kwargs,
)
def r2plus1d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: def r2plus1d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet:
...@@ -368,9 +350,13 @@ def r2plus1d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) ...@@ -368,9 +350,13 @@ def r2plus1d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any)
Returns: Returns:
nn.Module: R(2+1)D-18 network nn.Module: R(2+1)D-18 network
""" """
return _video_resnet('r2plus1d_18', return _video_resnet(
pretrained, progress, "r2plus1d_18",
block=BasicBlock, pretrained,
conv_makers=[Conv2Plus1D] * 4, progress,
layers=[2, 2, 2, 2], block=BasicBlock,
stem=R2Plus1dStem, **kwargs) conv_makers=[Conv2Plus1D] * 4,
layers=[2, 2, 2, 2],
stem=R2Plus1dStem,
**kwargs,
)
from .boxes import nms, batched_nms, remove_small_boxes, clip_boxes_to_image, box_area, box_iou, generalized_box_iou, \ from ._register_onnx_ops import _register_custom_op
masks_to_boxes from .boxes import (
nms,
batched_nms,
remove_small_boxes,
clip_boxes_to_image,
box_area,
box_iou,
generalized_box_iou,
masks_to_boxes,
)
from .boxes import box_convert from .boxes import box_convert
from .deform_conv import deform_conv2d, DeformConv2d from .deform_conv import deform_conv2d, DeformConv2d
from .roi_align import roi_align, RoIAlign
from .roi_pool import roi_pool, RoIPool
from .ps_roi_align import ps_roi_align, PSRoIAlign
from .ps_roi_pool import ps_roi_pool, PSRoIPool
from .poolers import MultiScaleRoIAlign
from .feature_pyramid_network import FeaturePyramidNetwork from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss from .focal_loss import sigmoid_focal_loss
from .poolers import MultiScaleRoIAlign
from .ps_roi_align import ps_roi_align, PSRoIAlign
from .ps_roi_pool import ps_roi_pool, PSRoIPool
from .roi_align import roi_align, RoIAlign
from .roi_pool import roi_pool, RoIPool
from .stochastic_depth import stochastic_depth, StochasticDepth from .stochastic_depth import stochastic_depth, StochasticDepth
from ._register_onnx_ops import _register_custom_op
_register_custom_op() _register_custom_op()
__all__ = [ __all__ = [
'deform_conv2d', 'DeformConv2d', 'nms', 'batched_nms', 'remove_small_boxes', "deform_conv2d",
'clip_boxes_to_image', 'box_convert', "DeformConv2d",
'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool', "nms",
'RoIPool', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', "batched_nms",
'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork', "remove_small_boxes",
'sigmoid_focal_loss', 'stochastic_depth', 'StochasticDepth' "clip_boxes_to_image",
"box_convert",
"box_area",
"box_iou",
"generalized_box_iou",
"roi_align",
"RoIAlign",
"roi_pool",
"RoIPool",
"ps_roi_align",
"PSRoIAlign",
"ps_roi_pool",
"PSRoIPool",
"MultiScaleRoIAlign",
"FeaturePyramidNetwork",
"sigmoid_focal_loss",
"stochastic_depth",
"StochasticDepth",
] ]
import sys import sys
import torch
import warnings import warnings
import torch
_onnx_opset_version = 11 _onnx_opset_version = 11
def _register_custom_op(): def _register_custom_op():
from torch.onnx.symbolic_helper import parse_args, scalar_type_to_onnx, scalar_type_to_pytorch_type, \ from torch.onnx.symbolic_helper import (
cast_pytorch_to_onnx parse_args,
from torch.onnx.symbolic_opset9 import _cast_Long scalar_type_to_onnx,
scalar_type_to_pytorch_type,
cast_pytorch_to_onnx,
)
from torch.onnx.symbolic_opset11 import select, squeeze, unsqueeze from torch.onnx.symbolic_opset11 import select, squeeze, unsqueeze
from torch.onnx.symbolic_opset9 import _cast_Long
@parse_args('v', 'v', 'f') @parse_args("v", "v", "f")
def symbolic_multi_label_nms(g, boxes, scores, iou_threshold): def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
boxes = unsqueeze(g, boxes, 0) boxes = unsqueeze(g, boxes, 0)
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0) scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
max_output_per_class = g.op('Constant', value_t=torch.tensor([sys.maxsize], dtype=torch.long)) max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long))
iou_threshold = g.op('Constant', value_t=torch.tensor([iou_threshold], dtype=torch.float)) iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float))
nms_out = g.op('NonMaxSuppression', boxes, scores, max_output_per_class, iou_threshold) nms_out = g.op("NonMaxSuppression", boxes, scores, max_output_per_class, iou_threshold)
return squeeze(g, select(g, nms_out, 1, g.op('Constant', value_t=torch.tensor([2], dtype=torch.long))), 1) return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1)
@parse_args('v', 'v', 'f', 'i', 'i', 'i', 'i') @parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _cast_Long(g, squeeze(g, select(g, rois, 1, g.op('Constant', batch_indices = _cast_Long(
value_t=torch.tensor([0], dtype=torch.long))), 1), False) g, squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1), False
rois = select(g, rois, 1, g.op('Constant', value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long))) )
rois = select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
if aligned: if aligned:
warnings.warn("ONNX export of ROIAlign with aligned=True does not match PyTorch when using malformed boxes," warnings.warn(
" ONNX forces ROIs to be 1x1 or larger.") "ONNX export of ROIAlign with aligned=True does not match PyTorch when using malformed boxes,"
" ONNX forces ROIs to be 1x1 or larger."
)
scale = torch.tensor(0.5 / spatial_scale).to(dtype=torch.float) scale = torch.tensor(0.5 / spatial_scale).to(dtype=torch.float)
rois = g.op("Sub", rois, scale) rois = g.op("Sub", rois, scale)
# ONNX doesn't support negative sampling_ratio # ONNX doesn't support negative sampling_ratio
if sampling_ratio < 0: if sampling_ratio < 0:
warnings.warn("ONNX doesn't support negative sampling ratio," warnings.warn(
"therefore is is set to 0 in order to be exported.") "ONNX doesn't support negative sampling ratio," "therefore is is set to 0 in order to be exported."
)
sampling_ratio = 0 sampling_ratio = 0
return g.op('RoiAlign', input, rois, batch_indices, spatial_scale_f=spatial_scale, return g.op(
output_height_i=pooled_height, output_width_i=pooled_width, sampling_ratio_i=sampling_ratio) "RoiAlign",
input,
rois,
batch_indices,
spatial_scale_f=spatial_scale,
output_height_i=pooled_height,
output_width_i=pooled_width,
sampling_ratio_i=sampling_ratio,
)
@parse_args('v', 'v', 'f', 'i', 'i') @parse_args("v", "v", "f", "i", "i")
def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width): def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width):
roi_pool = g.op('MaxRoiPool', input, rois, roi_pool = g.op(
pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale) "MaxRoiPool", input, rois, pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale
)
return roi_pool, None return roi_pool, None
from torch.onnx import register_custom_op_symbolic from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic('torchvision::nms', symbolic_multi_label_nms, _onnx_opset_version)
register_custom_op_symbolic('torchvision::roi_align', roi_align, _onnx_opset_version) register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version)
register_custom_op_symbolic('torchvision::roi_pool', roi_pool, _onnx_opset_version) register_custom_op_symbolic("torchvision::roi_align", roi_align, _onnx_opset_version)
register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version)
from typing import List, Union
import torch import torch
from torch import Tensor from torch import Tensor
from typing import List, Union
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor: def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
...@@ -27,10 +28,11 @@ def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor: ...@@ -27,10 +28,11 @@ def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor:
def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]): def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]):
if isinstance(boxes, (list, tuple)): if isinstance(boxes, (list, tuple)):
for _tensor in boxes: for _tensor in boxes:
assert _tensor.size(1) == 4, \ assert (
'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]' _tensor.size(1) == 4
), "The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]"
elif isinstance(boxes, torch.Tensor): elif isinstance(boxes, torch.Tensor):
assert boxes.size(1) == 5, 'The boxes tensor shape is not correct as Tensor[K, 5]' assert boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]"
else: else:
assert False, 'boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]' assert False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]"
return return
import torch
from torch import Tensor
from typing import Tuple from typing import Tuple
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
import torch
import torchvision import torchvision
from torch import Tensor
from torchvision.extension import _assert_has_ops from torchvision.extension import _assert_has_ops
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
""" """
...@@ -183,13 +185,13 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: ...@@ -183,13 +185,13 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
if in_fmt == out_fmt: if in_fmt == out_fmt:
return boxes.clone() return boxes.clone()
if in_fmt != 'xyxy' and out_fmt != 'xyxy': if in_fmt != "xyxy" and out_fmt != "xyxy":
# convert to xyxy and change in_fmt xyxy # convert to xyxy and change in_fmt xyxy
if in_fmt == "xywh": if in_fmt == "xywh":
boxes = _box_xywh_to_xyxy(boxes) boxes = _box_xywh_to_xyxy(boxes)
elif in_fmt == "cxcywh": elif in_fmt == "cxcywh":
boxes = _box_cxcywh_to_xyxy(boxes) boxes = _box_cxcywh_to_xyxy(boxes)
in_fmt = 'xyxy' in_fmt = "xyxy"
if in_fmt == "xyxy": if in_fmt == "xyxy":
if out_fmt == "xywh": if out_fmt == "xywh":
......
import math import math
from typing import Optional, Tuple
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import init from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from typing import Optional, Tuple from torch.nn.parameter import Parameter
from torchvision.extension import _assert_has_ops from torchvision.extension import _assert_has_ops
...@@ -84,7 +84,9 @@ def deform_conv2d( ...@@ -84,7 +84,9 @@ def deform_conv2d(
"the shape of the offset tensor at dimension 1 is not valid. It should " "the shape of the offset tensor at dimension 1 is not valid. It should "
"be a multiple of 2 * weight.size[2] * weight.size[3].\n" "be a multiple of 2 * weight.size[2] * weight.size[3].\n"
"Got offset.shape[1]={}, while 2 * weight.size[2] * weight.size[3]={}".format( "Got offset.shape[1]={}, while 2 * weight.size[2] * weight.size[3]={}".format(
offset.shape[1], 2 * weights_h * weights_w)) offset.shape[1], 2 * weights_h * weights_w
)
)
return torch.ops.torchvision.deform_conv2d( return torch.ops.torchvision.deform_conv2d(
input, input,
...@@ -92,12 +94,16 @@ def deform_conv2d( ...@@ -92,12 +94,16 @@ def deform_conv2d(
offset, offset,
mask, mask,
bias, bias,
stride_h, stride_w, stride_h,
pad_h, pad_w, stride_w,
dil_h, dil_w, pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps, n_weight_grps,
n_offset_grps, n_offset_grps,
use_mask,) use_mask,
)
class DeformConv2d(nn.Module): class DeformConv2d(nn.Module):
...@@ -119,9 +125,9 @@ class DeformConv2d(nn.Module): ...@@ -119,9 +125,9 @@ class DeformConv2d(nn.Module):
super(DeformConv2d, self).__init__() super(DeformConv2d, self).__init__()
if in_channels % groups != 0: if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups') raise ValueError("in_channels must be divisible by groups")
if out_channels % groups != 0: if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups') raise ValueError("out_channels must be divisible by groups")
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -131,13 +137,14 @@ class DeformConv2d(nn.Module): ...@@ -131,13 +137,14 @@ class DeformConv2d(nn.Module):
self.dilation = _pair(dilation) self.dilation = _pair(dilation)
self.groups = groups self.groups = groups
self.weight = Parameter(torch.empty(out_channels, in_channels // groups, self.weight = Parameter(
self.kernel_size[0], self.kernel_size[1])) torch.empty(out_channels, in_channels // groups, self.kernel_size[0], self.kernel_size[1])
)
if bias: if bias:
self.bias = Parameter(torch.empty(out_channels)) self.bias = Parameter(torch.empty(out_channels))
else: else:
self.register_parameter('bias', None) self.register_parameter("bias", None)
self.reset_parameters() self.reset_parameters()
...@@ -160,18 +167,26 @@ class DeformConv2d(nn.Module): ...@@ -160,18 +167,26 @@ class DeformConv2d(nn.Module):
out_height, out_width]): masks to be applied for each position in the out_height, out_width]): masks to be applied for each position in the
convolution kernel. convolution kernel.
""" """
return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride, return deform_conv2d(
padding=self.padding, dilation=self.dilation, mask=mask) input,
offset,
self.weight,
self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
mask=mask,
)
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.__class__.__name__ + '(' s = self.__class__.__name__ + "("
s += '{in_channels}' s += "{in_channels}"
s += ', {out_channels}' s += ", {out_channels}"
s += ', kernel_size={kernel_size}' s += ", kernel_size={kernel_size}"
s += ', stride={stride}' s += ", stride={stride}"
s += ', padding={padding}' if self.padding != (0, 0) else '' s += ", padding={padding}" if self.padding != (0, 0) else ""
s += ', dilation={dilation}' if self.dilation != (1, 1) else '' s += ", dilation={dilation}" if self.dilation != (1, 1) else ""
s += ', groups={groups}' if self.groups != 1 else '' s += ", groups={groups}" if self.groups != 1 else ""
s += ', bias=False' if self.bias is None else '' s += ", bias=False" if self.bias is None else ""
s += ')' s += ")"
return s.format(**self.__dict__) return s.format(**self.__dict__)
from collections import OrderedDict from collections import OrderedDict
from typing import Tuple, List, Dict, Optional
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, Tensor from torch import nn, Tensor
from typing import Tuple, List, Dict, Optional
class ExtraFPNBlock(nn.Module): class ExtraFPNBlock(nn.Module):
""" """
...@@ -21,6 +20,7 @@ class ExtraFPNBlock(nn.Module): ...@@ -21,6 +20,7 @@ class ExtraFPNBlock(nn.Module):
of the FPN of the FPN
names (List[str]): the extended set of names for the results names (List[str]): the extended set of names for the results
""" """
def forward( def forward(
self, self,
results: List[Tensor], results: List[Tensor],
...@@ -67,6 +67,7 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -67,6 +67,7 @@ class FeaturePyramidNetwork(nn.Module):
>>> ('feat3', torch.Size([1, 5, 8, 8]))] >>> ('feat3', torch.Size([1, 5, 8, 8]))]
""" """
def __init__( def __init__(
self, self,
in_channels_list: List[int], in_channels_list: List[int],
...@@ -165,6 +166,7 @@ class LastLevelMaxPool(ExtraFPNBlock): ...@@ -165,6 +166,7 @@ class LastLevelMaxPool(ExtraFPNBlock):
""" """
Applies a max_pool2d on top of the last feature map Applies a max_pool2d on top of the last feature map
""" """
def forward( def forward(
self, self,
x: List[Tensor], x: List[Tensor],
...@@ -180,6 +182,7 @@ class LastLevelP6P7(ExtraFPNBlock): ...@@ -180,6 +182,7 @@ class LastLevelP6P7(ExtraFPNBlock):
""" """
This module is used in RetinaNet to generate extra layers, P6 and P7. This module is used in RetinaNet to generate extra layers, P6 and P7.
""" """
def __init__(self, in_channels: int, out_channels: int): def __init__(self, in_channels: int, out_channels: int):
super(LastLevelP6P7, self).__init__() super(LastLevelP6P7, self).__init__()
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
......
...@@ -31,9 +31,7 @@ def sigmoid_focal_loss( ...@@ -31,9 +31,7 @@ def sigmoid_focal_loss(
Loss tensor with the reduction option applied. Loss tensor with the reduction option applied.
""" """
p = torch.sigmoid(inputs) p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits( ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
inputs, targets, reduction="none"
)
p_t = p * targets + (1 - p) * (1 - targets) p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma) loss = ce_loss * ((1 - p_t) ** gamma)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment