Unverified Commit 3437d121 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[Trainer/Deepspeed] handle get_last_lr() before first step() (#10362)

* handle get_last_lr() before first step()

* abstract away the lr getting logic

* cleanup

* add test

* move to utils
parent 4a1ab7cb
...@@ -78,6 +78,31 @@ class TrainerIntegrationDeepSpeed(TestCasePlus): ...@@ -78,6 +78,31 @@ class TrainerIntegrationDeepSpeed(TestCasePlus):
trainer.train() trainer.train()
assert "DeepSpeed info" in cs.out, "expected DeepSpeed logger output but got none" assert "DeepSpeed info" in cs.out, "expected DeepSpeed logger output but got none"
def test_early_get_last_lr(self):
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
# not run for the first few dozen steps while loss scale is too large, and thus during
# that time `get_last_lr` will fail if called during that warm up stage,
#
# setting `logging_steps=1` forces an early `trainer._maybe_log_save_evaluate()` which calls
# `self.lr_scheduler.get_last_lr()` and originally it'd fail on the very first step.
with mockenv_context(**self.dist_env_1_gpu):
a = b = 0.0
trainer = get_regression_trainer(
a=a,
b=b,
local_rank=0,
train_len=8,
deepspeed=self.ds_config_file,
per_device_train_batch_size=8,
logging_steps=1,
)
trainer.train()
no_grad_accum_a = trainer.model.a.item()
# it's enough that train didn't fail for this test, but we must check that
# optimizer/scheduler didn't run (since if it did this test isn't testing the right thing)
self.assertEqual(no_grad_accum_a, a)
def test_gradient_accumulation(self): def test_gradient_accumulation(self):
# this test measures that we get identical weights and similar loss with: # this test measures that we get identical weights and similar loss with:
......
...@@ -82,6 +82,7 @@ from .trainer_pt_utils import ( ...@@ -82,6 +82,7 @@ from .trainer_pt_utils import (
SequentialDistributedSampler, SequentialDistributedSampler,
distributed_broadcast_scalars, distributed_broadcast_scalars,
distributed_concat, distributed_concat,
get_learning_rate,
nested_concat, nested_concat,
nested_detach, nested_detach,
nested_numpify, nested_numpify,
...@@ -1129,12 +1130,8 @@ class Trainer: ...@@ -1129,12 +1130,8 @@ class Trainer:
tr_loss -= tr_loss tr_loss -= tr_loss
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
# backward compatibility for pytorch schedulers logs["learning_rate"] = get_learning_rate(self)
logs["learning_rate"] = (
self.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0]
)
self._total_loss_scalar += tr_loss_scalar self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step self._globalstep_last_logged = self.state.global_step
......
...@@ -24,6 +24,7 @@ from typing import Iterator, List, Optional, Union ...@@ -24,6 +24,7 @@ from typing import Iterator, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from packaging import version
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler from torch.utils.data.sampler import RandomSampler, Sampler
...@@ -262,6 +263,29 @@ def _get_first_shape(arrays): ...@@ -262,6 +263,29 @@ def _get_first_shape(arrays):
return arrays.shape return arrays.shape
def get_learning_rate(trainer):
if trainer.deepspeed:
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
# not run for the first few dozen steps while loss scale is too large, and thus during
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
try:
last_lr = trainer.lr_scheduler.get_last_lr()[0]
except AssertionError as e:
if "need to call step" in str(e):
logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
last_lr = 0
else:
raise
else:
last_lr = (
# backward compatibility for pytorch schedulers
trainer.lr_scheduler.get_last_lr()[0]
if version.parse(torch.__version__) >= version.parse("1.4")
else trainer.lr_scheduler.get_lr()[0]
)
return last_lr
class DistributedTensorGatherer: class DistributedTensorGatherer:
""" """
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks. A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment