Unverified Commit c35d3855 authored by Sergii Dymchenko's avatar Sergii Dymchenko Committed by GitHub
Browse files

[TorchFix] Add weights_only to torch.load (#8105)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 01dca0eb
...@@ -127,7 +127,8 @@ def load_data(traindir, valdir, args): ...@@ -127,7 +127,8 @@ def load_data(traindir, valdir, args):
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached! # Attention, as the transforms are also cached!
print(f"Loading dataset_train from {cache_path}") print(f"Loading dataset_train from {cache_path}")
dataset, _ = torch.load(cache_path) # TODO: this could probably be weights_only=True
dataset, _ = torch.load(cache_path, weights_only=False)
else: else:
# We need a default value for the variables below because args may come # We need a default value for the variables below because args may come
# from train_quantization.py which doesn't define them. # from train_quantization.py which doesn't define them.
...@@ -159,7 +160,8 @@ def load_data(traindir, valdir, args): ...@@ -159,7 +160,8 @@ def load_data(traindir, valdir, args):
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached! # Attention, as the transforms are also cached!
print(f"Loading dataset_test from {cache_path}") print(f"Loading dataset_test from {cache_path}")
dataset_test, _ = torch.load(cache_path) # TODO: this could probably be weights_only=True
dataset_test, _ = torch.load(cache_path, weights_only=False)
else: else:
if args.weights and args.test_only: if args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights) weights = torchvision.models.get_weight(args.weights)
...@@ -337,7 +339,7 @@ def main(args): ...@@ -337,7 +339,7 @@ def main(args):
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha) model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
if args.resume: if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu") checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"]) model_without_ddp.load_state_dict(checkpoint["model"])
if not args.test_only: if not args.test_only:
optimizer.load_state_dict(checkpoint["optimizer"]) optimizer.load_state_dict(checkpoint["optimizer"])
......
...@@ -74,7 +74,7 @@ def main(args): ...@@ -74,7 +74,7 @@ def main(args):
model_without_ddp = model.module model_without_ddp = model.module
if args.resume: if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu") checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"]) model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"]) optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
......
...@@ -287,8 +287,7 @@ def average_checkpoints(inputs): ...@@ -287,8 +287,7 @@ def average_checkpoints(inputs):
for fpath in inputs: for fpath in inputs:
with open(fpath, "rb") as f: with open(fpath, "rb") as f:
state = torch.load( state = torch.load(
f, f, map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")), weights_only=True
map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")),
) )
# Copies over the settings from the first checkpoint # Copies over the settings from the first checkpoint
if new_state is None: if new_state is None:
...@@ -367,7 +366,7 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T ...@@ -367,7 +366,7 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
# Deep copy to avoid side effects on the model object. # Deep copy to avoid side effects on the model object.
model = copy.deepcopy(model) model = copy.deepcopy(model)
checkpoint = torch.load(checkpoint_path, map_location="cpu") checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
# Load the weights to the model to validate that everything works # Load the weights to the model to validate that everything works
# and remove unnecessary weights (such as auxiliaries, etc.) # and remove unnecessary weights (such as auxiliaries, etc.)
......
...@@ -262,7 +262,7 @@ def load_checkpoint(args): ...@@ -262,7 +262,7 @@ def load_checkpoint(args):
utils.setup_ddp(args) utils.setup_ddp(args)
if not args.weights: if not args.weights:
checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu")) checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"), weights_only=True)
if "model" in checkpoint: if "model" in checkpoint:
experiment_args = checkpoint["args"] experiment_args = checkpoint["args"]
model = torchvision.prototype.models.depth.stereo.__dict__[experiment_args.model](weights=None) model = torchvision.prototype.models.depth.stereo.__dict__[experiment_args.model](weights=None)
......
...@@ -498,7 +498,7 @@ def main(args): ...@@ -498,7 +498,7 @@ def main(args):
# load them from checkpoint if needed # load them from checkpoint if needed
args.start_step = 0 args.start_step = 0
if args.resume_path is not None: if args.resume_path is not None:
checkpoint = torch.load(args.resume_path, map_location="cpu") checkpoint = torch.load(args.resume_path, map_location="cpu", weights_only=True)
if "model" in checkpoint: if "model" in checkpoint:
# this means the user requested to resume from a training checkpoint # this means the user requested to resume from a training checkpoint
model_without_ddp.load_state_dict(checkpoint["model"]) model_without_ddp.load_state_dict(checkpoint["model"])
......
...@@ -288,7 +288,7 @@ def main(args): ...@@ -288,7 +288,7 @@ def main(args):
) )
if args.resume: if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu") checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"]) model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"]) optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
......
...@@ -226,7 +226,7 @@ def main(args): ...@@ -226,7 +226,7 @@ def main(args):
model_without_ddp = model model_without_ddp = model
if args.resume is not None: if args.resume is not None:
checkpoint = torch.load(args.resume, map_location="cpu") checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"]) model_without_ddp.load_state_dict(checkpoint["model"])
if args.test_only: if args.test_only:
......
...@@ -223,7 +223,7 @@ def main(args): ...@@ -223,7 +223,7 @@ def main(args):
lr_scheduler = main_lr_scheduler lr_scheduler = main_lr_scheduler
if args.resume: if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu") checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"], strict=not args.test_only) model_without_ddp.load_state_dict(checkpoint["model"], strict=not args.test_only)
if not args.test_only: if not args.test_only:
optimizer.load_state_dict(checkpoint["optimizer"]) optimizer.load_state_dict(checkpoint["optimizer"])
......
...@@ -101,7 +101,7 @@ def main(args): ...@@ -101,7 +101,7 @@ def main(args):
model = EmbeddingNet() model = EmbeddingNet()
if args.resume: if args.resume:
model.load_state_dict(torch.load(args.resume)) model.load_state_dict(torch.load(args.resume, weights_only=True))
model.to(device) model.to(device)
......
...@@ -164,7 +164,7 @@ def main(args): ...@@ -164,7 +164,7 @@ def main(args):
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_train from {cache_path}") print(f"Loading dataset_train from {cache_path}")
dataset, _ = torch.load(cache_path) dataset, _ = torch.load(cache_path, weights_only=True)
dataset.transform = transform_train dataset.transform = transform_train
else: else:
if args.distributed: if args.distributed:
...@@ -201,7 +201,7 @@ def main(args): ...@@ -201,7 +201,7 @@ def main(args):
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_test from {cache_path}") print(f"Loading dataset_test from {cache_path}")
dataset_test, _ = torch.load(cache_path) dataset_test, _ = torch.load(cache_path, weights_only=True)
dataset_test.transform = transform_test dataset_test.transform = transform_test
else: else:
if args.distributed: if args.distributed:
...@@ -295,7 +295,7 @@ def main(args): ...@@ -295,7 +295,7 @@ def main(args):
model_without_ddp = model.module model_without_ddp = model.module
if args.resume: if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu") checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
model_without_ddp.load_state_dict(checkpoint["model"]) model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"]) optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
......
...@@ -1024,7 +1024,8 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn): ...@@ -1024,7 +1024,8 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
# "23_23_1.7": ... # "23_23_1.7": ...
# } # }
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt") p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
true_cv2_results = torch.load(p)
true_cv2_results = torch.load(p, weights_only=False)
if image_size == "small": if image_size == "small":
tensor = ( tensor = (
......
...@@ -149,7 +149,7 @@ def _assert_expected(output, name, prec=None, atol=None, rtol=None): ...@@ -149,7 +149,7 @@ def _assert_expected(output, name, prec=None, atol=None, rtol=None):
if binary_size > MAX_PICKLE_SIZE: if binary_size > MAX_PICKLE_SIZE:
raise RuntimeError(f"The output for {filename}, is larger than 50kb - got {binary_size}kb") raise RuntimeError(f"The output for {filename}, is larger than 50kb - got {binary_size}kb")
else: else:
expected = torch.load(expected_file) expected = torch.load(expected_file, weights_only=True)
rtol = rtol or prec # keeping prec param for legacy reason, but could be removed ideally rtol = rtol or prec # keeping prec param for legacy reason, but could be removed ideally
atol = atol or prec atol = atol or prec
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False, check_device=False) torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False, check_device=False)
...@@ -747,7 +747,7 @@ def test_segmentation_model(model_fn, dev): ...@@ -747,7 +747,7 @@ def test_segmentation_model(model_fn, dev):
# so instead of validating the probability scores, check that the class # so instead of validating the probability scores, check that the class
# predictions match. # predictions match.
expected_file = _get_expected_file(model_name) expected_file = _get_expected_file(model_name)
expected = torch.load(expected_file) expected = torch.load(expected_file, weights_only=True)
torch.testing.assert_close( torch.testing.assert_close(
out.argmax(dim=1), expected.argmax(dim=1), rtol=prec, atol=prec, check_device=False out.argmax(dim=1), expected.argmax(dim=1), rtol=prec, atol=prec, check_device=False
) )
...@@ -847,7 +847,7 @@ def test_detection_model(model_fn, dev): ...@@ -847,7 +847,7 @@ def test_detection_model(model_fn, dev):
# as in NMSTester.test_nms_cuda to see if this is caused by duplicate # as in NMSTester.test_nms_cuda to see if this is caused by duplicate
# scores. # scores.
expected_file = _get_expected_file(model_name) expected_file = _get_expected_file(model_name)
expected = torch.load(expected_file) expected = torch.load(expected_file, weights_only=True)
torch.testing.assert_close( torch.testing.assert_close(
output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, check_device=False, check_dtype=False output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, check_device=False, check_dtype=False
) )
......
...@@ -215,7 +215,7 @@ class TestCommon: ...@@ -215,7 +215,7 @@ class TestCommon:
with io.BytesIO() as buffer: with io.BytesIO() as buffer:
torch.save(sample, buffer) torch.save(sample, buffer)
buffer.seek(0) buffer.seek(0)
assert_samples_equal(torch.load(buffer), sample) assert_samples_equal(torch.load(buffer, weights_only=True), sample)
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_infinite_buffer_size(self, dataset_mock, config): def test_infinite_buffer_size(self, dataset_mock, config):
......
...@@ -3176,7 +3176,8 @@ class TestGaussianBlur: ...@@ -3176,7 +3176,8 @@ class TestGaussianBlur:
# "26_28_1__23_23_1.7": cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7), # "26_28_1__23_23_1.7": cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7),
# } # }
REFERENCE_GAUSSIAN_BLUR_IMAGE_RESULTS = torch.load( REFERENCE_GAUSSIAN_BLUR_IMAGE_RESULTS = torch.load(
Path(__file__).parent / "assets" / "gaussian_blur_opencv_results.pt" Path(__file__).parent / "assets" / "gaussian_blur_opencv_results.pt",
weights_only=False,
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -375,7 +375,7 @@ def test_flow_to_image(batch): ...@@ -375,7 +375,7 @@ def test_flow_to_image(batch):
assert img.shape == (2, 3, h, w) if batch else (3, h, w) assert img.shape == (2, 3, h, w) if batch else (3, h, w)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt") path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
expected_img = torch.load(path, map_location="cpu") expected_img = torch.load(path, map_location="cpu", weights_only=True)
if batch: if batch:
expected_img = torch.stack([expected_img, expected_img]) expected_img = torch.stack([expected_img, expected_img])
......
...@@ -84,7 +84,7 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str ...@@ -84,7 +84,7 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str
file = os.path.join(root, file) file = os.path.join(root, file)
if check_integrity(file): if check_integrity(file):
return torch.load(file) return torch.load(file, weights_only=True)
else: else:
msg = ( msg = (
"The meta file {} is not present in the root directory or is corrupted. " "The meta file {} is not present in the root directory or is corrupted. "
......
...@@ -116,7 +116,7 @@ class MNIST(VisionDataset): ...@@ -116,7 +116,7 @@ class MNIST(VisionDataset):
# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
# directly. # directly.
data_file = self.training_file if self.train else self.test_file data_file = self.training_file if self.train else self.test_file
return torch.load(os.path.join(self.processed_folder, data_file)) return torch.load(os.path.join(self.processed_folder, data_file), weights_only=True)
def _load_data(self): def _load_data(self):
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte" image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
......
...@@ -106,7 +106,7 @@ class PhotoTour(VisionDataset): ...@@ -106,7 +106,7 @@ class PhotoTour(VisionDataset):
self.cache() self.cache()
# load the serialized data # load the serialized data
self.data, self.labels, self.matches = torch.load(self.data_file) self.data, self.labels, self.matches = torch.load(self.data_file, weights_only=True)
def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]: def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment