Unverified Commit 42f359d0 authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Use DS callable API to allow hf_scheduler + ds_optimizer (#13216)



* Use DS callable API to allow hf_scheduler + ds_optimizer

* Preserve backward-compatibility

* Restore backward compatibility

* Tweak arg positioning

* Tweak arg positioning

* bump the required version

* Undo indent

* Update src/transformers/trainer.py

* style
Co-authored-by: default avatarStas Bekman <stas@stason.org>
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
parent 35236b87
...@@ -91,7 +91,7 @@ _deps = [ ...@@ -91,7 +91,7 @@ _deps = [
"cookiecutter==1.7.2", "cookiecutter==1.7.2",
"dataclasses", "dataclasses",
"datasets", "datasets",
"deepspeed>=0.4.3", "deepspeed>=0.5.1",
"docutils==0.16.0", "docutils==0.16.0",
"fairscale>0.3", "fairscale>0.3",
"faiss-cpu", "faiss-cpu",
......
...@@ -311,13 +311,13 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None): ...@@ -311,13 +311,13 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
# 1. DS scheduler + DS optimizer: Yes # 1. DS scheduler + DS optimizer: Yes
# 2. HF scheduler + HF optimizer: Yes # 2. HF scheduler + HF optimizer: Yes
# 3. DS scheduler + HF optimizer: Yes # 3. DS scheduler + HF optimizer: Yes
# 4. HF scheduler + DS optimizer: No # 4. HF scheduler + DS optimizer: Yes
# #
# Unless Offload is enabled in which case it's: # Unless Offload is enabled in which case it's:
# 1. DS scheduler + DS optimizer: Yes # 1. DS scheduler + DS optimizer: Yes
# 2. HF scheduler + HF optimizer: Mostly* # 2. HF scheduler + HF optimizer: Mostly*
# 3. DS scheduler + HF optimizer: Mostly* # 3. DS scheduler + HF optimizer: Mostly*
# 4. HF scheduler + DS optimizer: No # 4. HF scheduler + DS optimizer: Yes
# #
# Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB) # Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB)
...@@ -336,28 +336,20 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None): ...@@ -336,28 +336,20 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
# ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch. # ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch.
# But trainer uses AdamW by default. # But trainer uses AdamW by default.
trainer.create_optimizer() optimizer = trainer.create_optimizer()
optimizer = trainer.optimizer
# To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer` # To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer`
config["zero_allow_untested_optimizer"] = True config["zero_allow_untested_optimizer"] = True
# DS schedulers (deepspeed/runtime/lr_schedules.py): def _lr_scheduler_callable(optimizer):
# return trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
# DS name | --lr_scheduler_type | HF func | Notes
# -------------| ---------------------|-----------------------------------|--------------------
# LRRangeTest | na | na | LRRT
# OneCycle | na | na | 1CLR
# WarmupLR | constant_with_warmup | get_constant_schedule_with_warmup | w/ warmup_min_lr=0
# WarmupDecayLR| linear | get_linear_schedule_with_warmup |
lr_scheduler = None lr_scheduler = None
if "scheduler" not in config: if "scheduler" not in config:
if "optimizer" in config: if optimizer is None:
# to make this option work, we need to init DS optimizer first, then init HS scheduler, # Optimizer is not available, so use callable to defer lr_scheduler creation to DS init
# then pass the HS scheduler to DS init, which is not possible at the moment lr_scheduler = _lr_scheduler_callable
raise ValueError("At the moment HF scheduler + DeepSpeed optimizer combination is not possible")
else: else:
trainer.create_scheduler(num_training_steps=num_training_steps) lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
lr_scheduler = trainer.lr_scheduler
# keep for quick debug: # keep for quick debug:
# from pprint import pprint; pprint(config) # from pprint import pprint; pprint(config)
......
...@@ -8,7 +8,7 @@ deps = { ...@@ -8,7 +8,7 @@ deps = {
"cookiecutter": "cookiecutter==1.7.2", "cookiecutter": "cookiecutter==1.7.2",
"dataclasses": "dataclasses", "dataclasses": "dataclasses",
"datasets": "datasets", "datasets": "datasets",
"deepspeed": "deepspeed>=0.4.3", "deepspeed": "deepspeed>=0.5.1",
"docutils": "docutils==0.16.0", "docutils": "docutils==0.16.0",
"fairscale": "fairscale>0.3", "fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu", "faiss-cpu": "faiss-cpu",
......
...@@ -768,7 +768,7 @@ class Trainer: ...@@ -768,7 +768,7 @@ class Trainer:
and/or :obj:`create_scheduler`) in a subclass. and/or :obj:`create_scheduler`) in a subclass.
""" """
self.create_optimizer() self.create_optimizer()
self.create_scheduler(num_training_steps) self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
def create_optimizer(self): def create_optimizer(self):
""" """
...@@ -813,9 +813,12 @@ class Trainer: ...@@ -813,9 +813,12 @@ class Trainer:
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer) self.optimizer = smp.DistributedOptimizer(self.optimizer)
def create_scheduler(self, num_training_steps: int): return self.optimizer
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
""" """
Setup the scheduler. The optimizer of the trainer must have been set up before this method is called. Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args: Args:
num_training_steps (int): The number of training steps to do. num_training_steps (int): The number of training steps to do.
...@@ -823,10 +826,11 @@ class Trainer: ...@@ -823,10 +826,11 @@ class Trainer:
if self.lr_scheduler is None: if self.lr_scheduler is None:
self.lr_scheduler = get_scheduler( self.lr_scheduler = get_scheduler(
self.args.lr_scheduler_type, self.args.lr_scheduler_type,
self.optimizer, optimizer=self.optimizer if optimizer is None else optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps, num_training_steps=num_training_steps,
) )
return self.lr_scheduler
def num_examples(self, dataloader: DataLoader) -> int: def num_examples(self, dataloader: DataLoader) -> int:
""" """
......
...@@ -292,19 +292,16 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -292,19 +292,16 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
self.assertNotEqual(new_a, a) self.assertNotEqual(new_a, a)
def test_hf_scheduler_ds_optimizer(self): def test_hf_scheduler_ds_optimizer(self):
# this combo is not possible at the moment a = 0
with mockenv_context(**self.dist_env_1_gpu): with mockenv_context(**self.dist_env_1_gpu):
ds_config_zero2_dict = self.get_config_dict(ZERO2) ds_config_zero2_dict = self.get_config_dict(ZERO2)
del ds_config_zero2_dict["scheduler"] # force default HF Trainer scheduler del ds_config_zero2_dict["scheduler"] # force default HF Trainer scheduler
ds_config_zero2_dict["zero_optimization"]["offload_optimizer"]["device"] = "none" ds_config_zero2_dict["zero_optimization"]["offload_optimizer"]["device"] = "none"
ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_zero2_dict) trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_zero2_dict)
with self.assertRaises(Exception) as context: trainer.train()
trainer.train() new_a = trainer.model.a.item()
self.assertTrue( self.assertNotEqual(new_a, a)
"HF scheduler + DeepSpeed optimizer combination is not possible" in str(context.exception),
f"got exception: {context.exception}",
)
@require_deepspeed_aio @require_deepspeed_aio
def test_stage3_nvme_offload(self): def test_stage3_nvme_offload(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment