Unverified Commit c0328a6c authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Load checkpoint without re-creating the model (#11318)

parent 95037a16
......@@ -271,7 +271,7 @@ class PretrainedConfig(object):
self._name_or_path = str(kwargs.pop("name_or_path", ""))
# Drop the transformers version info
kwargs.pop("transformers_version", None)
self.transformers_version = kwargs.pop("transformers_version", None)
# Additional attributes without default values
for key, value in kwargs.items():
......
......@@ -55,9 +55,12 @@ from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from . import __version__
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .dependency_versions_check import dep_version_check
from .file_utils import (
CONFIG_NAME,
WEIGHTS_NAME,
is_apex_available,
is_datasets_available,
......@@ -999,14 +1002,23 @@ class Trainer:
logger.info(f"Loading model from {resume_from_checkpoint}).")
if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
checkpoint_version = config.transformers_version
if checkpoint_version is not None and checkpoint_version != __version__:
logger.warn(
f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
f"Transformers but your current version is {__version__}. This is not recommended and could "
"yield to errors or unwanted behaviors."
)
if self.deepspeed:
# will be resumed in deepspeed_init
pass
elif isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(resume_from_checkpoint)
model_reloaded = True
else:
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME))
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
self.model.load_state_dict(state_dict)
# If model was re-initialized, put it on the right device and update self.model_wrapped
......@@ -1293,12 +1305,9 @@ class Trainer:
logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)
if isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
if self.place_model_on_device:
self.model = self.model.to(args.device)
else:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME), map_location="cpu")
# If the model is on the GPU, it still works!
self.model.load_state_dict(state_dict)
if self.deepspeed:
......
......@@ -725,6 +725,46 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
def test_resume_training_with_frozen_params(self):
if torch.cuda.device_count() > 2:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
return
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
train_len=128,
per_device_train_batch_size=4,
save_steps=5,
learning_rate=0.1,
)
trainer.model.a.requires_grad_(False)
trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()
state = dataclasses.asdict(trainer.state)
checkpoint = os.path.join(tmpdir, "checkpoint-5")
# Reinitialize trainer
trainer = get_regression_trainer(
output_dir=tmpdir,
train_len=128,
per_device_train_batch_size=4,
save_steps=5,
learning_rate=0.1,
)
trainer.model.a.requires_grad_(False)
trainer.train(resume_from_checkpoint=checkpoint)
self.assertFalse(trainer.model.a.requires_grad)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
def test_load_best_model_at_end(self):
total = int(self.n_epochs * 64 / self.batch_size)
with tempfile.TemporaryDirectory() as tmpdir:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment