Unverified Commit b66c5ab2 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[deepspeed] fix --load_best_model_at_end (#14652)

* [deepspeed] fix load_best_model_at_end

* try with pull_request_target

* revert: try with pull_request_target

* style

* add test

* cleanup
parent 30646a0a
...@@ -357,6 +357,18 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps ...@@ -357,6 +357,18 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
return optimizer, lr_scheduler return optimizer, lr_scheduler
def deepspeed_reinit(trainer):
"""
this is a temp hack based on: https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until
Deepspeed fixes a bug where it can't resume from a checkpoint after it did some stepping
https://github.com/microsoft/DeepSpeed/issues/1612
"""
import deepspeed
deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**trainer.deepspeed_initialize_kwargs)
return deepspeed_engine, optimizer, lr_scheduler
def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False): def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False):
""" """
Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args. Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
...@@ -398,12 +410,12 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf ...@@ -398,12 +410,12 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
model_parameters = None model_parameters = None
else: else:
optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps) optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps)
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
# keep for quick debug: # keep for quick debug:
# from pprint import pprint; pprint(config) # from pprint import pprint; pprint(config)
model, optimizer, _, lr_scheduler = deepspeed.initialize( kwargs = dict(
model=model, model=model,
model_parameters=model_parameters, model_parameters=model_parameters,
config_params=config, config_params=config,
...@@ -411,6 +423,11 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf ...@@ -411,6 +423,11 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
) )
deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
# stash kwargs to enabled a later deepspeed_reinit
trainer.deepspeed_initialize_kwargs = kwargs
if resume_from_checkpoint is not None: if resume_from_checkpoint is not None:
# it's possible that the user is trying to resume from model_path, which doesn't necessarily # it's possible that the user is trying to resume from model_path, which doesn't necessarily
...@@ -424,7 +441,7 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf ...@@ -424,7 +441,7 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
if len(deepspeed_checkpoint_dirs) > 0: if len(deepspeed_checkpoint_dirs) > 0:
logger.info(f"Attempting to resume from {resume_from_checkpoint}") logger.info(f"Attempting to resume from {resume_from_checkpoint}")
# this magically updates self.optimizer and self.lr_scheduler # this magically updates self.optimizer and self.lr_scheduler
load_path, _ = model.load_checkpoint( load_path, _ = deepspeed_engine.load_checkpoint(
resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
) )
if load_path is None: if load_path is None:
...@@ -432,4 +449,4 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf ...@@ -432,4 +449,4 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
else: else:
logger.info(f"{resume_from_checkpoint} doesn't have deepspeed checkpoints, doing nothing") logger.info(f"{resume_from_checkpoint} doesn't have deepspeed checkpoints, doing nothing")
return model, optimizer, lr_scheduler return deepspeed_engine, optimizer, lr_scheduler
...@@ -59,7 +59,7 @@ from . import __version__ ...@@ -59,7 +59,7 @@ from . import __version__
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow from .debug_utils import DebugOption, DebugUnderflowOverflow
from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enabled
from .dependency_versions_check import dep_version_check from .dependency_versions_check import dep_version_check
from .file_utils import ( from .file_utils import (
CONFIG_NAME, CONFIG_NAME,
...@@ -1434,21 +1434,28 @@ class Trainer: ...@@ -1434,21 +1434,28 @@ class Trainer:
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
if os.path.exists(best_model_path): if os.path.exists(best_model_path):
# We load the model state dict on the CPU to avoid an OOM error. if self.deepspeed:
state_dict = torch.load(best_model_path, map_location="cpu") # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
# If the model is on the GPU, it still works! deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self)
self._load_state_dict_in_model(state_dict) self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.deepspeed.load_checkpoint(
self.state.best_model_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
)
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
# If the model is on the GPU, it still works!
self._load_state_dict_in_model(state_dict)
else: else:
logger.warn( logger.warn(
f"Could not locate the best model at {best_model_path}, if you are running a distributed training " f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
"on multiple nodes, you should activate `--save_on_each_node`." "on multiple nodes, you should activate `--save_on_each_node`."
) )
if self.deepspeed:
self.deepspeed.load_checkpoint(
self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False
)
# add remaining tr_loss # add remaining tr_loss
self._total_loss_scalar += tr_loss.item() self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step train_loss = self._total_loss_scalar / self.state.global_step
...@@ -1975,6 +1982,9 @@ class Trainer: ...@@ -1975,6 +1982,9 @@ class Trainer:
# This must be called on all ranks # This must be called on all ranks
self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME) self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME)
# save a deepspeed checkpoint as well (this is very fast)
self.deepspeed.save_checkpoint(output_dir)
elif self.args.should_save: elif self.args.should_save:
self._save(output_dir) self._save(output_dir)
......
...@@ -602,6 +602,11 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -602,6 +602,11 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(b, b1) self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1) self.check_trainer_state_are_the_same(state, state1)
# Finally, should be able to resume with the same trainer/same deepspeed engine instance
# XXX: but currently this not possible due DS bug: https://github.com/microsoft/DeepSpeed/issues/1612
# trainer.train(resume_from_checkpoint=checkpoint)
# a workaround needs to be used that re-creates the deepspeed engine
@parameterized.expand(stages) @parameterized.expand(stages)
def test_load_state_dict_from_zero_checkpoint(self, stage): def test_load_state_dict_from_zero_checkpoint(self, stage):
# test that we can load fp32 weights directly from the zero checkpoint into the current model # test that we can load fp32 weights directly from the zero checkpoint into the current model
...@@ -968,3 +973,49 @@ class TestDeepSpeedWithLauncher(TestCasePlus): ...@@ -968,3 +973,49 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
with CaptureStderr() as cs: with CaptureStderr() as cs:
execute_subprocess_async(cmd, env=self.get_env()) execute_subprocess_async(cmd, env=self.get_env())
assert "Detected DeepSpeed ZeRO-3" in cs.err assert "Detected DeepSpeed ZeRO-3" in cs.err
@parameterized.expand(stages)
def test_load_best_model(self, stage):
# this test exercises --load_best_model_at_end - the key is being able to resume after some training
data_dir = self.tests_dir / "fixtures/tests_samples/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir()
args = f"""
--model_name_or_path {T5_TINY}
--tokenizer_name {T5_TINY}
--train_file {data_dir}/train.json
--validation_file {data_dir}/val.json
--output_dir {output_dir}
--overwrite_output_dir
--source_lang en
--target_lang ro
--do_train
--max_train_samples 3
--do_eval
--max_eval_samples 1
--logging_strategy steps
--logging_steps 1
--evaluation_strategy steps
--eval_steps 1
--save_strategy steps
--save_steps 1
--load_best_model_at_end
--per_device_train_batch_size 1
--per_device_eval_batch_size 1
--num_train_epochs 1
--fp16
--report_to none
""".split()
args.extend(["--source_prefix", "translate English to Romanian: "])
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_zero3.json".split()
script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
launcher = get_launcher(distributed=False)
cmd = launcher + script + args + ds_args
# keep for quick debug
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
with CaptureStderr() as cs:
execute_subprocess_async(cmd, env=self.get_env())
# enough to test it didn't fail
assert "Detected DeepSpeed ZeRO-3" in cs.err
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment