You need to sign in or sign up before continuing.
Unverified Commit 683baf8e authored by Adam J. Stewart's avatar Adam J. Stewart Committed by GitHub
Browse files

Check sha256 of weights (#7219)


Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 8811c915
......@@ -265,7 +265,7 @@ def inception_v3(
if quantize and not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if not quantize and not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
......
......@@ -149,6 +149,6 @@ def mobilenet_v2(
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -149,7 +149,7 @@ def _mobilenet_v3_model(
torch.ao.quantization.prepare_qat(model, inplace=True)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if quantize:
torch.ao.quantization.convert(model, inplace=True)
......
......@@ -144,7 +144,7 @@ def _resnet(
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -108,7 +108,7 @@ def _shufflenetv2(
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -397,7 +397,7 @@ def _regnet(
model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -298,7 +298,7 @@ def _resnet(
model = ResNet(block, layers, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -275,7 +275,7 @@ def deeplabv3_resnet50(
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -331,7 +331,7 @@ def deeplabv3_resnet101(
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -385,6 +385,6 @@ def deeplabv3_mobilenet_v3_large(
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -168,7 +168,7 @@ def fcn_resnet50(
model = _fcn_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -227,6 +227,6 @@ def fcn_resnet101(
model = _fcn_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -173,6 +173,6 @@ def lraspp_mobilenet_v3_large(
model = _lraspp_mobilenetv3(backbone, num_classes)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -178,7 +178,7 @@ def _shufflenetv2(
model = ShuffleNetV2(*args, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -109,7 +109,7 @@ def _squeezenet(
model = SqueezeNet(version, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -639,7 +639,7 @@ def _swin_transformer(
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -102,7 +102,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -593,7 +593,7 @@ def _mvit(
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -303,7 +303,7 @@ def _video_resnet(
model = VideoResNet(block, conv_makers, layers, stem, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -214,6 +214,6 @@ def s3d(*, weights: Optional[S3D_Weights] = None, progress: bool = True, **kwarg
model = S3D(**kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -497,7 +497,7 @@ def _swin_transformer3d(
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -332,7 +332,7 @@ def _vision_transformer(
)
if weights:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -1052,7 +1052,7 @@ def _crestereo(
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment