Unverified Commit 683baf8e authored by Adam J. Stewart's avatar Adam J. Stewart Committed by GitHub
Browse files

Check sha256 of weights (#7219)


Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 8811c915
......@@ -265,7 +265,7 @@ def inception_v3(
if quantize and not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if not quantize and not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
......
......@@ -149,6 +149,6 @@ def mobilenet_v2(
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -149,7 +149,7 @@ def _mobilenet_v3_model(
torch.ao.quantization.prepare_qat(model, inplace=True)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if quantize:
torch.ao.quantization.convert(model, inplace=True)
......
......@@ -144,7 +144,7 @@ def _resnet(
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -108,7 +108,7 @@ def _shufflenetv2(
quantize_model(model, backend)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -397,7 +397,7 @@ def _regnet(
model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -298,7 +298,7 @@ def _resnet(
model = ResNet(block, layers, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -275,7 +275,7 @@ def deeplabv3_resnet50(
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -331,7 +331,7 @@ def deeplabv3_resnet101(
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -385,6 +385,6 @@ def deeplabv3_mobilenet_v3_large(
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -168,7 +168,7 @@ def fcn_resnet50(
model = _fcn_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -227,6 +227,6 @@ def fcn_resnet101(
model = _fcn_resnet(backbone, num_classes, aux_loss)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -173,6 +173,6 @@ def lraspp_mobilenet_v3_large(
model = _lraspp_mobilenetv3(backbone, num_classes)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -178,7 +178,7 @@ def _shufflenetv2(
model = ShuffleNetV2(*args, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -109,7 +109,7 @@ def _squeezenet(
model = SqueezeNet(version, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -639,7 +639,7 @@ def _swin_transformer(
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -102,7 +102,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -593,7 +593,7 @@ def _mvit(
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -303,7 +303,7 @@ def _video_resnet(
model = VideoResNet(block, conv_makers, layers, stem, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -214,6 +214,6 @@ def s3d(*, weights: Optional[S3D_Weights] = None, progress: bool = True, **kwarg
model = S3D(**kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......@@ -497,7 +497,7 @@ def _swin_transformer3d(
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -332,7 +332,7 @@ def _vision_transformer(
)
if weights:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
......@@ -1052,7 +1052,7 @@ def _crestereo(
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment