Unverified Commit cd8c93f7 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[DeepSpeed] improve checkpoint loading code plus tests (#10760)

* deepspeed checkpoint loading code plus tests

* style

* style
parent 01c7fb04
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import io
import json
import os
......@@ -19,6 +20,8 @@ import sys
import unittest
from copy import deepcopy
from transformers import TrainingArguments
from transformers.file_utils import WEIGHTS_NAME
from transformers.integrations import is_deepspeed_available
from transformers.testing_utils import (
CaptureStd,
......@@ -35,7 +38,7 @@ from transformers.trainer_utils import set_seed
bindir = os.path.abspath(os.path.dirname(__file__))
sys.path.append(f"{bindir}/../../../tests")
from test_trainer import get_regression_trainer # noqa
from test_trainer import TrainerIntegrationCommon, get_regression_trainer # noqa
set_seed(42)
......@@ -60,11 +63,21 @@ def require_deepspeed(test_case):
@require_deepspeed
@require_torch_gpu
class TrainerIntegrationDeepSpeed(TestCasePlus):
""" This class is for testing directly via get_regression_trainer """
class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
"""
This class is for testing directly via get_regression_trainer
It mixes in `TrainerIntegrationCommon` which already has a lot of helper validation methods which we can re-use here.
"""
def setUp(self):
super().setUp()
args = TrainingArguments(".")
self.n_epochs = args.num_train_epochs
self.batch_size = args.train_batch_size
self.dist_env_1_gpu = dict(
MASTER_ADDR="localhost", MASTER_PORT="10999", RANK="0", LOCAL_RANK="0", WORLD_SIZE="1"
)
......@@ -222,6 +235,101 @@ class TrainerIntegrationDeepSpeed(TestCasePlus):
# see the note above how to get identical loss on a small bs
self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=5)
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, is_pretrained=True):
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
file_list = [WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
ds_file_list = ["mp_rank_00_model_states.pt", "zero_pp_rank_0_mp_rank_00optim_states.pt"]
for step in range(freq, total, freq):
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
self.assertTrue(os.path.isdir(checkpoint))
# common files
for filename in file_list:
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
# ds files
ds_path = os.path.join(checkpoint, f"global_step{step}")
for filename in ds_file_list:
# filename = os.path.join(path, filename)
# print(filename)
self.assertTrue(os.path.isfile(os.path.join(ds_path, filename)))
def test_save_checkpoints(self):
# adapted from TrainerIntegrationTest.test_save_checkpoints
output_dir = self.get_auto_remove_tmp_dir()
ds_config_dict = deepcopy(self.ds_config_dict)
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
freq = 5
# save checkpoints
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(
output_dir=output_dir,
save_steps=freq,
deepspeed=ds_config_dict,
)
trainer.train()
total = int(self.n_epochs * 64 / self.batch_size)
self.check_saved_checkpoints_deepspeed(output_dir, freq, total)
def test_can_resume_training(self):
# adapted from TrainerIntegrationTest.test_can_resume_training
output_dir = self.get_auto_remove_tmp_dir()
ds_config_dict = deepcopy(self.ds_config_dict)
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
kwargs = dict(output_dir=output_dir, train_len=128, save_steps=5, learning_rate=0.1, deepspeed=ds_config_dict)
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(**kwargs)
trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()
state = dataclasses.asdict(trainer.state)
checkpoint = os.path.join(output_dir, "checkpoint-5")
# Reinitialize trainer
trainer = get_regression_trainer(**kwargs)
trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint = os.path.join(output_dir, "checkpoint-15")
# Reinitialize trainer and load model
trainer = get_regression_trainer(**kwargs)
trainer.train(resume_from_checkpoint=checkpoint)
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
# Now check failures
# 1. fail to find a bogus checkpoint
trainer = get_regression_trainer(**kwargs)
with self.assertRaises(Exception) as context:
trainer.train(resume_from_checkpoint=f"{checkpoint}-bogus")
self.assertTrue("failed to resume from checkpoint" in str(context.exception))
# 2. fail to find any checkpoint - due a fresh output_dir
output_dir2 = self.get_auto_remove_tmp_dir()
trainer = get_regression_trainer(output_dir=output_dir2, deepspeed=ds_config_dict)
with self.assertRaises(Exception) as context:
trainer.train(resume_from_checkpoint=True)
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
@slow
@require_deepspeed
......
......@@ -21,6 +21,7 @@ import numbers
import os
import re
import tempfile
from copy import deepcopy
from pathlib import Path
from .utils import logging
......@@ -268,15 +269,19 @@ def rewrite_logs(d):
return new_d
def init_deepspeed(trainer, num_training_steps):
def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
"""
Init DeepSpeed, after converting any relevant Trainer's args into DeepSpeed configuration
Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
If ``resume_from_checkpoint`` was passed then an attempt to resume from a previously saved checkpoint will be made.
Args:
trainer: Trainer object
num_training_steps: per single gpu
resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load
Returns: model, optimizer, lr_scheduler
"""
import deepspeed
......@@ -287,7 +292,9 @@ def init_deepspeed(trainer, num_training_steps):
model = trainer.model
if isinstance(args.deepspeed, dict):
config = args.deepspeed
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
# modified it, it will not be accepted here again, since some config params must be not set by users
config = deepcopy(args.deepspeed)
elif isinstance(args.deepspeed, str):
with io.open(ds_config_file, "r", encoding="utf-8") as f:
config = json.load(f)
......@@ -442,6 +449,15 @@ def init_deepspeed(trainer, num_training_steps):
lr_scheduler=lr_scheduler,
)
if resume_from_checkpoint is not None: # and os.path.isdir(resume_from_checkpoint):
logger.info(f"Attempting to resume from {resume_from_checkpoint}")
# this magically updates self.optimizer and self.lr_scheduler
load_path, _ = model.load_checkpoint(
resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
)
if load_path is None:
raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}")
return model, optimizer, lr_scheduler
......
......@@ -878,7 +878,11 @@ class Trainer:
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
logger.info(f"Loading model from {resume_from_checkpoint}).")
if isinstance(self.model, PreTrainedModel):
if self.deepspeed:
# will be resumed in init_deepspeed
pass
elif isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(resume_from_checkpoint)
model_reloaded = True
else:
......@@ -920,7 +924,9 @@ class Trainer:
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
if self.args.deepspeed:
model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps)
model, optimizer, lr_scheduler = init_deepspeed(
self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
)
self.model = model.module
self.model_wrapped = model
self.deepspeed = model # DeepSpeedEngine object
......@@ -1294,6 +1300,10 @@ class Trainer:
if checkpoint is None:
return
if self.deepspeed:
# deepspeed loads optimizer/lr_scheduler together with the model in init_deepspeed
return
if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
os.path.join(checkpoint, "scheduler.pt")
):
......@@ -1318,10 +1328,6 @@ class Trainer:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
reissue_pt_warnings(caught_warnings)
if self.deepspeed:
# Not sure how to check if there is a saved deepspeed checkpoint, but since it just return None if it fails to find a deepspeed checkpoint this is sort of a check-n-load function
self.deepspeed.load_checkpoint(checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True)
def hyperparameter_search(
self,
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
......
......@@ -24,6 +24,7 @@ import numpy as np
from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available
from transformers.file_utils import WEIGHTS_NAME
from transformers.testing_utils import (
TestCasePlus,
get_tests_dir,
require_datasets,
require_optuna,
......@@ -235,28 +236,7 @@ if is_torch_available():
)
@require_torch
@require_sentencepiece
@require_tokenizers
class TrainerIntegrationTest(unittest.TestCase):
def setUp(self):
args = TrainingArguments(".")
self.n_epochs = args.num_train_epochs
self.batch_size = args.train_batch_size
trainer = get_regression_trainer(learning_rate=0.1)
trainer.train()
self.default_trained_model = (trainer.model.a, trainer.model.b)
trainer = get_regression_trainer(learning_rate=0.1, seed=314)
trainer.train()
self.alternate_trained_model = (trainer.model.a, trainer.model.b)
def check_trained_model(self, model, alternate_seed=False):
# Checks a training seeded with learning_rate = 0.1
(a, b) = self.alternate_trained_model if alternate_seed else self.default_trained_model
self.assertTrue(torch.allclose(model.a, a))
self.assertTrue(torch.allclose(model.b, b))
class TrainerIntegrationCommon:
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True):
file_list = [WEIGHTS_NAME, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
if is_pretrained:
......@@ -306,6 +286,30 @@ class TrainerIntegrationTest(unittest.TestCase):
_ = log1.pop("train_samples_per_second", None)
self.assertEqual(log, log1)
@require_torch
@require_sentencepiece
@require_tokenizers
class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
def setUp(self):
super().setUp()
args = TrainingArguments(".")
self.n_epochs = args.num_train_epochs
self.batch_size = args.train_batch_size
trainer = get_regression_trainer(learning_rate=0.1)
trainer.train()
self.default_trained_model = (trainer.model.a, trainer.model.b)
trainer = get_regression_trainer(learning_rate=0.1, seed=314)
trainer.train()
self.alternate_trained_model = (trainer.model.a, trainer.model.b)
def check_trained_model(self, model, alternate_seed=False):
# Checks a training seeded with learning_rate = 0.1
(a, b) = self.alternate_trained_model if alternate_seed else self.default_trained_model
self.assertTrue(torch.allclose(model.a, a))
self.assertTrue(torch.allclose(model.b, b))
def test_trainer_works_with_dict(self):
# Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break
# anything.
......@@ -607,6 +611,7 @@ class TrainerIntegrationTest(unittest.TestCase):
# save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
# won't be the same since the training dataloader is shuffled).
return
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(output_dir=tmpdir, train_len=128, save_steps=5, learning_rate=0.1)
trainer.train()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment