"tools/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "7b1f6a850f9b7507d7cb6afc1d576e3bb5b265d0"
Unverified Commit 6bc517cc authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

deepspeed resume from ckpt fixes and adding support for deepspeed optimizer...

deepspeed resume from ckpt fixes and adding support for deepspeed optimizer and HF scheduler (#25863)

* Add support for deepspeed optimizer and HF scheduler

* fix bug

* fix the import

* fix issue with deepspeed scheduler saving for hf optim + hf scheduler scenario

* fix loading of hf scheduler when loading deepspeed checkpoint

* fix import of `DeepSpeedSchedulerWrapper`

* add tests

* add the comment and skip the failing tests

* address comment
parent 1110b565
...@@ -26,6 +26,8 @@ from ..utils import is_accelerate_available, is_torch_available, logging ...@@ -26,6 +26,8 @@ from ..utils import is_accelerate_available, is_torch_available, logging
if is_torch_available(): if is_torch_available():
import torch import torch
from ..optimization import get_scheduler
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -274,7 +276,7 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps ...@@ -274,7 +276,7 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
# 1. DS scheduler + DS optimizer: Yes # 1. DS scheduler + DS optimizer: Yes
# 2. HF scheduler + HF optimizer: Mostly* # 2. HF scheduler + HF optimizer: Mostly*
# 3. DS scheduler + HF optimizer: Mostly* # 3. DS scheduler + HF optimizer: Mostly*
# 4. HF scheduler + DS optimizer: No # 4. HF scheduler + DS optimizer: Yes
# #
# Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB) # Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB)
...@@ -304,11 +306,18 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps ...@@ -304,11 +306,18 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
lr_scheduler = DummyScheduler(optimizer) lr_scheduler = DummyScheduler(optimizer)
else: else:
if isinstance(optimizer, DummyOptim): if isinstance(optimizer, DummyOptim):
raise ValueError(
"Found `optimizer` configured in the DeepSpeed config, but no `scheduler`. " def _lr_scheduler_callable(optimizer):
"Please configure a scheduler in the DeepSpeed config." return get_scheduler(
) trainer.args.lr_scheduler_type,
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) optimizer=optimizer,
num_warmup_steps=trainer.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable)
else:
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
return optimizer, lr_scheduler return optimizer, lr_scheduler
......
...@@ -60,7 +60,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d ...@@ -60,7 +60,7 @@ from .data.data_collator import DataCollator, DataCollatorWithPadding, default_d
from .debug_utils import DebugOption, DebugUnderflowOverflow from .debug_utils import DebugOption, DebugUnderflowOverflow
from .dependency_versions_check import dep_version_check from .dependency_versions_check import dep_version_check
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .modelcard import TrainingSummary from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
...@@ -212,6 +212,9 @@ if is_accelerate_available(): ...@@ -212,6 +212,9 @@ if is_accelerate_available():
save_fsdp_optimizer, save_fsdp_optimizer,
) )
if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper
if TYPE_CHECKING: if TYPE_CHECKING:
import optuna import optuna
...@@ -2362,7 +2365,14 @@ class Trainer: ...@@ -2362,7 +2365,14 @@ class Trainer:
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
# Save SCHEDULER & SCALER # Save SCHEDULER & SCALER
if self.args.should_save and not self.is_deepspeed_enabled and not is_torch_tpu_available(): is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
self.lr_scheduler, DeepSpeedSchedulerWrapper
)
if (
self.args.should_save
and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
and not is_torch_tpu_available()
):
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
...@@ -2428,6 +2438,10 @@ class Trainer: ...@@ -2428,6 +2438,10 @@ class Trainer:
if self.is_deepspeed_enabled: if self.is_deepspeed_enabled:
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
return return
checkpoint_file_exists = ( checkpoint_file_exists = (
......
...@@ -136,6 +136,14 @@ ZERO3 = "zero3" ...@@ -136,6 +136,14 @@ ZERO3 = "zero3"
FP16 = "fp16" FP16 = "fp16"
BF16 = "bf16" BF16 = "bf16"
HF_OPTIM = "hf_optim"
HF_SCHEDULER = "hf_scheduler"
DS_OPTIM = "ds_optim"
DS_SCHEDULER = "ds_scheduler"
optims = [HF_OPTIM, DS_OPTIM]
schedulers = [HF_SCHEDULER, DS_SCHEDULER]
stages = [ZERO2, ZERO3] stages = [ZERO2, ZERO3]
if is_torch_bf16_gpu_available(): if is_torch_bf16_gpu_available():
dtypes = [FP16, BF16] dtypes = [FP16, BF16]
...@@ -153,6 +161,8 @@ def parameterized_custom_name_func(func, param_num, param): ...@@ -153,6 +161,8 @@ def parameterized_custom_name_func(func, param_num, param):
# Cartesian-product of zero stages with models to test # Cartesian-product of zero stages with models to test
params = list(itertools.product(stages, dtypes)) params = list(itertools.product(stages, dtypes))
params_with_optims_and_schedulers = list(itertools.product(stages, dtypes, optims, schedulers))
@require_deepspeed @require_deepspeed
@require_torch_gpu @require_torch_gpu
...@@ -640,10 +650,16 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T ...@@ -640,10 +650,16 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
"Can't find a valid checkpoint at" in str(context.exception), f"got exception: {context.exception}" "Can't find a valid checkpoint at" in str(context.exception), f"got exception: {context.exception}"
) )
@parameterized.expand(params, name_func=parameterized_custom_name_func) @parameterized.expand(params_with_optims_and_schedulers, name_func=parameterized_custom_name_func)
def test_can_resume_training_normal(self, stage, dtype): def test_can_resume_training_normal(self, stage, dtype, optim, scheduler):
# adapted from TrainerIntegrationTest.test_can_resume_training # adapted from TrainerIntegrationTest.test_can_resume_training
# test normal resume for each stage separately, error-handling is tested in a different test # test normal resume for each stage separately, error-handling is tested in a different test
# ToDo: Currently, hf_optim + hf_scheduler resumes with the correct states and
# also has same losses for few steps but then slowly diverges. Need to figure it out.
if optim == HF_OPTIM and scheduler == HF_SCHEDULER:
return
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False) output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
ds_config_dict = self.get_config_dict(stage) ds_config_dict = self.get_config_dict(stage)
if dtype == FP16: if dtype == FP16:
...@@ -652,6 +668,12 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T ...@@ -652,6 +668,12 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
if stage == ZERO3: if stage == ZERO3:
ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
if optim == HF_OPTIM:
del ds_config_dict["optimizer"]
if scheduler == HF_SCHEDULER:
del ds_config_dict["scheduler"]
kwargs = { kwargs = {
"output_dir": output_dir, "output_dir": output_dir,
"train_len": 128, "train_len": 128,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment