Unverified Commit a638de19 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `weights_only` (#28725)



fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent d6ac8f4a
...@@ -330,10 +330,11 @@ def convert_pt_checkpoint_to_tf( ...@@ -330,10 +330,11 @@ def convert_pt_checkpoint_to_tf(
if compare_with_pt_model: if compare_with_pt_model:
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
state_dict = torch.load( state_dict = torch.load(
pytorch_checkpoint_path, pytorch_checkpoint_path,
map_location="cpu", map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13, **weights_only_kwarg,
) )
pt_model = pt_model_class.from_pretrained( pt_model = pt_model_class.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict pretrained_model_name_or_path=None, config=config, state_dict=state_dict
......
...@@ -74,7 +74,8 @@ def load_pytorch_checkpoint_in_flax_state_dict( ...@@ -74,7 +74,8 @@ def load_pytorch_checkpoint_in_flax_state_dict(
) )
raise raise
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13) weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
...@@ -252,7 +253,8 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): ...@@ -252,7 +253,8 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
flax_state_dict = {} flax_state_dict = {}
for shard_file in shard_filenames: for shard_file in shard_filenames:
# load using msgpack utils # load using msgpack utils
pt_state_dict = torch.load(shard_file, weights_only=is_torch_greater_or_equal_than_1_13) weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
model_prefix = flax_model.base_model_prefix model_prefix = flax_model.base_model_prefix
......
...@@ -188,7 +188,8 @@ def load_pytorch_checkpoint_in_tf2_model( ...@@ -188,7 +188,8 @@ def load_pytorch_checkpoint_in_tf2_model(
if pt_path.endswith(".safetensors"): if pt_path.endswith(".safetensors"):
state_dict = safe_load_file(pt_path) state_dict = safe_load_file(pt_path)
else: else:
state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13) weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
pt_state_dict.update(state_dict) pt_state_dict.update(state_dict)
......
...@@ -482,11 +482,8 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): ...@@ -482,11 +482,8 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
error_message += f"\nMissing key(s): {str_unexpected_keys}." error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message) raise RuntimeError(error_message)
loader = ( weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
safe_load_file loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)
if load_safe
else partial(torch.load, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
)
for shard_file in shard_files: for shard_file in shard_files:
state_dict = loader(os.path.join(folder, shard_file)) state_dict = loader(os.path.join(folder, shard_file))
...@@ -530,10 +527,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ...@@ -530,10 +527,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
and is_zipfile(checkpoint_file) and is_zipfile(checkpoint_file)
): ):
extra_args = {"mmap": True} extra_args = {"mmap": True}
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
return torch.load( return torch.load(
checkpoint_file, checkpoint_file,
map_location=map_location, map_location=map_location,
weights_only=is_torch_greater_or_equal_than_1_13, **weights_only_kwarg,
**extra_args, **extra_args,
) )
except Exception as e: except Exception as e:
......
...@@ -1334,10 +1334,11 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): ...@@ -1334,10 +1334,11 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
cache_dir=cache_dir, cache_dir=cache_dir,
) )
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
state_dict = torch.load( state_dict = torch.load(
weight_path, weight_path,
map_location="cpu", map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13, **weights_only_kwarg,
) )
except EnvironmentError: except EnvironmentError:
......
...@@ -2102,6 +2102,7 @@ class Trainer: ...@@ -2102,6 +2102,7 @@ class Trainer:
) )
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt: if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
# If the model is on the GPU, it still works! # If the model is on the GPU, it still works!
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
...@@ -2120,7 +2121,7 @@ class Trainer: ...@@ -2120,7 +2121,7 @@ class Trainer:
state_dict = torch.load( state_dict = torch.load(
weights_file, weights_file,
map_location="cpu", map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13, **weights_only_kwarg,
) )
# Required for smp to not auto-translate state_dict from hf to smp (is already smp). # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
state_dict["_smp_is_partial"] = False state_dict["_smp_is_partial"] = False
...@@ -2137,7 +2138,7 @@ class Trainer: ...@@ -2137,7 +2138,7 @@ class Trainer:
state_dict = torch.load( state_dict = torch.load(
weights_file, weights_file,
map_location="cpu", map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13, **weights_only_kwarg,
) )
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
...@@ -2190,6 +2191,7 @@ class Trainer: ...@@ -2190,6 +2191,7 @@ class Trainer:
or os.path.exists(best_safe_adapter_model_path) or os.path.exists(best_safe_adapter_model_path)
): ):
has_been_loaded = True has_been_loaded = True
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
# If the 'user_content.pt' file exists, load with the new smp api. # If the 'user_content.pt' file exists, load with the new smp api.
...@@ -2209,7 +2211,7 @@ class Trainer: ...@@ -2209,7 +2211,7 @@ class Trainer:
state_dict = torch.load( state_dict = torch.load(
best_model_path, best_model_path,
map_location="cpu", map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13, **weights_only_kwarg,
) )
state_dict["_smp_is_partial"] = False state_dict["_smp_is_partial"] = False
...@@ -2242,7 +2244,7 @@ class Trainer: ...@@ -2242,7 +2244,7 @@ class Trainer:
state_dict = torch.load( state_dict = torch.load(
best_model_path, best_model_path,
map_location="cpu", map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13, **weights_only_kwarg,
) )
# If the model is on the GPU, it still works! # If the model is on the GPU, it still works!
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment