Unverified Commit a1668cc7 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Use `weights_only` only if torch >= 1.13 (#28506)



* fix

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 3005f965
...@@ -129,6 +129,7 @@ if is_torch_available(): ...@@ -129,6 +129,7 @@ if is_torch_available():
XLMWithLMHeadModel, XLMWithLMHeadModel,
XLNetLMHeadModel, XLNetLMHeadModel,
) )
from .pytorch_utils import is_torch_greater_or_equal_than_1_13
logging.set_verbosity_info() logging.set_verbosity_info()
...@@ -329,7 +330,11 @@ def convert_pt_checkpoint_to_tf( ...@@ -329,7 +330,11 @@ def convert_pt_checkpoint_to_tf(
if compare_with_pt_model: if compare_with_pt_model:
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu", weights_only=True) state_dict = torch.load(
pytorch_checkpoint_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
)
pt_model = pt_model_class.from_pretrained( pt_model = pt_model_class.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict pretrained_model_name_or_path=None, config=config, state_dict=state_dict
) )
......
...@@ -50,6 +50,8 @@ def load_pytorch_checkpoint_in_flax_state_dict( ...@@ -50,6 +50,8 @@ def load_pytorch_checkpoint_in_flax_state_dict(
"""Load pytorch checkpoints in a flax model""" """Load pytorch checkpoints in a flax model"""
try: try:
import torch # noqa: F401 import torch # noqa: F401
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
logger.error( logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see" "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
...@@ -68,7 +70,7 @@ def load_pytorch_checkpoint_in_flax_state_dict( ...@@ -68,7 +70,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
for k in f.keys(): for k in f.keys():
pt_state_dict[k] = f.get_tensor(k) pt_state_dict[k] = f.get_tensor(k)
else: else:
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=True) pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
...@@ -245,11 +247,13 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): ...@@ -245,11 +247,13 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
import torch import torch
from .pytorch_utils import is_torch_greater_or_equal_than_1_13
# Load the index # Load the index
flax_state_dict = {} flax_state_dict = {}
for shard_file in shard_filenames: for shard_file in shard_filenames:
# load using msgpack utils # load using msgpack utils
pt_state_dict = torch.load(shard_file, weights_only=True) pt_state_dict = torch.load(shard_file, weights_only=is_torch_greater_or_equal_than_1_13)
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
model_prefix = flax_model.base_model_prefix model_prefix = flax_model.base_model_prefix
......
...@@ -167,6 +167,8 @@ def load_pytorch_checkpoint_in_tf2_model( ...@@ -167,6 +167,8 @@ def load_pytorch_checkpoint_in_tf2_model(
import tensorflow as tf # noqa: F401 import tensorflow as tf # noqa: F401
import torch # noqa: F401 import torch # noqa: F401
from safetensors.torch import load_file as safe_load_file # noqa: F401 from safetensors.torch import load_file as safe_load_file # noqa: F401
from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
except ImportError: except ImportError:
logger.error( logger.error(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see " "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
...@@ -186,7 +188,7 @@ def load_pytorch_checkpoint_in_tf2_model( ...@@ -186,7 +188,7 @@ def load_pytorch_checkpoint_in_tf2_model(
if pt_path.endswith(".safetensors"): if pt_path.endswith(".safetensors"):
state_dict = safe_load_file(pt_path) state_dict = safe_load_file(pt_path)
else: else:
state_dict = torch.load(pt_path, map_location="cpu", weights_only=True) state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
pt_state_dict.update(state_dict) pt_state_dict.update(state_dict)
......
...@@ -48,6 +48,7 @@ from .pytorch_utils import ( # noqa: F401 ...@@ -48,6 +48,7 @@ from .pytorch_utils import ( # noqa: F401
apply_chunking_to_forward, apply_chunking_to_forward,
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
id_tensor_storage, id_tensor_storage,
is_torch_greater_or_equal_than_1_13,
prune_conv1d_layer, prune_conv1d_layer,
prune_layer, prune_layer,
prune_linear_layer, prune_linear_layer,
...@@ -481,7 +482,11 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): ...@@ -481,7 +482,11 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
error_message += f"\nMissing key(s): {str_unexpected_keys}." error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message) raise RuntimeError(error_message)
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True) loader = (
safe_load_file
if load_safe
else partial(torch.load, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
)
for shard_file in shard_files: for shard_file in shard_files:
state_dict = loader(os.path.join(folder, shard_file)) state_dict = loader(os.path.join(folder, shard_file))
...@@ -525,7 +530,12 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ...@@ -525,7 +530,12 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
and is_zipfile(checkpoint_file) and is_zipfile(checkpoint_file)
): ):
extra_args = {"mmap": True} extra_args = {"mmap": True}
return torch.load(checkpoint_file, map_location=map_location, weights_only=True, **extra_args) return torch.load(
checkpoint_file,
map_location=map_location,
weights_only=is_torch_greater_or_equal_than_1_13,
**extra_args,
)
except Exception as e: except Exception as e:
try: try:
with open(checkpoint_file) as f: with open(checkpoint_file) as f:
......
...@@ -37,6 +37,7 @@ from ...modeling_outputs import ( ...@@ -37,6 +37,7 @@ from ...modeling_outputs import (
XVectorOutput, XVectorOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -1333,7 +1334,11 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): ...@@ -1333,7 +1334,11 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
cache_dir=cache_dir, cache_dir=cache_dir,
) )
state_dict = torch.load(weight_path, map_location="cpu", weights_only=True) state_dict = torch.load(
weight_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
)
except EnvironmentError: except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
......
...@@ -64,7 +64,7 @@ from .modelcard import TrainingSummary ...@@ -64,7 +64,7 @@ from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import ( from .trainer_callback import (
CallbackHandler, CallbackHandler,
...@@ -2103,7 +2103,11 @@ class Trainer: ...@@ -2103,7 +2103,11 @@ class Trainer:
logger.warning( logger.warning(
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported." "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
) )
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True) state_dict = torch.load(
weights_file,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
)
# Required for smp to not auto-translate state_dict from hf to smp (is already smp). # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
state_dict["_smp_is_partial"] = False state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True) load_result = model.load_state_dict(state_dict, strict=True)
...@@ -2116,7 +2120,11 @@ class Trainer: ...@@ -2116,7 +2120,11 @@ class Trainer:
if self.args.save_safetensors and os.path.isfile(safe_weights_file): if self.args.save_safetensors and os.path.isfile(safe_weights_file):
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu") state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
else: else:
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True) state_dict = torch.load(
weights_file,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
)
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs # which takes *args instead of **kwargs
...@@ -2184,7 +2192,11 @@ class Trainer: ...@@ -2184,7 +2192,11 @@ class Trainer:
if self.args.save_safetensors and os.path.isfile(best_safe_model_path): if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else: else:
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True) state_dict = torch.load(
best_model_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
)
state_dict["_smp_is_partial"] = False state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True) load_result = model.load_state_dict(state_dict, strict=True)
...@@ -2213,7 +2225,11 @@ class Trainer: ...@@ -2213,7 +2225,11 @@ class Trainer:
if self.args.save_safetensors and os.path.isfile(best_safe_model_path): if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else: else:
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True) state_dict = torch.load(
best_model_path,
map_location="cpu",
weights_only=is_torch_greater_or_equal_than_1_13,
)
# If the model is on the GPU, it still works! # If the model is on the GPU, it still works!
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment