".github/git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "7dca70049566b5b1c55cbd67e1cb191729a98152"
Unverified Commit 683baf8e authored by Adam J. Stewart's avatar Adam J. Stewart Committed by GitHub
Browse files

Check sha256 of weights (#7219)


Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 8811c915
...@@ -85,8 +85,8 @@ class WeightsEnum(Enum): ...@@ -85,8 +85,8 @@ class WeightsEnum(Enum):
) )
return obj return obj
def get_state_dict(self, progress: bool) -> Mapping[str, Any]: def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]:
return load_state_dict_from_url(self.url, progress=progress) return load_state_dict_from_url(self.url, *args, **kwargs)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self._name_}" return f"{self.__class__.__name__}.{self._name_}"
......
...@@ -114,6 +114,6 @@ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, ...@@ -114,6 +114,6 @@ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True,
model = AlexNet(**kwargs) model = AlexNet(**kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -189,7 +189,7 @@ def _convnext( ...@@ -189,7 +189,7 @@ def _convnext(
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -227,7 +227,7 @@ def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> ...@@ -227,7 +227,7 @@ def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) ->
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
) )
state_dict = weights.get_state_dict(progress=progress) state_dict = weights.get_state_dict(progress=progress, check_hash=True)
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
res = pattern.match(key) res = pattern.match(key)
if res: if res:
......
...@@ -571,7 +571,7 @@ def fasterrcnn_resnet50_fpn( ...@@ -571,7 +571,7 @@ def fasterrcnn_resnet50_fpn(
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1: if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
...@@ -653,7 +653,7 @@ def fasterrcnn_resnet50_fpn_v2( ...@@ -653,7 +653,7 @@ def fasterrcnn_resnet50_fpn_v2(
) )
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -694,7 +694,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn( ...@@ -694,7 +694,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
) )
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -766,6 +766,6 @@ def fcos_resnet50_fpn( ...@@ -766,6 +766,6 @@ def fcos_resnet50_fpn(
model = FCOS(backbone, num_classes, **kwargs) model = FCOS(backbone, num_classes, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -465,7 +465,7 @@ def keypointrcnn_resnet50_fpn( ...@@ -465,7 +465,7 @@ def keypointrcnn_resnet50_fpn(
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1: if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
......
...@@ -501,7 +501,7 @@ def maskrcnn_resnet50_fpn( ...@@ -501,7 +501,7 @@ def maskrcnn_resnet50_fpn(
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs) model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1: if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
...@@ -582,6 +582,6 @@ def maskrcnn_resnet50_fpn_v2( ...@@ -582,6 +582,6 @@ def maskrcnn_resnet50_fpn_v2(
) )
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -815,7 +815,7 @@ def retinanet_resnet50_fpn( ...@@ -815,7 +815,7 @@ def retinanet_resnet50_fpn(
model = RetinaNet(backbone, num_classes, **kwargs) model = RetinaNet(backbone, num_classes, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1: if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
...@@ -894,6 +894,6 @@ def retinanet_resnet50_fpn_v2( ...@@ -894,6 +894,6 @@ def retinanet_resnet50_fpn_v2(
model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs) model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -677,6 +677,6 @@ def ssd300_vgg16( ...@@ -677,6 +677,6 @@ def ssd300_vgg16(
model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -326,6 +326,6 @@ def ssdlite320_mobilenet_v3_large( ...@@ -326,6 +326,6 @@ def ssdlite320_mobilenet_v3_large(
) )
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -357,7 +357,7 @@ def _efficientnet( ...@@ -357,7 +357,7 @@ def _efficientnet(
model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs) model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -332,7 +332,7 @@ def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = T ...@@ -332,7 +332,7 @@ def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = T
model = GoogLeNet(**kwargs) model = GoogLeNet(**kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if not original_aux_logits: if not original_aux_logits:
model.aux_logits = False model.aux_logits = False
model.aux1 = None # type: ignore[assignment] model.aux1 = None # type: ignore[assignment]
......
...@@ -470,7 +470,7 @@ def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bo ...@@ -470,7 +470,7 @@ def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bo
model = Inception3(**kwargs) model = Inception3(**kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if not original_aux_logits: if not original_aux_logits:
model.aux_logits = False model.aux_logits = False
model.AuxLogits = None model.AuxLogits = None
......
...@@ -763,7 +763,7 @@ def _maxvit( ...@@ -763,7 +763,7 @@ def _maxvit(
) )
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -317,7 +317,7 @@ def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwa ...@@ -317,7 +317,7 @@ def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwa
model = MNASNet(alpha, **kwargs) model = MNASNet(alpha, **kwargs)
if weights: if weights:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -255,6 +255,6 @@ def mobilenet_v2( ...@@ -255,6 +255,6 @@ def mobilenet_v2(
model = MobileNetV2(**kwargs) model = MobileNetV2(**kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -282,7 +282,7 @@ def _mobilenet_v3( ...@@ -282,7 +282,7 @@ def _mobilenet_v3(
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -818,7 +818,7 @@ def _raft( ...@@ -818,7 +818,7 @@ def _raft(
) )
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -197,7 +197,7 @@ def googlenet( ...@@ -197,7 +197,7 @@ def googlenet(
quantize_model(model, backend) quantize_model(model, backend)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if not original_aux_logits: if not original_aux_logits:
model.aux_logits = False model.aux_logits = False
model.aux1 = None # type: ignore[assignment] model.aux1 = None # type: ignore[assignment]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment