Unverified Commit a1668cc7 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Use `weights_only` only if torch >= 1.13 (#28506)



* fix

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 3005f965
......@@ -129,6 +129,7 @@ if is_torch_available():
XLMWithLMHeadModel,
XLNetLMHeadModel,
)
from .pytorch_utils import is_torch_greater_or_equal_than_1_13
logging.set_verbosity_info()
......@@ -329,7 +330,11 @@ def convert_pt_checkpoint_to_tf(
if compare_with_pt_model:
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu", weights_only=True)
state_dict = torch.load(
pytorch_checkpoint_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
)
pt_model = pt_model_class.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
)
......
......@@ -50,6 +50,8 @@ def load_pytorch_checkpoint_in_flax_state_dict(
"""Load pytorch checkpoints in a flax model"""
try:
import torch # noqa: F401
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
......@@ -68,7 +70,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
for k in f.keys():
pt_state_dict[k] = f.get_tensor(k)
else:
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
......@@ -245,11 +247,13 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
import torch
from .pytorch_utils import is_torch_greater_or_equal_than_1_13
# Load the index
flax_state_dict = {}
for shard_file in shard_filenames:
# load using msgpack utils
pt_state_dict = torch.load(shard_file, weights_only=True)
pt_state_dict = torch.load(shard_file, weights_only=is_torch_greater_or_equal_than_1_13)
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
model_prefix = flax_model.base_model_prefix
......
......@@ -167,6 +167,8 @@ def load_pytorch_checkpoint_in_tf2_model(
import tensorflow as tf # noqa: F401
import torch # noqa: F401
from safetensors.torch import load_file as safe_load_file # noqa: F401
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
except ImportError:
logger.error(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
......@@ -186,7 +188,7 @@ def load_pytorch_checkpoint_in_tf2_model(
if pt_path.endswith(".safetensors"):
state_dict = safe_load_file(pt_path)
else:
state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
pt_state_dict.update(state_dict)
......
......@@ -48,6 +48,7 @@ from .pytorch_utils import ( # noqa: F401
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
id_tensor_storage,
is_torch_greater_or_equal_than_1_13,
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
......@@ -481,7 +482,11 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True)
loader = (
safe_load_file
if load_safe
else partial(torch.load, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
)
for shard_file in shard_files:
state_dict = loader(os.path.join(folder, shard_file))
......@@ -525,7 +530,12 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
and is_zipfile(checkpoint_file)
):
extra_args = {"mmap": True}
return torch.load(checkpoint_file, map_location=map_location, weights_only=True, **extra_args)
return torch.load(
checkpoint_file,
map_location=map_location,
weights_only=is_torch_greater_or_equal_than_1_13,
**extra_args,
)
except Exception as e:
try:
with open(checkpoint_file) as f:
......
......@@ -37,6 +37,7 @@ from ...modeling_outputs import (
XVectorOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
......@@ -1333,7 +1334,11 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
cache_dir=cache_dir,
)
state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
state_dict = torch.load(
weight_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
......
......@@ -64,7 +64,7 @@ from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
......@@ -2103,7 +2103,11 @@ class Trainer:
logger.warning(
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
)
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
state_dict = torch.load(
weights_file,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
)
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
......@@ -2116,7 +2120,11 @@ class Trainer:
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
else:
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
state_dict = torch.load(
weights_file,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
)
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
......@@ -2184,7 +2192,11 @@ class Trainer:
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else:
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
state_dict = torch.load(
best_model_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
)
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
......@@ -2213,7 +2225,11 @@ class Trainer:
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else:
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
state_dict = torch.load(
best_model_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
)
# If the model is on the GPU, it still works!
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment