Unverified Commit 683baf8e authored by Adam J. Stewart's avatar Adam J. Stewart Committed by GitHub
Browse files

Check sha256 of weights (#7219)


Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 8811c915
...@@ -265,7 +265,7 @@ def inception_v3( ...@@ -265,7 +265,7 @@ def inception_v3(
if quantize and not original_aux_logits: if quantize and not original_aux_logits:
model.aux_logits = False model.aux_logits = False
model.AuxLogits = None model.AuxLogits = None
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if not quantize and not original_aux_logits: if not quantize and not original_aux_logits:
model.aux_logits = False model.aux_logits = False
model.AuxLogits = None model.AuxLogits = None
......
...@@ -149,6 +149,6 @@ def mobilenet_v2( ...@@ -149,6 +149,6 @@ def mobilenet_v2(
quantize_model(model, backend) quantize_model(model, backend)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -149,7 +149,7 @@ def _mobilenet_v3_model( ...@@ -149,7 +149,7 @@ def _mobilenet_v3_model(
torch.ao.quantization.prepare_qat(model, inplace=True) torch.ao.quantization.prepare_qat(model, inplace=True)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if quantize: if quantize:
torch.ao.quantization.convert(model, inplace=True) torch.ao.quantization.convert(model, inplace=True)
......
...@@ -144,7 +144,7 @@ def _resnet( ...@@ -144,7 +144,7 @@ def _resnet(
quantize_model(model, backend) quantize_model(model, backend)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -108,7 +108,7 @@ def _shufflenetv2( ...@@ -108,7 +108,7 @@ def _shufflenetv2(
quantize_model(model, backend) quantize_model(model, backend)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -397,7 +397,7 @@ def _regnet( ...@@ -397,7 +397,7 @@ def _regnet(
model = RegNet(block_params, norm_layer=norm_layer, **kwargs) model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -298,7 +298,7 @@ def _resnet( ...@@ -298,7 +298,7 @@ def _resnet(
model = ResNet(block, layers, **kwargs) model = ResNet(block, layers, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -275,7 +275,7 @@ def deeplabv3_resnet50( ...@@ -275,7 +275,7 @@ def deeplabv3_resnet50(
model = _deeplabv3_resnet(backbone, num_classes, aux_loss) model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -331,7 +331,7 @@ def deeplabv3_resnet101( ...@@ -331,7 +331,7 @@ def deeplabv3_resnet101(
model = _deeplabv3_resnet(backbone, num_classes, aux_loss) model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -385,6 +385,6 @@ def deeplabv3_mobilenet_v3_large( ...@@ -385,6 +385,6 @@ def deeplabv3_mobilenet_v3_large(
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -168,7 +168,7 @@ def fcn_resnet50( ...@@ -168,7 +168,7 @@ def fcn_resnet50(
model = _fcn_resnet(backbone, num_classes, aux_loss) model = _fcn_resnet(backbone, num_classes, aux_loss)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -227,6 +227,6 @@ def fcn_resnet101( ...@@ -227,6 +227,6 @@ def fcn_resnet101(
model = _fcn_resnet(backbone, num_classes, aux_loss) model = _fcn_resnet(backbone, num_classes, aux_loss)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -173,6 +173,6 @@ def lraspp_mobilenet_v3_large( ...@@ -173,6 +173,6 @@ def lraspp_mobilenet_v3_large(
model = _lraspp_mobilenetv3(backbone, num_classes) model = _lraspp_mobilenetv3(backbone, num_classes)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -178,7 +178,7 @@ def _shufflenetv2( ...@@ -178,7 +178,7 @@ def _shufflenetv2(
model = ShuffleNetV2(*args, **kwargs) model = ShuffleNetV2(*args, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -109,7 +109,7 @@ def _squeezenet( ...@@ -109,7 +109,7 @@ def _squeezenet(
model = SqueezeNet(version, **kwargs) model = SqueezeNet(version, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -639,7 +639,7 @@ def _swin_transformer( ...@@ -639,7 +639,7 @@ def _swin_transformer(
) )
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -102,7 +102,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b ...@@ -102,7 +102,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -593,7 +593,7 @@ def _mvit( ...@@ -593,7 +593,7 @@ def _mvit(
) )
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -303,7 +303,7 @@ def _video_resnet( ...@@ -303,7 +303,7 @@ def _video_resnet(
model = VideoResNet(block, conv_makers, layers, stem, **kwargs) model = VideoResNet(block, conv_makers, layers, stem, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -214,6 +214,6 @@ def s3d(*, weights: Optional[S3D_Weights] = None, progress: bool = True, **kwarg ...@@ -214,6 +214,6 @@ def s3d(*, weights: Optional[S3D_Weights] = None, progress: bool = True, **kwarg
model = S3D(**kwargs) model = S3D(**kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -497,7 +497,7 @@ def _swin_transformer3d( ...@@ -497,7 +497,7 @@ def _swin_transformer3d(
) )
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -332,7 +332,7 @@ def _vision_transformer( ...@@ -332,7 +332,7 @@ def _vision_transformer(
) )
if weights: if weights:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
...@@ -1052,7 +1052,7 @@ def _crestereo( ...@@ -1052,7 +1052,7 @@ def _crestereo(
) )
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment