Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d8654bb0
Unverified
Commit
d8654bb0
authored
Mar 07, 2022
by
Vasilis Vryniotis
Committed by
GitHub
Mar 07, 2022
Browse files
Refactor preset transforms (#5562)
* Refactor preset transforms * Making presets public.
parent
2b5ab1bc
Changes
40
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
132 additions
and
117 deletions
+132
-117
torchvision/prototype/models/mobilenetv3.py
torchvision/prototype/models/mobilenetv3.py
+4
-4
torchvision/prototype/models/optical_flow/raft.py
torchvision/prototype/models/optical_flow/raft.py
+9
-9
torchvision/prototype/models/quantization/googlenet.py
torchvision/prototype/models/quantization/googlenet.py
+2
-2
torchvision/prototype/models/quantization/inception.py
torchvision/prototype/models/quantization/inception.py
+2
-2
torchvision/prototype/models/quantization/mobilenetv2.py
torchvision/prototype/models/quantization/mobilenetv2.py
+2
-2
torchvision/prototype/models/quantization/mobilenetv3.py
torchvision/prototype/models/quantization/mobilenetv3.py
+2
-2
torchvision/prototype/models/quantization/resnet.py
torchvision/prototype/models/quantization/resnet.py
+6
-6
torchvision/prototype/models/quantization/shufflenetv2.py
torchvision/prototype/models/quantization/shufflenetv2.py
+3
-3
torchvision/prototype/models/regnet.py
torchvision/prototype/models/regnet.py
+29
-29
torchvision/prototype/models/resnet.py
torchvision/prototype/models/resnet.py
+17
-17
torchvision/prototype/models/segmentation/deeplabv3.py
torchvision/prototype/models/segmentation/deeplabv3.py
+4
-4
torchvision/prototype/models/segmentation/fcn.py
torchvision/prototype/models/segmentation/fcn.py
+3
-3
torchvision/prototype/models/segmentation/lraspp.py
torchvision/prototype/models/segmentation/lraspp.py
+2
-2
torchvision/prototype/models/shufflenetv2.py
torchvision/prototype/models/shufflenetv2.py
+3
-3
torchvision/prototype/models/squeezenet.py
torchvision/prototype/models/squeezenet.py
+3
-3
torchvision/prototype/models/vgg.py
torchvision/prototype/models/vgg.py
+13
-10
torchvision/prototype/models/video/resnet.py
torchvision/prototype/models/video/resnet.py
+4
-4
torchvision/prototype/models/vision_transformer.py
torchvision/prototype/models/vision_transformer.py
+5
-5
torchvision/prototype/transforms/__init__.py
torchvision/prototype/transforms/__init__.py
+7
-1
torchvision/prototype/transforms/_presets.py
torchvision/prototype/transforms/_presets.py
+12
-6
No files found.
torchvision/prototype/models/mobilenetv3.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
Optional
,
List
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
...models.mobilenetv3
import
MobileNetV3
,
_mobilenet_v3_conf
,
InvertedResidualConfig
...
...
@@ -51,7 +51,7 @@ _COMMON_META = {
class
MobileNet_V3_Large_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
5483032
,
...
...
@@ -62,7 +62,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
5483032
,
...
...
@@ -77,7 +77,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
class
MobileNet_V3_Small_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
2542856
,
...
...
torchvision/prototype/models/optical_flow/raft.py
View file @
d8654bb0
...
...
@@ -4,7 +4,7 @@ from torch.nn.modules.batchnorm import BatchNorm2d
from
torch.nn.modules.instancenorm
import
InstanceNorm2d
from
torchvision.models.optical_flow
import
RAFT
from
torchvision.models.optical_flow.raft
import
_raft
,
BottleneckBlock
,
ResidualBlock
from
torchvision.prototype.transforms
import
Raft
Eval
from
torchvision.prototype.transforms
import
OpticalFlow
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
.._api
import
WeightsEnum
...
...
@@ -33,7 +33,7 @@ class Raft_Large_Weights(WeightsEnum):
C_T_V1
=
Weights
(
# Chairs + Things, ported from original paper repo (raft-things.pth)
url
=
"https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth"
,
transforms
=
Raft
Eval
,
transforms
=
OpticalFlow
Eval
,
meta
=
{
**
_COMMON_META
,
"num_params"
:
5257536
,
...
...
@@ -48,7 +48,7 @@ class Raft_Large_Weights(WeightsEnum):
C_T_V2
=
Weights
(
# Chairs + Things
url
=
"https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth"
,
transforms
=
Raft
Eval
,
transforms
=
OpticalFlow
Eval
,
meta
=
{
**
_COMMON_META
,
"num_params"
:
5257536
,
...
...
@@ -63,7 +63,7 @@ class Raft_Large_Weights(WeightsEnum):
C_T_SKHT_V1
=
Weights
(
# Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth)
url
=
"https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth"
,
transforms
=
Raft
Eval
,
transforms
=
OpticalFlow
Eval
,
meta
=
{
**
_COMMON_META
,
"num_params"
:
5257536
,
...
...
@@ -78,7 +78,7 @@ class Raft_Large_Weights(WeightsEnum):
# Chairs + Things + (Sintel + Kitti + HD1K + Things_clean)
# Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel
url
=
"https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth"
,
transforms
=
Raft
Eval
,
transforms
=
OpticalFlow
Eval
,
meta
=
{
**
_COMMON_META
,
"num_params"
:
5257536
,
...
...
@@ -91,7 +91,7 @@ class Raft_Large_Weights(WeightsEnum):
C_T_SKHT_K_V1
=
Weights
(
# Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth)
url
=
"https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth"
,
transforms
=
Raft
Eval
,
transforms
=
OpticalFlow
Eval
,
meta
=
{
**
_COMMON_META
,
"num_params"
:
5257536
,
...
...
@@ -106,7 +106,7 @@ class Raft_Large_Weights(WeightsEnum):
# Same as CT_SKHT with extra fine-tuning on Kitti
# Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti
url
=
"https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth"
,
transforms
=
Raft
Eval
,
transforms
=
OpticalFlow
Eval
,
meta
=
{
**
_COMMON_META
,
"num_params"
:
5257536
,
...
...
@@ -122,7 +122,7 @@ class Raft_Small_Weights(WeightsEnum):
C_T_V1
=
Weights
(
# Chairs + Things, ported from original paper repo (raft-small.pth)
url
=
"https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth"
,
transforms
=
Raft
Eval
,
transforms
=
OpticalFlow
Eval
,
meta
=
{
**
_COMMON_META
,
"num_params"
:
990162
,
...
...
@@ -136,7 +136,7 @@ class Raft_Small_Weights(WeightsEnum):
C_T_V2
=
Weights
(
# Chairs + Things
url
=
"https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth"
,
transforms
=
Raft
Eval
,
transforms
=
OpticalFlow
Eval
,
meta
=
{
**
_COMMON_META
,
"num_params"
:
990162
,
...
...
torchvision/prototype/models/quantization/googlenet.py
View file @
d8654bb0
...
...
@@ -2,7 +2,7 @@ import warnings
from
functools
import
partial
from
typing
import
Any
,
Optional
,
Union
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.quantization.googlenet
import
(
...
...
@@ -26,7 +26,7 @@ __all__ = [
class
GoogLeNet_QuantizedWeights
(
WeightsEnum
):
IMAGENET1K_FBGEMM_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
"task"
:
"image_classification"
,
"architecture"
:
"GoogLeNet"
,
...
...
torchvision/prototype/models/quantization/inception.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
Optional
,
Union
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.quantization.inception
import
(
...
...
@@ -25,7 +25,7 @@ __all__ = [
class
Inception_V3_QuantizedWeights
(
WeightsEnum
):
IMAGENET1K_FBGEMM_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
299
,
resize_size
=
342
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
299
,
resize_size
=
342
),
meta
=
{
"task"
:
"image_classification"
,
"architecture"
:
"InceptionV3"
,
...
...
torchvision/prototype/models/quantization/mobilenetv2.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
Optional
,
Union
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.quantization.mobilenetv2
import
(
...
...
@@ -26,7 +26,7 @@ __all__ = [
class
MobileNet_V2_QuantizedWeights
(
WeightsEnum
):
IMAGENET1K_QNNPACK_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
"task"
:
"image_classification"
,
"architecture"
:
"MobileNetV2"
,
...
...
torchvision/prototype/models/quantization/mobilenetv3.py
View file @
d8654bb0
...
...
@@ -2,7 +2,7 @@ from functools import partial
from
typing
import
Any
,
List
,
Optional
,
Union
import
torch
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.quantization.mobilenetv3
import
(
...
...
@@ -59,7 +59,7 @@ def _mobilenet_v3_model(
class
MobileNet_V3_Large_QuantizedWeights
(
WeightsEnum
):
IMAGENET1K_QNNPACK_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
"task"
:
"image_classification"
,
"architecture"
:
"MobileNetV3"
,
...
...
torchvision/prototype/models/quantization/resnet.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
List
,
Optional
,
Type
,
Union
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.quantization.resnet
import
(
...
...
@@ -68,7 +68,7 @@ _COMMON_META = {
class
ResNet18_QuantizedWeights
(
WeightsEnum
):
IMAGENET1K_FBGEMM_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNet"
,
...
...
@@ -85,7 +85,7 @@ class ResNet18_QuantizedWeights(WeightsEnum):
class
ResNet50_QuantizedWeights
(
WeightsEnum
):
IMAGENET1K_FBGEMM_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNet"
,
...
...
@@ -98,7 +98,7 @@ class ResNet50_QuantizedWeights(WeightsEnum):
)
IMAGENET1K_FBGEMM_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNet"
,
...
...
@@ -115,7 +115,7 @@ class ResNet50_QuantizedWeights(WeightsEnum):
class
ResNeXt101_32X8D_QuantizedWeights
(
WeightsEnum
):
IMAGENET1K_FBGEMM_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNeXt"
,
...
...
@@ -128,7 +128,7 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
)
IMAGENET1K_FBGEMM_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNeXt"
,
...
...
torchvision/prototype/models/quantization/shufflenetv2.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
List
,
Optional
,
Union
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.quantization.shufflenetv2
import
(
...
...
@@ -67,7 +67,7 @@ _COMMON_META = {
class
ShuffleNet_V2_X0_5_QuantizedWeights
(
WeightsEnum
):
IMAGENET1K_FBGEMM_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
1366792
,
...
...
@@ -82,7 +82,7 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
class
ShuffleNet_V2_X1_0_QuantizedWeights
(
WeightsEnum
):
IMAGENET1K_FBGEMM_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
2278604
,
...
...
torchvision/prototype/models/regnet.py
View file @
d8654bb0
...
...
@@ -2,7 +2,7 @@ from functools import partial
from
typing
import
Any
,
Optional
from
torch
import
nn
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
...models.regnet
import
RegNet
,
BlockParams
...
...
@@ -77,7 +77,7 @@ def _regnet(
class
RegNet_Y_400MF_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
4344144
,
...
...
@@ -88,7 +88,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
4344144
,
...
...
@@ -103,7 +103,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum):
class
RegNet_Y_800MF_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
6432512
,
...
...
@@ -114,7 +114,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
6432512
,
...
...
@@ -129,7 +129,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum):
class
RegNet_Y_1_6GF_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
11202430
,
...
...
@@ -140,7 +140,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
11202430
,
...
...
@@ -155,7 +155,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum):
class
RegNet_Y_3_2GF_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
19436338
,
...
...
@@ -166,7 +166,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
19436338
,
...
...
@@ -181,7 +181,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum):
class
RegNet_Y_8GF_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
39381472
,
...
...
@@ -192,7 +192,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
39381472
,
...
...
@@ -207,7 +207,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum):
class
RegNet_Y_16GF_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
83590140
,
...
...
@@ -218,7 +218,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
83590140
,
...
...
@@ -233,7 +233,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
class
RegNet_Y_32GF_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
145046770
,
...
...
@@ -244,7 +244,7 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
145046770
,
...
...
@@ -264,7 +264,7 @@ class RegNet_Y_128GF_Weights(WeightsEnum):
class
RegNet_X_400MF_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
5495976
,
...
...
@@ -275,7 +275,7 @@ class RegNet_X_400MF_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
5495976
,
...
...
@@ -290,7 +290,7 @@ class RegNet_X_400MF_Weights(WeightsEnum):
class
RegNet_X_800MF_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
7259656
,
...
...
@@ -301,7 +301,7 @@ class RegNet_X_800MF_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
7259656
,
...
...
@@ -316,7 +316,7 @@ class RegNet_X_800MF_Weights(WeightsEnum):
class
RegNet_X_1_6GF_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
9190136
,
...
...
@@ -327,7 +327,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
9190136
,
...
...
@@ -342,7 +342,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum):
class
RegNet_X_3_2GF_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
15296552
,
...
...
@@ -353,7 +353,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
15296552
,
...
...
@@ -368,7 +368,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum):
class
RegNet_X_8GF_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
39572648
,
...
...
@@ -379,7 +379,7 @@ class RegNet_X_8GF_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
39572648
,
...
...
@@ -394,7 +394,7 @@ class RegNet_X_8GF_Weights(WeightsEnum):
class
RegNet_X_16GF_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
54278536
,
...
...
@@ -405,7 +405,7 @@ class RegNet_X_16GF_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
54278536
,
...
...
@@ -420,7 +420,7 @@ class RegNet_X_16GF_Weights(WeightsEnum):
class
RegNet_X_32GF_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
107811560
,
...
...
@@ -431,7 +431,7 @@ class RegNet_X_32GF_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
107811560
,
...
...
torchvision/prototype/models/resnet.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
List
,
Optional
,
Type
,
Union
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
...models.resnet
import
BasicBlock
,
Bottleneck
,
ResNet
...
...
@@ -63,7 +63,7 @@ _COMMON_META = {
class
ResNet18_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/resnet18-f37072fd.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNet"
,
...
...
@@ -80,7 +80,7 @@ class ResNet18_Weights(WeightsEnum):
class
ResNet34_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/resnet34-b627a593.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNet"
,
...
...
@@ -97,7 +97,7 @@ class ResNet34_Weights(WeightsEnum):
class
ResNet50_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/resnet50-0676ba61.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNet"
,
...
...
@@ -110,7 +110,7 @@ class ResNet50_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/resnet50-11ad3fa6.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNet"
,
...
...
@@ -127,7 +127,7 @@ class ResNet50_Weights(WeightsEnum):
class
ResNet101_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/resnet101-63fe2227.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNet"
,
...
...
@@ -140,7 +140,7 @@ class ResNet101_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/resnet101-cd907fc2.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNet"
,
...
...
@@ -157,7 +157,7 @@ class ResNet101_Weights(WeightsEnum):
class
ResNet152_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/resnet152-394f9c45.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNet"
,
...
...
@@ -170,7 +170,7 @@ class ResNet152_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/resnet152-f82ba261.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNet"
,
...
...
@@ -187,7 +187,7 @@ class ResNet152_Weights(WeightsEnum):
class
ResNeXt50_32X4D_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNeXt"
,
...
...
@@ -200,7 +200,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNeXt"
,
...
...
@@ -217,7 +217,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
class
ResNeXt101_32X8D_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNeXt"
,
...
...
@@ -230,7 +230,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"ResNeXt"
,
...
...
@@ -247,7 +247,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
class
Wide_ResNet50_2_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"WideResNet"
,
...
...
@@ -260,7 +260,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"WideResNet"
,
...
...
@@ -277,7 +277,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
class
Wide_ResNet101_2_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"WideResNet"
,
...
...
@@ -290,7 +290,7 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"WideResNet"
,
...
...
torchvision/prototype/models/segmentation/deeplabv3.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
Optional
from
torchvision.prototype.transforms
import
Voc
Eval
from
torchvision.prototype.transforms
import
SemanticSegmentation
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.segmentation.deeplabv3
import
DeepLabV3
,
_deeplabv3_mobilenetv3
,
_deeplabv3_resnet
...
...
@@ -36,7 +36,7 @@ _COMMON_META = {
class
DeepLabV3_ResNet50_Weights
(
WeightsEnum
):
COCO_WITH_VOC_LABELS_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth"
,
transforms
=
partial
(
Voc
Eval
,
resize_size
=
520
),
transforms
=
partial
(
SemanticSegmentation
Eval
,
resize_size
=
520
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
42004074
,
...
...
@@ -51,7 +51,7 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum):
class
DeepLabV3_ResNet101_Weights
(
WeightsEnum
):
COCO_WITH_VOC_LABELS_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth"
,
transforms
=
partial
(
Voc
Eval
,
resize_size
=
520
),
transforms
=
partial
(
SemanticSegmentation
Eval
,
resize_size
=
520
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
60996202
,
...
...
@@ -66,7 +66,7 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum):
class
DeepLabV3_MobileNet_V3_Large_Weights
(
WeightsEnum
):
COCO_WITH_VOC_LABELS_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth"
,
transforms
=
partial
(
Voc
Eval
,
resize_size
=
520
),
transforms
=
partial
(
SemanticSegmentation
Eval
,
resize_size
=
520
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
11029328
,
...
...
torchvision/prototype/models/segmentation/fcn.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
Optional
from
torchvision.prototype.transforms
import
Voc
Eval
from
torchvision.prototype.transforms
import
SemanticSegmentation
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.segmentation.fcn
import
FCN
,
_fcn_resnet
...
...
@@ -26,7 +26,7 @@ _COMMON_META = {
class
FCN_ResNet50_Weights
(
WeightsEnum
):
COCO_WITH_VOC_LABELS_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth"
,
transforms
=
partial
(
Voc
Eval
,
resize_size
=
520
),
transforms
=
partial
(
SemanticSegmentation
Eval
,
resize_size
=
520
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
35322218
,
...
...
@@ -41,7 +41,7 @@ class FCN_ResNet50_Weights(WeightsEnum):
class
FCN_ResNet101_Weights
(
WeightsEnum
):
COCO_WITH_VOC_LABELS_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth"
,
transforms
=
partial
(
Voc
Eval
,
resize_size
=
520
),
transforms
=
partial
(
SemanticSegmentation
Eval
,
resize_size
=
520
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
54314346
,
...
...
torchvision/prototype/models/segmentation/lraspp.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
Optional
from
torchvision.prototype.transforms
import
Voc
Eval
from
torchvision.prototype.transforms
import
SemanticSegmentation
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.segmentation.lraspp
import
LRASPP
,
_lraspp_mobilenetv3
...
...
@@ -17,7 +17,7 @@ __all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_l
class
LRASPP_MobileNet_V3_Large_Weights
(
WeightsEnum
):
COCO_WITH_VOC_LABELS_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth"
,
transforms
=
partial
(
Voc
Eval
,
resize_size
=
520
),
transforms
=
partial
(
SemanticSegmentation
Eval
,
resize_size
=
520
),
meta
=
{
"task"
:
"image_semantic_segmentation"
,
"architecture"
:
"LRASPP"
,
...
...
torchvision/prototype/models/shufflenetv2.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
Optional
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
...models.shufflenetv2
import
ShuffleNetV2
...
...
@@ -55,7 +55,7 @@ _COMMON_META = {
class
ShuffleNet_V2_X0_5_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
1366792
,
...
...
@@ -69,7 +69,7 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
class
ShuffleNet_V2_X1_0_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
2278604
,
...
...
torchvision/prototype/models/squeezenet.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
Optional
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
...models.squeezenet
import
SqueezeNet
...
...
@@ -27,7 +27,7 @@ _COMMON_META = {
class
SqueezeNet1_0_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"min_size"
:
(
21
,
21
),
...
...
@@ -42,7 +42,7 @@ class SqueezeNet1_0_Weights(WeightsEnum):
class
SqueezeNet1_1_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"min_size"
:
(
17
,
17
),
...
...
torchvision/prototype/models/vgg.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
Optional
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
...models.vgg
import
VGG
,
make_layers
,
cfgs
...
...
@@ -55,7 +55,7 @@ _COMMON_META = {
class
VGG11_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/vgg11-8a719046.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
132863336
,
...
...
@@ -69,7 +69,7 @@ class VGG11_Weights(WeightsEnum):
class
VGG11_BN_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/vgg11_bn-6002323d.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
132868840
,
...
...
@@ -83,7 +83,7 @@ class VGG11_BN_Weights(WeightsEnum):
class
VGG13_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/vgg13-19584684.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
133047848
,
...
...
@@ -97,7 +97,7 @@ class VGG13_Weights(WeightsEnum):
class
VGG13_BN_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/vgg13_bn-abd245e5.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
133053736
,
...
...
@@ -111,7 +111,7 @@ class VGG13_BN_Weights(WeightsEnum):
class
VGG16_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/vgg16-397923af.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
138357544
,
...
...
@@ -125,7 +125,10 @@ class VGG16_Weights(WeightsEnum):
IMAGENET1K_FEATURES
=
Weights
(
url
=
"https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth"
,
transforms
=
partial
(
ImageNetEval
,
crop_size
=
224
,
mean
=
(
0.48235
,
0.45882
,
0.40784
),
std
=
(
1.0
/
255.0
,
1.0
/
255.0
,
1.0
/
255.0
)
ImageClassificationEval
,
crop_size
=
224
,
mean
=
(
0.48235
,
0.45882
,
0.40784
),
std
=
(
1.0
/
255.0
,
1.0
/
255.0
,
1.0
/
255.0
),
),
meta
=
{
**
_COMMON_META
,
...
...
@@ -142,7 +145,7 @@ class VGG16_Weights(WeightsEnum):
class
VGG16_BN_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/vgg16_bn-6c64b313.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
138365992
,
...
...
@@ -156,7 +159,7 @@ class VGG16_BN_Weights(WeightsEnum):
class
VGG19_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/vgg19-dcbb9e9d.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
143667240
,
...
...
@@ -170,7 +173,7 @@ class VGG19_Weights(WeightsEnum):
class
VGG19_BN_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
143678248
,
...
...
torchvision/prototype/models/video/resnet.py
View file @
d8654bb0
...
...
@@ -2,7 +2,7 @@ from functools import partial
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Sequence
,
Type
,
Union
from
torch
import
nn
from
torchvision.prototype.transforms
import
Kinect400
Eval
from
torchvision.prototype.transforms
import
VideoClassification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.video.resnet
import
(
...
...
@@ -65,7 +65,7 @@ _COMMON_META = {
class
R3D_18_Weights
(
WeightsEnum
):
KINETICS400_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/r3d_18-b3b3357e.pth"
,
transforms
=
partial
(
Kinect400
Eval
,
crop_size
=
(
112
,
112
),
resize_size
=
(
128
,
171
)),
transforms
=
partial
(
VideoClassification
Eval
,
crop_size
=
(
112
,
112
),
resize_size
=
(
128
,
171
)),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"R3D"
,
...
...
@@ -80,7 +80,7 @@ class R3D_18_Weights(WeightsEnum):
class
MC3_18_Weights
(
WeightsEnum
):
KINETICS400_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/mc3_18-a90a0ba3.pth"
,
transforms
=
partial
(
Kinect400
Eval
,
crop_size
=
(
112
,
112
),
resize_size
=
(
128
,
171
)),
transforms
=
partial
(
VideoClassification
Eval
,
crop_size
=
(
112
,
112
),
resize_size
=
(
128
,
171
)),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"MC3"
,
...
...
@@ -95,7 +95,7 @@ class MC3_18_Weights(WeightsEnum):
class
R2Plus1D_18_Weights
(
WeightsEnum
):
KINETICS400_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth"
,
transforms
=
partial
(
Kinect400
Eval
,
crop_size
=
(
112
,
112
),
resize_size
=
(
128
,
171
)),
transforms
=
partial
(
VideoClassification
Eval
,
crop_size
=
(
112
,
112
),
resize_size
=
(
128
,
171
)),
meta
=
{
**
_COMMON_META
,
"architecture"
:
"R(2+1)D"
,
...
...
torchvision/prototype/models/vision_transformer.py
View file @
d8654bb0
...
...
@@ -5,7 +5,7 @@
from
functools
import
partial
from
typing
import
Any
,
Optional
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
...models.vision_transformer
import
VisionTransformer
,
interpolate_embeddings
# noqa: F401
...
...
@@ -38,7 +38,7 @@ _COMMON_META = {
class
ViT_B_16_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/vit_b_16-c867db91.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
86567656
,
...
...
@@ -55,7 +55,7 @@ class ViT_B_16_Weights(WeightsEnum):
class
ViT_B_32_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/vit_b_32-d86f8d99.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
88224232
,
...
...
@@ -72,7 +72,7 @@ class ViT_B_32_Weights(WeightsEnum):
class
ViT_L_16_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/vit_l_16-852ce7e3.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
242
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
242
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
304326632
,
...
...
@@ -89,7 +89,7 @@ class ViT_L_16_Weights(WeightsEnum):
class
ViT_L_32_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/vit_l_32-c7638314.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
306535400
,
...
...
torchvision/prototype/transforms/__init__.py
View file @
d8654bb0
...
...
@@ -10,5 +10,11 @@ from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from
._geometry
import
HorizontalFlip
,
Resize
,
CenterCrop
,
RandomResizedCrop
,
FiveCrop
,
TenCrop
,
BatchMultiCrop
from
._meta
import
ConvertBoundingBoxFormat
,
ConvertImageDtype
,
ConvertImageColorSpace
from
._misc
import
Identity
,
Normalize
,
ToDtype
,
Lambda
from
._presets
import
CocoEval
,
ImageNetEval
,
VocEval
,
Kinect400Eval
,
RaftEval
from
._presets
import
(
ObjectDetectionEval
,
ImageClassificationEval
,
SemanticSegmentationEval
,
VideoClassificationEval
,
OpticalFlowEval
,
)
from
._type_conversion
import
DecodeImage
,
LabelToOneHot
torchvision/prototype/transforms/_presets.py
View file @
d8654bb0
...
...
@@ -6,10 +6,16 @@ from torch import Tensor, nn
from
...transforms
import
functional
as
F
,
InterpolationMode
__all__
=
[
"CocoEval"
,
"ImageNetEval"
,
"Kinect400Eval"
,
"VocEval"
,
"RaftEval"
]
__all__
=
[
"ObjectDetectionEval"
,
"ImageClassificationEval"
,
"VideoClassificationEval"
,
"SemanticSegmentationEval"
,
"OpticalFlowEval"
,
]
class
Coco
Eval
(
nn
.
Module
):
class
ObjectDetection
Eval
(
nn
.
Module
):
def
forward
(
self
,
img
:
Tensor
,
target
:
Optional
[
Dict
[
str
,
Tensor
]]
=
None
)
->
Tuple
[
Tensor
,
Optional
[
Dict
[
str
,
Tensor
]]]:
...
...
@@ -18,7 +24,7 @@ class CocoEval(nn.Module):
return
F
.
convert_image_dtype
(
img
,
torch
.
float
),
target
class
Image
Net
Eval
(
nn
.
Module
):
class
Image
Classification
Eval
(
nn
.
Module
):
def
__init__
(
self
,
crop_size
:
int
,
...
...
@@ -44,7 +50,7 @@ class ImageNetEval(nn.Module):
return
img
class
Kinect400
Eval
(
nn
.
Module
):
class
VideoClassification
Eval
(
nn
.
Module
):
def
__init__
(
self
,
crop_size
:
Tuple
[
int
,
int
],
...
...
@@ -69,7 +75,7 @@ class Kinect400Eval(nn.Module):
return
vid
.
permute
(
1
,
0
,
2
,
3
)
# (T, C, H, W) => (C, T, H, W)
class
Voc
Eval
(
nn
.
Module
):
class
SemanticSegmentation
Eval
(
nn
.
Module
):
def
__init__
(
self
,
resize_size
:
int
,
...
...
@@ -99,7 +105,7 @@ class VocEval(nn.Module):
return
img
,
target
class
Raft
Eval
(
nn
.
Module
):
class
OpticalFlow
Eval
(
nn
.
Module
):
def
forward
(
self
,
img1
:
Tensor
,
img2
:
Tensor
,
flow
:
Optional
[
Tensor
],
valid_flow_mask
:
Optional
[
Tensor
]
)
->
Tuple
[
Tensor
,
Tensor
,
Optional
[
Tensor
],
Optional
[
Tensor
]]:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment