Unverified Commit 7b702836 authored by Ziyang's avatar Ziyang Committed by GitHub
Browse files

Support custom scheduler in deepspeed training (#26831)

Reuse trainer.create_scheduler to create scheduler for deepspeed
parent ca8944c4
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
""" """
Integration with Deepspeed Integration with Deepspeed
""" """
import copy
import importlib.metadata as importlib_metadata import importlib.metadata as importlib_metadata
import importlib.util import importlib.util
import weakref import weakref
...@@ -27,7 +27,6 @@ from ..utils import is_accelerate_available, is_torch_available, logging ...@@ -27,7 +27,6 @@ from ..utils import is_accelerate_available, is_torch_available, logging
if is_torch_available(): if is_torch_available():
import torch import torch
from ..optimization import get_scheduler
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -341,12 +340,15 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps ...@@ -341,12 +340,15 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
if isinstance(optimizer, DummyOptim): if isinstance(optimizer, DummyOptim):
def _lr_scheduler_callable(optimizer): def _lr_scheduler_callable(optimizer):
return get_scheduler( # create a shallow copy first, so later modifications do not affect original trainer
trainer.args.lr_scheduler_type, trainer_copy = copy.copy(trainer)
optimizer=optimizer, # at the time _lr_scheduler_callable is called, trainer.lr_scheduler has been set
num_warmup_steps=trainer.args.get_warmup_steps(num_training_steps), # update it to None so that we can re-create a new scheduler
num_training_steps=num_training_steps, trainer_copy.lr_scheduler = None
lr_scheduler = trainer_copy.create_scheduler(
num_training_steps=num_training_steps, optimizer=optimizer
) )
return lr_scheduler
lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable) lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment