Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d8654bb0
Unverified
Commit
d8654bb0
authored
Mar 07, 2022
by
Vasilis Vryniotis
Committed by
GitHub
Mar 07, 2022
Browse files
Refactor preset transforms (#5562)
* Refactor preset transforms * Making presets public.
parent
2b5ab1bc
Changes
40
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
75 additions
and
57 deletions
+75
-57
references/classification/train.py
references/classification/train.py
+1
-1
references/detection/train.py
references/detection/train.py
+1
-1
references/optical_flow/train.py
references/optical_flow/train.py
+1
-1
references/segmentation/train.py
references/segmentation/train.py
+1
-1
references/video_classification/train.py
references/video_classification/train.py
+1
-1
torchvision/prototype/models/alexnet.py
torchvision/prototype/models/alexnet.py
+2
-2
torchvision/prototype/models/convnext.py
torchvision/prototype/models/convnext.py
+5
-5
torchvision/prototype/models/densenet.py
torchvision/prototype/models/densenet.py
+5
-5
torchvision/prototype/models/detection/faster_rcnn.py
torchvision/prototype/models/detection/faster_rcnn.py
+4
-4
torchvision/prototype/models/detection/fcos.py
torchvision/prototype/models/detection/fcos.py
+2
-2
torchvision/prototype/models/detection/keypoint_rcnn.py
torchvision/prototype/models/detection/keypoint_rcnn.py
+3
-3
torchvision/prototype/models/detection/mask_rcnn.py
torchvision/prototype/models/detection/mask_rcnn.py
+2
-2
torchvision/prototype/models/detection/retinanet.py
torchvision/prototype/models/detection/retinanet.py
+2
-2
torchvision/prototype/models/detection/ssd.py
torchvision/prototype/models/detection/ssd.py
+2
-2
torchvision/prototype/models/detection/ssdlite.py
torchvision/prototype/models/detection/ssdlite.py
+2
-2
torchvision/prototype/models/efficientnet.py
torchvision/prototype/models/efficientnet.py
+31
-13
torchvision/prototype/models/googlenet.py
torchvision/prototype/models/googlenet.py
+2
-2
torchvision/prototype/models/inception.py
torchvision/prototype/models/inception.py
+2
-2
torchvision/prototype/models/mnasnet.py
torchvision/prototype/models/mnasnet.py
+3
-3
torchvision/prototype/models/mobilenetv2.py
torchvision/prototype/models/mobilenetv2.py
+3
-3
No files found.
references/classification/train.py
View file @
d8654bb0
...
...
@@ -163,7 +163,7 @@ def load_data(traindir, valdir, args):
weights
=
prototype
.
models
.
get_weight
(
args
.
weights
)
preprocessing
=
weights
.
transforms
()
else
:
preprocessing
=
prototype
.
transforms
.
Image
Net
Eval
(
preprocessing
=
prototype
.
transforms
.
Image
Classification
Eval
(
crop_size
=
val_crop_size
,
resize_size
=
val_resize_size
,
interpolation
=
interpolation
)
...
...
references/detection/train.py
View file @
d8654bb0
...
...
@@ -57,7 +57,7 @@ def get_transform(train, args):
weights
=
prototype
.
models
.
get_weight
(
args
.
weights
)
return
weights
.
transforms
()
else
:
return
prototype
.
transforms
.
Coco
Eval
()
return
prototype
.
transforms
.
ObjectDetection
Eval
()
def
get_args_parser
(
add_help
=
True
):
...
...
references/optical_flow/train.py
View file @
d8654bb0
...
...
@@ -137,7 +137,7 @@ def validate(model, args):
weights
=
prototype
.
models
.
get_weight
(
args
.
weights
)
preprocessing
=
weights
.
transforms
()
else
:
preprocessing
=
prototype
.
transforms
.
Raft
Eval
()
preprocessing
=
prototype
.
transforms
.
OpticalFlow
Eval
()
else
:
preprocessing
=
OpticalFlowPresetEval
()
...
...
references/segmentation/train.py
View file @
d8654bb0
...
...
@@ -42,7 +42,7 @@ def get_transform(train, args):
weights
=
prototype
.
models
.
get_weight
(
args
.
weights
)
return
weights
.
transforms
()
else
:
return
prototype
.
transforms
.
Voc
Eval
(
resize_size
=
520
)
return
prototype
.
transforms
.
SemanticSegmentation
Eval
(
resize_size
=
520
)
def
criterion
(
inputs
,
target
):
...
...
references/video_classification/train.py
View file @
d8654bb0
...
...
@@ -157,7 +157,7 @@ def main(args):
weights
=
prototype
.
models
.
get_weight
(
args
.
weights
)
transform_test
=
weights
.
transforms
()
else
:
transform_test
=
prototype
.
transforms
.
Kinect400
Eval
(
crop_size
=
(
112
,
112
),
resize_size
=
(
128
,
171
))
transform_test
=
prototype
.
transforms
.
VideoClassification
Eval
(
crop_size
=
(
112
,
112
),
resize_size
=
(
128
,
171
))
if
args
.
cache_dataset
and
os
.
path
.
exists
(
cache_path
):
print
(
f
"Loading dataset_test from
{
cache_path
}
"
)
...
...
torchvision/prototype/models/alexnet.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
Optional
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
...models.alexnet
import
AlexNet
...
...
@@ -16,7 +16,7 @@ __all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
class
AlexNet_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/alexnet-owt-7be5be79.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
"task"
:
"image_classification"
,
"architecture"
:
"AlexNet"
,
...
...
torchvision/prototype/models/convnext.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
List
,
Optional
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
...models.convnext
import
ConvNeXt
,
CNBlockConfig
...
...
@@ -56,7 +56,7 @@ _COMMON_META = {
class
ConvNeXt_Tiny_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/convnext_tiny-983f1562.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
236
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
236
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
28589128
,
...
...
@@ -70,7 +70,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
class
ConvNeXt_Small_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/convnext_small-0c510722.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
230
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
230
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
50223688
,
...
...
@@ -84,7 +84,7 @@ class ConvNeXt_Small_Weights(WeightsEnum):
class
ConvNeXt_Base_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/convnext_base-6075fbad.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
88591464
,
...
...
@@ -98,7 +98,7 @@ class ConvNeXt_Base_Weights(WeightsEnum):
class
ConvNeXt_Large_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/convnext_large-ea097f82.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
197767336
,
...
...
torchvision/prototype/models/densenet.py
View file @
d8654bb0
...
...
@@ -3,7 +3,7 @@ from functools import partial
from
typing
import
Any
,
Optional
,
Tuple
import
torch.nn
as
nn
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
...models.densenet
import
DenseNet
...
...
@@ -78,7 +78,7 @@ _COMMON_META = {
class
DenseNet121_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/densenet121-a639ec97.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
7978856
,
...
...
@@ -92,7 +92,7 @@ class DenseNet121_Weights(WeightsEnum):
class
DenseNet161_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/densenet161-8d451a50.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
28681000
,
...
...
@@ -106,7 +106,7 @@ class DenseNet161_Weights(WeightsEnum):
class
DenseNet169_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/densenet169-b2777c0a.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
14149480
,
...
...
@@ -120,7 +120,7 @@ class DenseNet169_Weights(WeightsEnum):
class
DenseNet201_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/densenet201-c1103571.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
20013928
,
...
...
torchvision/prototype/models/detection/faster_rcnn.py
View file @
d8654bb0
from
typing
import
Any
,
Optional
,
Union
from
torch
import
nn
from
torchvision.prototype.transforms
import
Coco
Eval
from
torchvision.prototype.transforms
import
ObjectDetection
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.detection.faster_rcnn
import
(
...
...
@@ -43,7 +43,7 @@ _COMMON_META = {
class
FasterRCNN_ResNet50_FPN_Weights
(
WeightsEnum
):
COCO_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth"
,
transforms
=
Coco
Eval
,
transforms
=
ObjectDetection
Eval
,
meta
=
{
**
_COMMON_META
,
"num_params"
:
41755286
,
...
...
@@ -57,7 +57,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
class
FasterRCNN_MobileNet_V3_Large_FPN_Weights
(
WeightsEnum
):
COCO_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth"
,
transforms
=
Coco
Eval
,
transforms
=
ObjectDetection
Eval
,
meta
=
{
**
_COMMON_META
,
"num_params"
:
19386354
,
...
...
@@ -71,7 +71,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
class
FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
(
WeightsEnum
):
COCO_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth"
,
transforms
=
Coco
Eval
,
transforms
=
ObjectDetection
Eval
,
meta
=
{
**
_COMMON_META
,
"num_params"
:
19386354
,
...
...
torchvision/prototype/models/detection/fcos.py
View file @
d8654bb0
from
typing
import
Any
,
Optional
from
torch
import
nn
from
torchvision.prototype.transforms
import
Coco
Eval
from
torchvision.prototype.transforms
import
ObjectDetection
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.detection.fcos
import
(
...
...
@@ -27,7 +27,7 @@ __all__ = [
class
FCOS_ResNet50_FPN_Weights
(
WeightsEnum
):
COCO_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth"
,
transforms
=
Coco
Eval
,
transforms
=
ObjectDetection
Eval
,
meta
=
{
"task"
:
"image_object_detection"
,
"architecture"
:
"FCOS"
,
...
...
torchvision/prototype/models/detection/keypoint_rcnn.py
View file @
d8654bb0
from
typing
import
Any
,
Optional
from
torch
import
nn
from
torchvision.prototype.transforms
import
Coco
Eval
from
torchvision.prototype.transforms
import
ObjectDetection
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.detection.keypoint_rcnn
import
(
...
...
@@ -37,7 +37,7 @@ _COMMON_META = {
class
KeypointRCNN_ResNet50_FPN_Weights
(
WeightsEnum
):
COCO_LEGACY
=
Weights
(
url
=
"https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth"
,
transforms
=
Coco
Eval
,
transforms
=
ObjectDetection
Eval
,
meta
=
{
**
_COMMON_META
,
"num_params"
:
59137258
,
...
...
@@ -48,7 +48,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
)
COCO_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth"
,
transforms
=
Coco
Eval
,
transforms
=
ObjectDetection
Eval
,
meta
=
{
**
_COMMON_META
,
"num_params"
:
59137258
,
...
...
torchvision/prototype/models/detection/mask_rcnn.py
View file @
d8654bb0
from
typing
import
Any
,
Optional
from
torch
import
nn
from
torchvision.prototype.transforms
import
Coco
Eval
from
torchvision.prototype.transforms
import
ObjectDetection
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.detection.mask_rcnn
import
(
...
...
@@ -27,7 +27,7 @@ __all__ = [
class
MaskRCNN_ResNet50_FPN_Weights
(
WeightsEnum
):
COCO_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth"
,
transforms
=
Coco
Eval
,
transforms
=
ObjectDetection
Eval
,
meta
=
{
"task"
:
"image_object_detection"
,
"architecture"
:
"MaskRCNN"
,
...
...
torchvision/prototype/models/detection/retinanet.py
View file @
d8654bb0
from
typing
import
Any
,
Optional
from
torch
import
nn
from
torchvision.prototype.transforms
import
Coco
Eval
from
torchvision.prototype.transforms
import
ObjectDetection
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.detection.retinanet
import
(
...
...
@@ -28,7 +28,7 @@ __all__ = [
class
RetinaNet_ResNet50_FPN_Weights
(
WeightsEnum
):
COCO_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth"
,
transforms
=
Coco
Eval
,
transforms
=
ObjectDetection
Eval
,
meta
=
{
"task"
:
"image_object_detection"
,
"architecture"
:
"RetinaNet"
,
...
...
torchvision/prototype/models/detection/ssd.py
View file @
d8654bb0
import
warnings
from
typing
import
Any
,
Optional
from
torchvision.prototype.transforms
import
Coco
Eval
from
torchvision.prototype.transforms
import
ObjectDetection
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.detection.ssd
import
(
...
...
@@ -25,7 +25,7 @@ __all__ = [
class
SSD300_VGG16_Weights
(
WeightsEnum
):
COCO_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth"
,
transforms
=
Coco
Eval
,
transforms
=
ObjectDetection
Eval
,
meta
=
{
"task"
:
"image_object_detection"
,
"architecture"
:
"SSD"
,
...
...
torchvision/prototype/models/detection/ssdlite.py
View file @
d8654bb0
...
...
@@ -3,7 +3,7 @@ from functools import partial
from
typing
import
Any
,
Callable
,
Optional
from
torch
import
nn
from
torchvision.prototype.transforms
import
Coco
Eval
from
torchvision.prototype.transforms
import
ObjectDetection
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
....models.detection.ssdlite
import
(
...
...
@@ -30,7 +30,7 @@ __all__ = [
class
SSDLite320_MobileNet_V3_Large_Weights
(
WeightsEnum
):
COCO_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth"
,
transforms
=
Coco
Eval
,
transforms
=
ObjectDetection
Eval
,
meta
=
{
"task"
:
"image_object_detection"
,
"architecture"
:
"SSDLite"
,
...
...
torchvision/prototype/models/efficientnet.py
View file @
d8654bb0
...
...
@@ -2,7 +2,7 @@ from functools import partial
from
typing
import
Any
,
Optional
,
Sequence
,
Union
from
torch
import
nn
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
...models.efficientnet
import
EfficientNet
,
MBConvConfig
,
FusedMBConvConfig
,
_efficientnet_conf
...
...
@@ -85,7 +85,9 @@ _COMMON_META_V2 = {
class
EfficientNet_B0_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth"
,
transforms
=
partial
(
ImageNetEval
,
crop_size
=
224
,
resize_size
=
256
,
interpolation
=
InterpolationMode
.
BICUBIC
),
transforms
=
partial
(
ImageClassificationEval
,
crop_size
=
224
,
resize_size
=
256
,
interpolation
=
InterpolationMode
.
BICUBIC
),
meta
=
{
**
_COMMON_META_V1
,
"num_params"
:
5288548
,
...
...
@@ -100,7 +102,9 @@ class EfficientNet_B0_Weights(WeightsEnum):
class
EfficientNet_B1_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth"
,
transforms
=
partial
(
ImageNetEval
,
crop_size
=
240
,
resize_size
=
256
,
interpolation
=
InterpolationMode
.
BICUBIC
),
transforms
=
partial
(
ImageClassificationEval
,
crop_size
=
240
,
resize_size
=
256
,
interpolation
=
InterpolationMode
.
BICUBIC
),
meta
=
{
**
_COMMON_META_V1
,
"num_params"
:
7794184
,
...
...
@@ -111,7 +115,9 @@ class EfficientNet_B1_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth"
,
transforms
=
partial
(
ImageNetEval
,
crop_size
=
240
,
resize_size
=
255
,
interpolation
=
InterpolationMode
.
BILINEAR
),
transforms
=
partial
(
ImageClassificationEval
,
crop_size
=
240
,
resize_size
=
255
,
interpolation
=
InterpolationMode
.
BILINEAR
),
meta
=
{
**
_COMMON_META_V1
,
"num_params"
:
7794184
,
...
...
@@ -128,7 +134,9 @@ class EfficientNet_B1_Weights(WeightsEnum):
class
EfficientNet_B2_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth"
,
transforms
=
partial
(
ImageNetEval
,
crop_size
=
288
,
resize_size
=
288
,
interpolation
=
InterpolationMode
.
BICUBIC
),
transforms
=
partial
(
ImageClassificationEval
,
crop_size
=
288
,
resize_size
=
288
,
interpolation
=
InterpolationMode
.
BICUBIC
),
meta
=
{
**
_COMMON_META_V1
,
"num_params"
:
9109994
,
...
...
@@ -143,7 +151,9 @@ class EfficientNet_B2_Weights(WeightsEnum):
class
EfficientNet_B3_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth"
,
transforms
=
partial
(
ImageNetEval
,
crop_size
=
300
,
resize_size
=
320
,
interpolation
=
InterpolationMode
.
BICUBIC
),
transforms
=
partial
(
ImageClassificationEval
,
crop_size
=
300
,
resize_size
=
320
,
interpolation
=
InterpolationMode
.
BICUBIC
),
meta
=
{
**
_COMMON_META_V1
,
"num_params"
:
12233232
,
...
...
@@ -158,7 +168,9 @@ class EfficientNet_B3_Weights(WeightsEnum):
class
EfficientNet_B4_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth"
,
transforms
=
partial
(
ImageNetEval
,
crop_size
=
380
,
resize_size
=
384
,
interpolation
=
InterpolationMode
.
BICUBIC
),
transforms
=
partial
(
ImageClassificationEval
,
crop_size
=
380
,
resize_size
=
384
,
interpolation
=
InterpolationMode
.
BICUBIC
),
meta
=
{
**
_COMMON_META_V1
,
"num_params"
:
19341616
,
...
...
@@ -173,7 +185,9 @@ class EfficientNet_B4_Weights(WeightsEnum):
class
EfficientNet_B5_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth"
,
transforms
=
partial
(
ImageNetEval
,
crop_size
=
456
,
resize_size
=
456
,
interpolation
=
InterpolationMode
.
BICUBIC
),
transforms
=
partial
(
ImageClassificationEval
,
crop_size
=
456
,
resize_size
=
456
,
interpolation
=
InterpolationMode
.
BICUBIC
),
meta
=
{
**
_COMMON_META_V1
,
"num_params"
:
30389784
,
...
...
@@ -188,7 +202,9 @@ class EfficientNet_B5_Weights(WeightsEnum):
class
EfficientNet_B6_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth"
,
transforms
=
partial
(
ImageNetEval
,
crop_size
=
528
,
resize_size
=
528
,
interpolation
=
InterpolationMode
.
BICUBIC
),
transforms
=
partial
(
ImageClassificationEval
,
crop_size
=
528
,
resize_size
=
528
,
interpolation
=
InterpolationMode
.
BICUBIC
),
meta
=
{
**
_COMMON_META_V1
,
"num_params"
:
43040704
,
...
...
@@ -203,7 +219,9 @@ class EfficientNet_B6_Weights(WeightsEnum):
class
EfficientNet_B7_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth"
,
transforms
=
partial
(
ImageNetEval
,
crop_size
=
600
,
resize_size
=
600
,
interpolation
=
InterpolationMode
.
BICUBIC
),
transforms
=
partial
(
ImageClassificationEval
,
crop_size
=
600
,
resize_size
=
600
,
interpolation
=
InterpolationMode
.
BICUBIC
),
meta
=
{
**
_COMMON_META_V1
,
"num_params"
:
66347960
,
...
...
@@ -219,7 +237,7 @@ class EfficientNet_V2_S_Weights(WeightsEnum):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
Image
Classification
Eval
,
crop_size
=
384
,
resize_size
=
384
,
interpolation
=
InterpolationMode
.
BILINEAR
,
...
...
@@ -239,7 +257,7 @@ class EfficientNet_V2_M_Weights(WeightsEnum):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
Image
Classification
Eval
,
crop_size
=
480
,
resize_size
=
480
,
interpolation
=
InterpolationMode
.
BILINEAR
,
...
...
@@ -259,7 +277,7 @@ class EfficientNet_V2_L_Weights(WeightsEnum):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
Image
Classification
Eval
,
crop_size
=
480
,
resize_size
=
480
,
interpolation
=
InterpolationMode
.
BICUBIC
,
...
...
torchvision/prototype/models/googlenet.py
View file @
d8654bb0
...
...
@@ -2,7 +2,7 @@ import warnings
from
functools
import
partial
from
typing
import
Any
,
Optional
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
...models.googlenet
import
GoogLeNet
,
GoogLeNetOutputs
,
_GoogLeNetOutputs
...
...
@@ -17,7 +17,7 @@ __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weig
class
GoogLeNet_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/googlenet-1378be20.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
"task"
:
"image_classification"
,
"architecture"
:
"GoogLeNet"
,
...
...
torchvision/prototype/models/inception.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
Optional
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
...models.inception
import
Inception3
,
InceptionOutputs
,
_InceptionOutputs
...
...
@@ -16,7 +16,7 @@ __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_
class
Inception_V3_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
299
,
resize_size
=
342
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
299
,
resize_size
=
342
),
meta
=
{
"task"
:
"image_classification"
,
"architecture"
:
"InceptionV3"
,
...
...
torchvision/prototype/models/mnasnet.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
Optional
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
...models.mnasnet
import
MNASNet
...
...
@@ -38,7 +38,7 @@ _COMMON_META = {
class
MNASNet0_5_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
2218512
,
...
...
@@ -57,7 +57,7 @@ class MNASNet0_75_Weights(WeightsEnum):
class
MNASNet1_0_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"num_params"
:
4383312
,
...
...
torchvision/prototype/models/mobilenetv2.py
View file @
d8654bb0
from
functools
import
partial
from
typing
import
Any
,
Optional
from
torchvision.prototype.transforms
import
Image
Net
Eval
from
torchvision.prototype.transforms
import
Image
Classification
Eval
from
torchvision.transforms.functional
import
InterpolationMode
from
...models.mobilenetv2
import
MobileNetV2
...
...
@@ -28,7 +28,7 @@ _COMMON_META = {
class
MobileNet_V2_Weights
(
WeightsEnum
):
IMAGENET1K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/mobilenet_v2-b0353104.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
"recipe"
:
"https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2"
,
...
...
@@ -38,7 +38,7 @@ class MobileNet_V2_Weights(WeightsEnum):
)
IMAGENET1K_V2
=
Weights
(
url
=
"https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth"
,
transforms
=
partial
(
Image
Net
Eval
,
crop_size
=
224
,
resize_size
=
232
),
transforms
=
partial
(
Image
Classification
Eval
,
crop_size
=
224
,
resize_size
=
232
),
meta
=
{
**
_COMMON_META
,
"recipe"
:
"https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning"
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment