Unverified Commit 72054ca0 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding multiweight support for googlenet prototype model (#4813)

* Move model builder at the bottom of the file, so we can use proper typing.

* Adding GoogLeNet with multi-weight support.

* Simplify expression.
parent 61399976
...@@ -25,43 +25,6 @@ GoogLeNetOutputs.__annotations__ = {"logits": Tensor, "aux_logits2": Optional[Te ...@@ -25,43 +25,6 @@ GoogLeNetOutputs.__annotations__ = {"logits": Tensor, "aux_logits2": Optional[Te
_GoogLeNetOutputs = GoogLeNetOutputs _GoogLeNetOutputs = GoogLeNetOutputs
def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "GoogLeNet":
r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
The required minimum input size of the model is 15x15.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Default: *False* when pretrained is True otherwise *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" not in kwargs:
kwargs["aux_logits"] = False
if kwargs["aux_logits"]:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
model = GoogLeNet(**kwargs)
state_dict = load_state_dict_from_url(model_urls["googlenet"], progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
return model
return GoogLeNet(**kwargs)
class GoogLeNet(nn.Module): class GoogLeNet(nn.Module):
__constants__ = ["aux_logits", "transform_input"] __constants__ = ["aux_logits", "transform_input"]
...@@ -311,3 +274,40 @@ class BasicConv2d(nn.Module): ...@@ -311,3 +274,40 @@ class BasicConv2d(nn.Module):
x = self.conv(x) x = self.conv(x)
x = self.bn(x) x = self.bn(x)
return F.relu(x, inplace=True) return F.relu(x, inplace=True)
def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> GoogLeNet:
r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
The required minimum input size of the model is 15x15.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Default: *False* when pretrained is True otherwise *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" not in kwargs:
kwargs["aux_logits"] = False
if kwargs["aux_logits"]:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
model = GoogLeNet(**kwargs)
state_dict = load_state_dict_from_url(model_urls["googlenet"], progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
return model
return GoogLeNet(**kwargs)
from .alexnet import * from .alexnet import *
from .resnet import *
from .densenet import * from .densenet import *
from .vgg import *
from .efficientnet import * from .efficientnet import *
from .mobilenetv3 import * from .googlenet import *
from .mobilenetv2 import *
from .mnasnet import * from .mnasnet import *
from .mobilenetv2 import *
from .mobilenetv3 import *
from .regnet import * from .regnet import *
from .resnet import *
from .shufflenetv2 import * from .shufflenetv2 import *
from .vgg import *
from . import detection from . import detection
from . import quantization from . import quantization
from . import segmentation from . import segmentation
......
import warnings
from functools import partial
from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNetWeights", "googlenet"]
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class GoogLeNetWeights(Weights):
ImageNet1K_TheCodezV1 = WeightEntry(
url="https://download.pytorch.org/models/googlenet-1378be20.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/TheCodez/examples/blob/inception/imagenet/README.md#googlenet",
"acc@1": 69.778,
"acc@5": 89.530,
},
)
def googlenet(weights: Optional[GoogLeNetWeights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = GoogLeNetWeights.ImageNet1K_TheCodezV1 if kwargs.pop("pretrained") else None
weights = GoogLeNetWeights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", False)
if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if original_aux_logits:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
)
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
kwargs["num_classes"] = len(weights.meta["categories"])
model = GoogLeNet(**kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
model.aux2 = None # type: ignore[assignment]
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment