"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "076052f12eef6c2c64be85ca9c89054167cc1f24"
Unverified Commit 732fc0bd authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding multiweight support for inception prototype model (#4821)

* Moving original builder at the bottom of the page to use proper typing.

* Adding multiweight support to inception.

* Update doc.
parent 00b963ac
...@@ -38,7 +38,7 @@ The weights of the Inception V3 model are ported from the original paper rather ...@@ -38,7 +38,7 @@ The weights of the Inception V3 model are ported from the original paper rather
Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model use the following command: Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model use the following command:
``` ```
torchrun --nproc_per_node=8 train.py --model inception_v3 torchrun --nproc_per_node=8 train.py --model inception_v3\
--val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained --val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained
``` ```
......
...@@ -26,43 +26,6 @@ InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Ten ...@@ -26,43 +26,6 @@ InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Ten
_InceptionOutputs = InceptionOutputs _InceptionOutputs = InceptionOutputs
def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3":
r"""Inception v3 model architecture from
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
The required minimum input size of the model is 75x75.
.. note::
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of
N x 3 x 299 x 299, so ensure your images are sized accordingly.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
aux_logits (bool): If True, add an auxiliary branch that can improve training.
Default: *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" in kwargs:
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
else:
original_aux_logits = True
kwargs["init_weights"] = False # we are loading weights from a pretrained model
model = Inception3(**kwargs)
state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
return model
return Inception3(**kwargs)
class Inception3(nn.Module): class Inception3(nn.Module):
def __init__( def __init__(
self, self,
...@@ -442,3 +405,40 @@ class BasicConv2d(nn.Module): ...@@ -442,3 +405,40 @@ class BasicConv2d(nn.Module):
x = self.conv(x) x = self.conv(x)
x = self.bn(x) x = self.bn(x)
return F.relu(x, inplace=True) return F.relu(x, inplace=True)
def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> Inception3:
r"""Inception v3 model architecture from
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
The required minimum input size of the model is 75x75.
.. note::
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of
N x 3 x 299 x 299, so ensure your images are sized accordingly.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
aux_logits (bool): If True, add an auxiliary branch that can improve training.
Default: *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" in kwargs:
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
else:
original_aux_logits = True
kwargs["init_weights"] = False # we are loading weights from a pretrained model
model = Inception3(**kwargs)
state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
return model
return Inception3(**kwargs)
...@@ -2,6 +2,7 @@ from .alexnet import * ...@@ -2,6 +2,7 @@ from .alexnet import *
from .densenet import * from .densenet import *
from .efficientnet import * from .efficientnet import *
from .googlenet import * from .googlenet import *
from .inception import *
from .mnasnet import * from .mnasnet import *
from .mobilenetv2 import * from .mobilenetv2 import *
from .mobilenetv3 import * from .mobilenetv3 import *
......
import warnings
from functools import partial
from typing import Any, Optional
from torchvision.transforms.functional import InterpolationMode
from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception3Weights", "inception_v3"]
_common_meta = {"size": (299, 299), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class Inception3Weights(Weights):
ImageNet1K_TFV1 = WeightEntry(
url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3",
"acc@1": 77.294,
"acc@5": 93.450,
},
)
def inception_v3(weights: Optional[Inception3Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = Inception3Weights.ImageNet1K_TFV1 if kwargs.pop("pretrained") else None
weights = Inception3Weights.verify(weights)
original_aux_logits = kwargs.get("aux_logits", True)
if weights is not None:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
kwargs["num_classes"] = len(weights.meta["categories"])
model = Inception3(**kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment