Unverified Commit 683baf8e authored by Adam J. Stewart's avatar Adam J. Stewart Committed by GitHub
Browse files

Check sha256 of weights (#7219)


Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 8811c915
......@@ -85,8 +85,8 @@ class WeightsEnum(Enum):
)
return obj
def get_state_dict(self, progress: bool) -> Mapping[str, Any]:
return load_state_dict_from_url(self.url, progress=progress)
def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]:
return load_state_dict_from_url(self.url, *args, **kwargs)
def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self._name_}"
......
......@@ -114,6 +114,6 @@ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True,
model = AlexNet(**kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -189,7 +189,7 @@ def _convnext(
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -227,7 +227,7 @@ def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) ->
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)
state_dict = weights.get_state_dict(progress=progress)
state_dict = weights.get_state_dict(progress=progress, check_hash=True)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
......
......@@ -571,7 +571,7 @@ def fasterrcnn_resnet50_fpn(
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)
......@@ -653,7 +653,7 @@ def fasterrcnn_resnet50_fpn_v2(
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -694,7 +694,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -766,6 +766,6 @@ def fcos_resnet50_fpn(
model = FCOS(backbone, num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -465,7 +465,7 @@ def keypointrcnn_resnet50_fpn(
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)
......
......@@ -501,7 +501,7 @@ def maskrcnn_resnet50_fpn(
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)
......@@ -582,6 +582,6 @@ def maskrcnn_resnet50_fpn_v2(
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -815,7 +815,7 @@ def retinanet_resnet50_fpn(
model = RetinaNet(backbone, num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0)
......@@ -894,6 +894,6 @@ def retinanet_resnet50_fpn_v2(
model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -677,6 +677,6 @@ def ssd300_vgg16(
model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -326,6 +326,6 @@ def ssdlite320_mobilenet_v3_large(
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -357,7 +357,7 @@ def _efficientnet(
model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -332,7 +332,7 @@ def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = T
model = GoogLeNet(**kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
......
......@@ -470,7 +470,7 @@ def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bo
model = Inception3(**kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
......
......@@ -763,7 +763,7 @@ def _maxvit(
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -317,7 +317,7 @@ def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwa
model = MNASNet(alpha, **kwargs)
if weights:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -255,6 +255,6 @@ def mobilenet_v2(
model = MobileNetV2(**kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -282,7 +282,7 @@ def _mobilenet_v3(
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -818,7 +818,7 @@ def _raft(
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -197,7 +197,7 @@ def googlenet(
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if not original_aux_logits:
model.aux_logits = False
model.aux1 = None # type: ignore[assignment]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment