Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
683baf8e
Unverified
Commit
683baf8e
authored
May 03, 2023
by
Adam J. Stewart
Committed by
GitHub
May 03, 2023
Browse files
Check sha256 of weights (#7219)
Co-authored-by:
Nicolas Hug
<
nh.nicolas.hug@gmail.com
>
parent
8811c915
Changes
41
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
25 additions
and
25 deletions
+25
-25
torchvision/models/_api.py
torchvision/models/_api.py
+2
-2
torchvision/models/alexnet.py
torchvision/models/alexnet.py
+1
-1
torchvision/models/convnext.py
torchvision/models/convnext.py
+1
-1
torchvision/models/densenet.py
torchvision/models/densenet.py
+1
-1
torchvision/models/detection/faster_rcnn.py
torchvision/models/detection/faster_rcnn.py
+3
-3
torchvision/models/detection/fcos.py
torchvision/models/detection/fcos.py
+1
-1
torchvision/models/detection/keypoint_rcnn.py
torchvision/models/detection/keypoint_rcnn.py
+1
-1
torchvision/models/detection/mask_rcnn.py
torchvision/models/detection/mask_rcnn.py
+2
-2
torchvision/models/detection/retinanet.py
torchvision/models/detection/retinanet.py
+2
-2
torchvision/models/detection/ssd.py
torchvision/models/detection/ssd.py
+1
-1
torchvision/models/detection/ssdlite.py
torchvision/models/detection/ssdlite.py
+1
-1
torchvision/models/efficientnet.py
torchvision/models/efficientnet.py
+1
-1
torchvision/models/googlenet.py
torchvision/models/googlenet.py
+1
-1
torchvision/models/inception.py
torchvision/models/inception.py
+1
-1
torchvision/models/maxvit.py
torchvision/models/maxvit.py
+1
-1
torchvision/models/mnasnet.py
torchvision/models/mnasnet.py
+1
-1
torchvision/models/mobilenetv2.py
torchvision/models/mobilenetv2.py
+1
-1
torchvision/models/mobilenetv3.py
torchvision/models/mobilenetv3.py
+1
-1
torchvision/models/optical_flow/raft.py
torchvision/models/optical_flow/raft.py
+1
-1
torchvision/models/quantization/googlenet.py
torchvision/models/quantization/googlenet.py
+1
-1
No files found.
torchvision/models/_api.py
View file @
683baf8e
...
@@ -85,8 +85,8 @@ class WeightsEnum(Enum):
...
@@ -85,8 +85,8 @@ class WeightsEnum(Enum):
)
)
return
obj
return
obj
def
get_state_dict
(
self
,
progress
:
bool
)
->
Mapping
[
str
,
Any
]:
def
get_state_dict
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Mapping
[
str
,
Any
]:
return
load_state_dict_from_url
(
self
.
url
,
progress
=
progres
s
)
return
load_state_dict_from_url
(
self
.
url
,
*
args
,
**
kwarg
s
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
.
{
self
.
_name_
}
"
return
f
"
{
self
.
__class__
.
__name__
}
.
{
self
.
_name_
}
"
...
...
torchvision/models/alexnet.py
View file @
683baf8e
...
@@ -114,6 +114,6 @@ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True,
...
@@ -114,6 +114,6 @@ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True,
model
=
AlexNet
(
**
kwargs
)
model
=
AlexNet
(
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
torchvision/models/convnext.py
View file @
683baf8e
...
@@ -189,7 +189,7 @@ def _convnext(
...
@@ -189,7 +189,7 @@ def _convnext(
model
=
ConvNeXt
(
block_setting
,
stochastic_depth_prob
=
stochastic_depth_prob
,
**
kwargs
)
model
=
ConvNeXt
(
block_setting
,
stochastic_depth_prob
=
stochastic_depth_prob
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
...
torchvision/models/densenet.py
View file @
683baf8e
...
@@ -227,7 +227,7 @@ def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) ->
...
@@ -227,7 +227,7 @@ def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) ->
r
"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
r
"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)
)
state_dict
=
weights
.
get_state_dict
(
progress
=
progress
)
state_dict
=
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
)
for
key
in
list
(
state_dict
.
keys
()):
for
key
in
list
(
state_dict
.
keys
()):
res
=
pattern
.
match
(
key
)
res
=
pattern
.
match
(
key
)
if
res
:
if
res
:
...
...
torchvision/models/detection/faster_rcnn.py
View file @
683baf8e
...
@@ -571,7 +571,7 @@ def fasterrcnn_resnet50_fpn(
...
@@ -571,7 +571,7 @@ def fasterrcnn_resnet50_fpn(
model
=
FasterRCNN
(
backbone
,
num_classes
=
num_classes
,
**
kwargs
)
model
=
FasterRCNN
(
backbone
,
num_classes
=
num_classes
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
if
weights
==
FasterRCNN_ResNet50_FPN_Weights
.
COCO_V1
:
if
weights
==
FasterRCNN_ResNet50_FPN_Weights
.
COCO_V1
:
overwrite_eps
(
model
,
0.0
)
overwrite_eps
(
model
,
0.0
)
...
@@ -653,7 +653,7 @@ def fasterrcnn_resnet50_fpn_v2(
...
@@ -653,7 +653,7 @@ def fasterrcnn_resnet50_fpn_v2(
)
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -694,7 +694,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
...
@@ -694,7 +694,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
)
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
...
torchvision/models/detection/fcos.py
View file @
683baf8e
...
@@ -766,6 +766,6 @@ def fcos_resnet50_fpn(
...
@@ -766,6 +766,6 @@ def fcos_resnet50_fpn(
model
=
FCOS
(
backbone
,
num_classes
,
**
kwargs
)
model
=
FCOS
(
backbone
,
num_classes
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
torchvision/models/detection/keypoint_rcnn.py
View file @
683baf8e
...
@@ -465,7 +465,7 @@ def keypointrcnn_resnet50_fpn(
...
@@ -465,7 +465,7 @@ def keypointrcnn_resnet50_fpn(
model
=
KeypointRCNN
(
backbone
,
num_classes
,
num_keypoints
=
num_keypoints
,
**
kwargs
)
model
=
KeypointRCNN
(
backbone
,
num_classes
,
num_keypoints
=
num_keypoints
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
if
weights
==
KeypointRCNN_ResNet50_FPN_Weights
.
COCO_V1
:
if
weights
==
KeypointRCNN_ResNet50_FPN_Weights
.
COCO_V1
:
overwrite_eps
(
model
,
0.0
)
overwrite_eps
(
model
,
0.0
)
...
...
torchvision/models/detection/mask_rcnn.py
View file @
683baf8e
...
@@ -501,7 +501,7 @@ def maskrcnn_resnet50_fpn(
...
@@ -501,7 +501,7 @@ def maskrcnn_resnet50_fpn(
model
=
MaskRCNN
(
backbone
,
num_classes
=
num_classes
,
**
kwargs
)
model
=
MaskRCNN
(
backbone
,
num_classes
=
num_classes
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
if
weights
==
MaskRCNN_ResNet50_FPN_Weights
.
COCO_V1
:
if
weights
==
MaskRCNN_ResNet50_FPN_Weights
.
COCO_V1
:
overwrite_eps
(
model
,
0.0
)
overwrite_eps
(
model
,
0.0
)
...
@@ -582,6 +582,6 @@ def maskrcnn_resnet50_fpn_v2(
...
@@ -582,6 +582,6 @@ def maskrcnn_resnet50_fpn_v2(
)
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
torchvision/models/detection/retinanet.py
View file @
683baf8e
...
@@ -815,7 +815,7 @@ def retinanet_resnet50_fpn(
...
@@ -815,7 +815,7 @@ def retinanet_resnet50_fpn(
model
=
RetinaNet
(
backbone
,
num_classes
,
**
kwargs
)
model
=
RetinaNet
(
backbone
,
num_classes
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
if
weights
==
RetinaNet_ResNet50_FPN_Weights
.
COCO_V1
:
if
weights
==
RetinaNet_ResNet50_FPN_Weights
.
COCO_V1
:
overwrite_eps
(
model
,
0.0
)
overwrite_eps
(
model
,
0.0
)
...
@@ -894,6 +894,6 @@ def retinanet_resnet50_fpn_v2(
...
@@ -894,6 +894,6 @@ def retinanet_resnet50_fpn_v2(
model
=
RetinaNet
(
backbone
,
num_classes
,
anchor_generator
=
anchor_generator
,
head
=
head
,
**
kwargs
)
model
=
RetinaNet
(
backbone
,
num_classes
,
anchor_generator
=
anchor_generator
,
head
=
head
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
torchvision/models/detection/ssd.py
View file @
683baf8e
...
@@ -677,6 +677,6 @@ def ssd300_vgg16(
...
@@ -677,6 +677,6 @@ def ssd300_vgg16(
model
=
SSD
(
backbone
,
anchor_generator
,
(
300
,
300
),
num_classes
,
**
kwargs
)
model
=
SSD
(
backbone
,
anchor_generator
,
(
300
,
300
),
num_classes
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
torchvision/models/detection/ssdlite.py
View file @
683baf8e
...
@@ -326,6 +326,6 @@ def ssdlite320_mobilenet_v3_large(
...
@@ -326,6 +326,6 @@ def ssdlite320_mobilenet_v3_large(
)
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
torchvision/models/efficientnet.py
View file @
683baf8e
...
@@ -357,7 +357,7 @@ def _efficientnet(
...
@@ -357,7 +357,7 @@ def _efficientnet(
model
=
EfficientNet
(
inverted_residual_setting
,
dropout
,
last_channel
=
last_channel
,
**
kwargs
)
model
=
EfficientNet
(
inverted_residual_setting
,
dropout
,
last_channel
=
last_channel
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
...
torchvision/models/googlenet.py
View file @
683baf8e
...
@@ -332,7 +332,7 @@ def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = T
...
@@ -332,7 +332,7 @@ def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = T
model
=
GoogLeNet
(
**
kwargs
)
model
=
GoogLeNet
(
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
if
not
original_aux_logits
:
if
not
original_aux_logits
:
model
.
aux_logits
=
False
model
.
aux_logits
=
False
model
.
aux1
=
None
# type: ignore[assignment]
model
.
aux1
=
None
# type: ignore[assignment]
...
...
torchvision/models/inception.py
View file @
683baf8e
...
@@ -470,7 +470,7 @@ def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bo
...
@@ -470,7 +470,7 @@ def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bo
model
=
Inception3
(
**
kwargs
)
model
=
Inception3
(
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
if
not
original_aux_logits
:
if
not
original_aux_logits
:
model
.
aux_logits
=
False
model
.
aux_logits
=
False
model
.
AuxLogits
=
None
model
.
AuxLogits
=
None
...
...
torchvision/models/maxvit.py
View file @
683baf8e
...
@@ -763,7 +763,7 @@ def _maxvit(
...
@@ -763,7 +763,7 @@ def _maxvit(
)
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
...
torchvision/models/mnasnet.py
View file @
683baf8e
...
@@ -317,7 +317,7 @@ def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwa
...
@@ -317,7 +317,7 @@ def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwa
model
=
MNASNet
(
alpha
,
**
kwargs
)
model
=
MNASNet
(
alpha
,
**
kwargs
)
if
weights
:
if
weights
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
...
torchvision/models/mobilenetv2.py
View file @
683baf8e
...
@@ -255,6 +255,6 @@ def mobilenet_v2(
...
@@ -255,6 +255,6 @@ def mobilenet_v2(
model
=
MobileNetV2
(
**
kwargs
)
model
=
MobileNetV2
(
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
torchvision/models/mobilenetv3.py
View file @
683baf8e
...
@@ -282,7 +282,7 @@ def _mobilenet_v3(
...
@@ -282,7 +282,7 @@ def _mobilenet_v3(
model
=
MobileNetV3
(
inverted_residual_setting
,
last_channel
,
**
kwargs
)
model
=
MobileNetV3
(
inverted_residual_setting
,
last_channel
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
...
torchvision/models/optical_flow/raft.py
View file @
683baf8e
...
@@ -818,7 +818,7 @@ def _raft(
...
@@ -818,7 +818,7 @@ def _raft(
)
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
...
torchvision/models/quantization/googlenet.py
View file @
683baf8e
...
@@ -197,7 +197,7 @@ def googlenet(
...
@@ -197,7 +197,7 @@ def googlenet(
quantize_model
(
model
,
backend
)
quantize_model
(
model
,
backend
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
if
not
original_aux_logits
:
if
not
original_aux_logits
:
model
.
aux_logits
=
False
model
.
aux_logits
=
False
model
.
aux1
=
None
# type: ignore[assignment]
model
.
aux1
=
None
# type: ignore[assignment]
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment