Unverified Commit 871598be authored by Viktor Scherbakov's avatar Viktor Scherbakov Committed by GitHub
Browse files

Implemented safetensors checkpoints save/load for Trainer (#22498)



* implemented safetensors save/load

* remove duplicated file

* added tests

* more tests

* style fix

* fix tf tests

* change to list comprehension
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* review fixes + safe load for sharded checkpoint

* style fix

* remove rogue import

* remove partial to avoid undefined exception

* use naming alias instead of safetensors.torch

* fix safe sharding in tests

* grammar
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* update docs
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* update docs
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* minor corrections

* style

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 00b5887b
...@@ -336,7 +336,7 @@ def shard_checkpoint( ...@@ -336,7 +336,7 @@ def shard_checkpoint(
return shards, index return shards, index
def load_sharded_checkpoint(model, folder, strict=True): def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
""" """
This is the same as This is the same as
[`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict) [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
...@@ -350,6 +350,9 @@ def load_sharded_checkpoint(model, folder, strict=True): ...@@ -350,6 +350,9 @@ def load_sharded_checkpoint(model, folder, strict=True):
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint. folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
strict (`bool`, *optional`, defaults to `True`): strict (`bool`, *optional`, defaults to `True`):
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
prefer_safe (`bool`, *optional*, defaults to `False`)
If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
Returns: Returns:
`NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
...@@ -358,10 +361,32 @@ def load_sharded_checkpoint(model, folder, strict=True): ...@@ -358,10 +361,32 @@ def load_sharded_checkpoint(model, folder, strict=True):
""" """
# Load the index # Load the index
index_file = os.path.join(folder, WEIGHTS_INDEX_NAME) index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
if not os.path.isfile(index_file): safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
raise ValueError(f"Can't find a checkpoint index ({WEIGHTS_INDEX_NAME}) in {folder}.")
with open(index_file, "r", encoding="utf-8") as f: index_present = os.path.isfile(index_file)
safe_index_present = os.path.isfile(safe_index_file)
if not index_present and not (safe_index_present and is_safetensors_available()):
filenames = (
(WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) if is_safetensors_available() else (WEIGHTS_INDEX_NAME,)
)
raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")
load_safe = False
if safe_index_present:
if prefer_safe:
if is_safetensors_available():
load_safe = True # load safe due to preference
else:
logger.warning(
f"Cannot load sharded checkpoint at {folder} safely since safetensors is not installed!"
)
elif not index_present:
load_safe = True # load safe since we have no other choice
load_index = safe_index_file if load_safe else index_file
with open(load_index, "r", encoding="utf-8") as f:
index = json.load(f) index = json.load(f)
shard_files = list(set(index["weight_map"].values())) shard_files = list(set(index["weight_map"].values()))
...@@ -381,11 +406,13 @@ def load_sharded_checkpoint(model, folder, strict=True): ...@@ -381,11 +406,13 @@ def load_sharded_checkpoint(model, folder, strict=True):
error_message += f"\nMissing key(s): {str_unexpected_keys}." error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message) raise RuntimeError(error_message)
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu")
for shard_file in shard_files: for shard_file in shard_files:
state_dict = torch.load(os.path.join(folder, shard_file), map_location="cpu") state_dict = loader(os.path.join(folder, shard_file))
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
# Make sure memory is fred before we load the next state dict. # Make sure memory is freed before we load the next state dict.
del state_dict del state_dict
gc.collect() gc.collect()
......
...@@ -135,6 +135,8 @@ from .trainer_utils import ( ...@@ -135,6 +135,8 @@ from .trainer_utils import (
from .training_args import OptimizerNames, ParallelMode, TrainingArguments from .training_args import OptimizerNames, ParallelMode, TrainingArguments
from .utils import ( from .utils import (
CONFIG_NAME, CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
can_return_loss, can_return_loss,
...@@ -145,6 +147,7 @@ from .utils import ( ...@@ -145,6 +147,7 @@ from .utils import (
is_datasets_available, is_datasets_available,
is_in_notebook, is_in_notebook,
is_ipex_available, is_ipex_available,
is_safetensors_available,
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled, is_sagemaker_mp_enabled,
is_torch_compile_available, is_torch_compile_available,
...@@ -198,6 +201,10 @@ else: ...@@ -198,6 +201,10 @@ else:
IS_SAGEMAKER_MP_POST_1_10 = False IS_SAGEMAKER_MP_POST_1_10 = False
if is_safetensors_available():
import safetensors.torch
skip_first_batches = None skip_first_batches = None
if is_accelerate_available(): if is_accelerate_available():
from accelerate import __version__ as accelerate_version from accelerate import __version__ as accelerate_version
...@@ -2091,15 +2098,22 @@ class Trainer: ...@@ -2091,15 +2098,22 @@ class Trainer:
if model is None: if model is None:
model = self.model model = self.model
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile( config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
if not any(
[os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file]]
): ):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
logger.info(f"Loading model from {resume_from_checkpoint}.") logger.info(f"Loading model from {resume_from_checkpoint}.")
if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)): if os.path.isfile(config_file):
config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME)) config = PretrainedConfig.from_json_file(config_file)
checkpoint_version = config.transformers_version checkpoint_version = config.transformers_version
if checkpoint_version is not None and checkpoint_version != __version__: if checkpoint_version is not None and checkpoint_version != __version__:
logger.warning( logger.warning(
...@@ -2108,7 +2122,7 @@ class Trainer: ...@@ -2108,7 +2122,7 @@ class Trainer:
"yield to errors or unwanted behaviors." "yield to errors or unwanted behaviors."
) )
if os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file):
# If the model is on the GPU, it still works! # If the model is on the GPU, it still works!
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
...@@ -2124,7 +2138,7 @@ class Trainer: ...@@ -2124,7 +2138,7 @@ class Trainer:
logger.warning( logger.warning(
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported." "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
) )
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") state_dict = torch.load(weights_file, map_location="cpu")
# Required for smp to not auto-translate state_dict from hf to smp (is already smp). # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
state_dict["_smp_is_partial"] = False state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True) load_result = model.load_state_dict(state_dict, strict=True)
...@@ -2132,7 +2146,11 @@ class Trainer: ...@@ -2132,7 +2146,11 @@ class Trainer:
del state_dict del state_dict
else: else:
# We load the model state dict on the CPU to avoid an OOM error. # We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") if self.args.save_safetensors and os.path.isfile(safe_weights_file):
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
else:
state_dict = torch.load(weights_file, map_location="cpu")
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs # which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False) load_result = model.load_state_dict(state_dict, False)
...@@ -2141,15 +2159,18 @@ class Trainer: ...@@ -2141,15 +2159,18 @@ class Trainer:
self._issue_warnings_after_load(load_result) self._issue_warnings_after_load(load_result)
else: else:
# We load the sharded checkpoint # We load the sharded checkpoint
load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled()) load_result = load_sharded_checkpoint(
model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
)
if not is_sagemaker_mp_enabled(): if not is_sagemaker_mp_enabled():
self._issue_warnings_after_load(load_result) self._issue_warnings_after_load(load_result)
def _load_best_model(self): def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if os.path.exists(best_model_path): if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path):
if self.deepspeed: if self.deepspeed:
if self.model_wrapped is not None: if self.model_wrapped is not None:
# this removes the pre-hooks from the previous engine # this removes the pre-hooks from the previous engine
...@@ -2181,12 +2202,20 @@ class Trainer: ...@@ -2181,12 +2202,20 @@ class Trainer:
else: else:
# If the 'user_content.pt' file does NOT exist, load with the old smp api. # If the 'user_content.pt' file does NOT exist, load with the old smp api.
# Checkpoint must have been saved with the old smp api. # Checkpoint must have been saved with the old smp api.
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else:
state_dict = torch.load(best_model_path, map_location="cpu") state_dict = torch.load(best_model_path, map_location="cpu")
state_dict["_smp_is_partial"] = False state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True) load_result = model.load_state_dict(state_dict, strict=True)
else: else:
# We load the model state dict on the CPU to avoid an OOM error. # We load the model state dict on the CPU to avoid an OOM error.
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else:
state_dict = torch.load(best_model_path, map_location="cpu") state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works! # If the model is on the GPU, it still works!
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs # which takes *args instead of **kwargs
...@@ -2837,17 +2866,24 @@ class Trainer: ...@@ -2837,17 +2866,24 @@ class Trainer:
# Save a trained model and configuration using `save_pretrained()`. # Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel): if not isinstance(self.model, PreTrainedModel):
if isinstance(unwrap_model(self.model), PreTrainedModel):
if state_dict is None: if state_dict is None:
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
else: else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if state_dict is None: if self.args.save_safetensors:
state_dict = self.model.state_dict() safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else: else:
self.model.save_pretrained(output_dir, state_dict=state_dict) self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
if self.tokenizer is not None: if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
...@@ -3546,7 +3582,7 @@ class Trainer: ...@@ -3546,7 +3582,7 @@ class Trainer:
output_dir = self.args.output_dir output_dir = self.args.output_dir
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
modeling_files = [CONFIG_NAME, WEIGHTS_NAME] modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
for modeling_file in modeling_files: for modeling_file in modeling_files:
if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
......
...@@ -42,6 +42,7 @@ from .utils import ( ...@@ -42,6 +42,7 @@ from .utils import (
get_full_repo_name, get_full_repo_name,
is_accelerate_available, is_accelerate_available,
is_psutil_available, is_psutil_available,
is_safetensors_available,
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled, is_sagemaker_mp_enabled,
is_torch_available, is_torch_available,
...@@ -261,6 +262,9 @@ class TrainingArguments: ...@@ -261,6 +262,9 @@ class TrainingArguments:
save_total_limit (`int`, *optional*): save_total_limit (`int`, *optional*):
If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
`output_dir`. `output_dir`.
save_safetensors (`bool`, *optional*, defaults to `False`):
Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of
default `torch.load` and `torch.save`.
save_on_each_node (`bool`, *optional*, defaults to `False`): save_on_each_node (`bool`, *optional*, defaults to `False`):
When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on
the main one. the main one.
...@@ -720,6 +724,12 @@ class TrainingArguments: ...@@ -720,6 +724,12 @@ class TrainingArguments:
) )
}, },
) )
save_safetensors: Optional[bool] = field(
default=False,
metadata={
"help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save."
},
)
save_on_each_node: bool = field( save_on_each_node: bool = field(
default=False, default=False,
metadata={ metadata={
...@@ -1166,6 +1176,17 @@ class TrainingArguments: ...@@ -1166,6 +1176,17 @@ class TrainingArguments:
f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}." f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}."
) )
safetensors_available = is_safetensors_available()
if self.save_safetensors and not safetensors_available:
raise ValueError(f"--save_safetensors={self.save_safetensors} requires safetensors to be installed!")
if not self.save_safetensors and safetensors_available:
logger.info(
f"Found safetensors installation, but --save_safetensors={self.save_safetensors}. "
f"Safetensors should be a preferred weights saving format due to security and performance reasons. "
f"If your model cannot be saved by safetensors please feel free to open an issue at "
f"https://github.com/huggingface/safetensors!"
)
if self.load_best_model_at_end and self.metric_for_best_model is None: if self.load_best_model_at_end and self.metric_for_best_model is None:
self.metric_for_best_model = "loss" self.metric_for_best_model = "loss"
if self.greater_is_better is None and self.metric_for_best_model is not None: if self.greater_is_better is None and self.metric_for_best_model is not None:
......
...@@ -25,6 +25,7 @@ import sys ...@@ -25,6 +25,7 @@ import sys
import tempfile import tempfile
import time import time
import unittest import unittest
from itertools import product
from pathlib import Path from pathlib import Path
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
...@@ -54,6 +55,7 @@ from transformers.testing_utils import ( ...@@ -54,6 +55,7 @@ from transformers.testing_utils import (
require_intel_extension_for_pytorch, require_intel_extension_for_pytorch,
require_optuna, require_optuna,
require_ray, require_ray,
require_safetensors,
require_sentencepiece, require_sentencepiece,
require_sigopt, require_sigopt,
require_tokenizers, require_tokenizers,
...@@ -73,10 +75,13 @@ from transformers.testing_utils import ( ...@@ -73,10 +75,13 @@ from transformers.testing_utils import (
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.training_args import OptimizerNames from transformers.training_args import OptimizerNames
from transformers.utils import ( from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
is_apex_available, is_apex_available,
is_bitsandbytes_available, is_bitsandbytes_available,
is_safetensors_available,
is_torchdistx_available, is_torchdistx_available,
) )
from transformers.utils.hp_naming import TrialShortNamer from transformers.utils.hp_naming import TrialShortNamer
...@@ -102,6 +107,9 @@ if is_torch_available(): ...@@ -102,6 +107,9 @@ if is_torch_available():
) )
from transformers.modeling_utils import unwrap_model from transformers.modeling_utils import unwrap_model
if is_safetensors_available():
import safetensors.torch
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
...@@ -345,8 +353,9 @@ if is_torch_available(): ...@@ -345,8 +353,9 @@ if is_torch_available():
class TrainerIntegrationCommon: class TrainerIntegrationCommon:
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True): def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=False):
file_list = [WEIGHTS_NAME, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"] weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME
file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
if is_pretrained: if is_pretrained:
file_list.append("config.json") file_list.append("config.json")
for step in range(freq, total, freq): for step in range(freq, total, freq):
...@@ -356,7 +365,7 @@ class TrainerIntegrationCommon: ...@@ -356,7 +365,7 @@ class TrainerIntegrationCommon:
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename))) self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
def check_best_model_has_been_loaded( def check_best_model_has_been_loaded(
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=False
): ):
checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}") checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}")
log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history
...@@ -370,7 +379,10 @@ class TrainerIntegrationCommon: ...@@ -370,7 +379,10 @@ class TrainerIntegrationCommon:
best_model.to(trainer.args.device) best_model.to(trainer.args.device)
else: else:
best_model = RegressionModel() best_model = RegressionModel()
if not safe_weights:
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME)) state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
else:
state_dict = safetensors.torch.load_file(os.path.join(checkpoint, SAFE_WEIGHTS_NAME))
best_model.load_state_dict(state_dict) best_model.load_state_dict(state_dict)
best_model.to(trainer.args.device) best_model.to(trainer.args.device)
self.assertTrue(torch.allclose(best_model.a, trainer.model.a)) self.assertTrue(torch.allclose(best_model.a, trainer.model.a))
...@@ -394,24 +406,43 @@ class TrainerIntegrationCommon: ...@@ -394,24 +406,43 @@ class TrainerIntegrationCommon:
_ = log1.pop(key, None) _ = log1.pop(key, None)
self.assertEqual(log, log1) self.assertEqual(log, log1)
def convert_to_sharded_checkpoint(self, folder): def convert_to_sharded_checkpoint(self, folder, save_safe=False, load_safe=False):
# Converts a checkpoint of a regression model to a sharded checkpoint. # Converts a checkpoint of a regression model to a sharded checkpoint.
state_dict = torch.load(os.path.join(folder, WEIGHTS_NAME)) if load_safe:
os.remove(os.path.join(folder, WEIGHTS_NAME)) loader = safetensors.torch.load_file
weights_file = os.path.join(folder, SAFE_WEIGHTS_NAME)
else:
loader = torch.load
weights_file = os.path.join(folder, WEIGHTS_NAME)
if save_safe:
extension = "safetensors"
saver = safetensors.torch.save_file
index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
shard_name = SAFE_WEIGHTS_NAME
else:
extension = "bin"
saver = torch.save
index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
shard_name = WEIGHTS_NAME
state_dict = loader(weights_file)
os.remove(weights_file)
keys = list(state_dict.keys()) keys = list(state_dict.keys())
shard_files = [ shard_files = [
WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(keys):05d}.bin") for idx in range(len(keys)) shard_name.replace(f".{extension}", f"-{idx+1:05d}-of-{len(keys):05d}.{extension}")
for idx in range(len(keys))
] ]
index = {"metadata": {}, "weight_map": {key: shard_files[i] for i, key in enumerate(keys)}} index = {"metadata": {}, "weight_map": {key: shard_files[i] for i, key in enumerate(keys)}}
save_index_file = os.path.join(folder, WEIGHTS_INDEX_NAME) with open(index_file, "w", encoding="utf-8") as f:
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n" content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content) f.write(content)
for param_name, shard_file in zip(keys, shard_files): for param_name, shard_file in zip(keys, shard_files):
torch.save({param_name: state_dict[param_name]}, os.path.join(folder, shard_file)) saver({param_name: state_dict[param_name]}, os.path.join(folder, shard_file))
@require_torch @require_torch
...@@ -1132,6 +1163,26 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1132,6 +1163,26 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train() trainer.train()
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False) self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
@require_safetensors
def test_safe_checkpoints(self):
for save_safetensors in [True, False]:
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5, save_safetensors=save_safetensors)
trainer.train()
self.check_saved_checkpoints(
tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), safe_weights=save_safetensors
)
# With a regular model that is not a PreTrainedModel
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir, save_steps=5, pretrained=False, save_safetensors=save_safetensors
)
trainer.train()
self.check_saved_checkpoints(
tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors
)
@require_torch_multi_gpu @require_torch_multi_gpu
def test_run_seq2seq_double_train_wrap_once(self): def test_run_seq2seq_double_train_wrap_once(self):
# test that we don't wrap the model more than once # test that we don't wrap the model more than once
...@@ -1373,6 +1424,42 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1373,6 +1424,42 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(b, b1) self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1) self.check_trainer_state_are_the_same(state, state1)
@require_safetensors
@require_torch_up_to_2_gpus
def test_resume_training_with_safe_checkpoint(self):
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
for initial_safe in [False, True]:
for loaded_safe in [False, True]:
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
train_len=128,
save_steps=5,
learning_rate=0.1,
save_safetensors=initial_safe,
)
trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()
state = dataclasses.asdict(trainer.state)
checkpoint = os.path.join(tmpdir, "checkpoint-5")
self.convert_to_sharded_checkpoint(checkpoint, load_safe=initial_safe, save_safe=loaded_safe)
# Reinitialize trainer
trainer = get_regression_trainer(
output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1, save_safetensors=loaded_safe
)
trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
@require_torch_up_to_2_gpus @require_torch_up_to_2_gpus
def test_resume_training_with_gradient_accumulation(self): def test_resume_training_with_gradient_accumulation(self):
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
...@@ -1522,6 +1609,30 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1522,6 +1609,30 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False) self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=False)
self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss", is_pretrained=False) self.check_best_model_has_been_loaded(tmpdir, 5, total, trainer, "eval_loss", is_pretrained=False)
@require_safetensors
def test_load_best_model_from_safetensors(self):
total = int(self.n_epochs * 64 / self.batch_size)
for save_safetensors, pretrained in product([False, True], [False, True]):
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
a=1.5,
b=2.5,
output_dir=tmpdir,
learning_rate=0.1,
eval_steps=5,
evaluation_strategy="steps",
save_steps=5,
load_best_model_at_end=True,
save_safetensors=save_safetensors,
pretrained=pretrained,
)
self.assertFalse(trainer.args.greater_is_better)
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, total, is_pretrained=pretrained, safe_weights=save_safetensors)
self.check_best_model_has_been_loaded(
tmpdir, 5, total, trainer, "eval_loss", is_pretrained=pretrained, safe_weights=save_safetensors
)
@slow @slow
def test_trainer_eval_mrpc(self): def test_trainer_eval_mrpc(self):
MODEL_ID = "bert-base-cased-finetuned-mrpc" MODEL_ID = "bert-base-cased-finetuned-mrpc"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment