Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7b1b68d7
Unverified
Commit
7b1b68d7
authored
Oct 22, 2021
by
Joao Gomes
Committed by
GitHub
Oct 22, 2021
Browse files
Multi-weight support for MobileNetV3 prototype models (#4723)
* Adding multiweight support for mobilenetv3 prototype
parent
b280c318
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
94 additions
and
3 deletions
+94
-3
torchvision/models/mobilenetv3.py
torchvision/models/mobilenetv3.py
+3
-3
torchvision/prototype/models/__init__.py
torchvision/prototype/models/__init__.py
+1
-0
torchvision/prototype/models/mobilenetv3.py
torchvision/prototype/models/mobilenetv3.py
+90
-0
No files found.
torchvision/models/mobilenetv3.py
View file @
7b1b68d7
...
@@ -281,7 +281,7 @@ def _mobilenet_v3_conf(
...
@@ -281,7 +281,7 @@ def _mobilenet_v3_conf(
return
inverted_residual_setting
,
last_channel
return
inverted_residual_setting
,
last_channel
def
_mobilenet_v3
_model
(
def
_mobilenet_v3
(
arch
:
str
,
arch
:
str
,
inverted_residual_setting
:
List
[
InvertedResidualConfig
],
inverted_residual_setting
:
List
[
InvertedResidualConfig
],
last_channel
:
int
,
last_channel
:
int
,
...
@@ -309,7 +309,7 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs
...
@@ -309,7 +309,7 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs
"""
"""
arch
=
"mobilenet_v3_large"
arch
=
"mobilenet_v3_large"
inverted_residual_setting
,
last_channel
=
_mobilenet_v3_conf
(
arch
,
**
kwargs
)
inverted_residual_setting
,
last_channel
=
_mobilenet_v3_conf
(
arch
,
**
kwargs
)
return
_mobilenet_v3
_model
(
arch
,
inverted_residual_setting
,
last_channel
,
pretrained
,
progress
,
**
kwargs
)
return
_mobilenet_v3
(
arch
,
inverted_residual_setting
,
last_channel
,
pretrained
,
progress
,
**
kwargs
)
def
mobilenet_v3_small
(
pretrained
:
bool
=
False
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
MobileNetV3
:
def
mobilenet_v3_small
(
pretrained
:
bool
=
False
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
MobileNetV3
:
...
@@ -323,4 +323,4 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs
...
@@ -323,4 +323,4 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs
"""
"""
arch
=
"mobilenet_v3_small"
arch
=
"mobilenet_v3_small"
inverted_residual_setting
,
last_channel
=
_mobilenet_v3_conf
(
arch
,
**
kwargs
)
inverted_residual_setting
,
last_channel
=
_mobilenet_v3_conf
(
arch
,
**
kwargs
)
return
_mobilenet_v3
_model
(
arch
,
inverted_residual_setting
,
last_channel
,
pretrained
,
progress
,
**
kwargs
)
return
_mobilenet_v3
(
arch
,
inverted_residual_setting
,
last_channel
,
pretrained
,
progress
,
**
kwargs
)
torchvision/prototype/models/__init__.py
View file @
7b1b68d7
...
@@ -3,5 +3,6 @@ from .resnet import *
...
@@ -3,5 +3,6 @@ from .resnet import *
from
.densenet
import
*
from
.densenet
import
*
from
.vgg
import
*
from
.vgg
import
*
from
.efficientnet
import
*
from
.efficientnet
import
*
from
.mobilenetv3
import
*
from
.
import
detection
from
.
import
detection
from
.
import
quantization
from
.
import
quantization
torchvision/prototype/models/mobilenetv3.py
0 → 100644
View file @
7b1b68d7
import
warnings
from
functools
import
partial
from
typing
import
Any
,
Optional
,
List
from
torchvision.transforms.functional
import
InterpolationMode
from
...models.mobilenetv3
import
MobileNetV3
,
_mobilenet_v3_conf
,
InvertedResidualConfig
from
..transforms.presets
import
ImageNetEval
from
._api
import
Weights
,
WeightEntry
from
._meta
import
_IMAGENET_CATEGORIES
__all__
=
[
"MobileNetV3"
,
"MobileNetV3LargeWeights"
,
"MobileNetV3SmallWeights"
,
"mobilenet_v3_large"
,
"mobilenet_v3_small"
,
]
def
_mobilenet_v3
(
inverted_residual_setting
:
List
[
InvertedResidualConfig
],
last_channel
:
int
,
weights
:
Optional
[
Weights
],
progress
:
bool
,
**
kwargs
:
Any
,
)
->
MobileNetV3
:
if
weights
is
not
None
:
kwargs
[
"num_classes"
]
=
len
(
weights
.
meta
[
"categories"
])
model
=
MobileNetV3
(
inverted_residual_setting
,
last_channel
,
**
kwargs
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
state_dict
(
progress
=
progress
))
return
model
_common_meta
=
{
"size"
:
(
224
,
224
),
"categories"
:
_IMAGENET_CATEGORIES
,
"interpolation"
:
InterpolationMode
.
BILINEAR
}
class
MobileNetV3LargeWeights
(
Weights
):
ImageNet1K_RefV1
=
WeightEntry
(
url
=
"https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth"
,
transforms
=
partial
(
ImageNetEval
,
crop_size
=
224
),
meta
=
{
**
_common_meta
,
"recipe"
:
"https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small"
,
"acc@1"
:
74.042
,
"acc@5"
:
91.340
,
},
)
class
MobileNetV3SmallWeights
(
Weights
):
ImageNet1K_RefV1
=
WeightEntry
(
url
=
"https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth"
,
transforms
=
partial
(
ImageNetEval
,
crop_size
=
224
),
meta
=
{
**
_common_meta
,
"recipe"
:
"https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small"
,
"acc@1"
:
67.668
,
"acc@5"
:
87.402
,
},
)
def
mobilenet_v3_large
(
weights
:
Optional
[
MobileNetV3LargeWeights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
MobileNetV3
:
if
"pretrained"
in
kwargs
:
warnings
.
warn
(
"The argument pretrained is deprecated, please use weights instead."
)
weights
=
MobileNetV3LargeWeights
.
ImageNet1K_RefV1
if
kwargs
.
pop
(
"pretrained"
)
else
None
weights
=
MobileNetV3LargeWeights
.
verify
(
weights
)
inverted_residual_setting
,
last_channel
=
_mobilenet_v3_conf
(
"mobilenet_v3_large"
,
**
kwargs
)
return
_mobilenet_v3
(
inverted_residual_setting
,
last_channel
,
weights
,
progress
,
**
kwargs
)
def
mobilenet_v3_small
(
weights
:
Optional
[
MobileNetV3SmallWeights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
MobileNetV3
:
if
"pretrained"
in
kwargs
:
warnings
.
warn
(
"The argument pretrained is deprecated, please use weights instead."
)
weights
=
MobileNetV3SmallWeights
.
ImageNet1K_RefV1
if
kwargs
.
pop
(
"pretrained"
)
else
None
weights
=
MobileNetV3SmallWeights
.
verify
(
weights
)
inverted_residual_setting
,
last_channel
=
_mobilenet_v3_conf
(
"mobilenet_v3_small"
,
**
kwargs
)
return
_mobilenet_v3
(
inverted_residual_setting
,
last_channel
,
weights
,
progress
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment