Unverified Commit a4763f55 authored by carefree0910's avatar carefree0910 Committed by GitHub
Browse files

Supported customizing kwargs for lr_scheduler (#584)


Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 66268bd3
...@@ -979,7 +979,7 @@ class DeepSpeedEngine(Module): ...@@ -979,7 +979,7 @@ class DeepSpeedEngine(Module):
torch.nn.utils.clip_grad_norm_(parameters=self.module.parameters(), torch.nn.utils.clip_grad_norm_(parameters=self.module.parameters(),
max_norm=self.gradient_clipping()) max_norm=self.gradient_clipping())
def _take_model_step(self): def _take_model_step(self, lr_kwargs):
if self.gradient_clipping() > 0.0: if self.gradient_clipping() > 0.0:
if not self.fp16_enabled() and not self.amp_enabled(): if not self.fp16_enabled() and not self.amp_enabled():
self.clip_fp32_gradients() self.clip_fp32_gradients()
...@@ -1010,14 +1010,14 @@ class DeepSpeedEngine(Module): ...@@ -1010,14 +1010,14 @@ class DeepSpeedEngine(Module):
self.skipped_steps += 1 self.skipped_steps += 1
else: else:
if self.lr_scheduler is not None: if self.lr_scheduler is not None:
self.lr_scheduler.step() self.lr_scheduler.step(**(lr_kwargs or {}))
if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0: if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0:
self._report_progress(self.global_steps + 1) self._report_progress(self.global_steps + 1)
self.global_steps += 1 self.global_steps += 1
self.global_samples += self.train_batch_size() self.global_samples += self.train_batch_size()
def step(self): def step(self, lr_kwargs=None):
r"""Execute the weight update step after forward and backward propagation r"""Execute the weight update step after forward and backward propagation
on effective_train_batch. on effective_train_batch.
""" """
...@@ -1034,7 +1034,7 @@ class DeepSpeedEngine(Module): ...@@ -1034,7 +1034,7 @@ class DeepSpeedEngine(Module):
if self.progressive_layer_drop: if self.progressive_layer_drop:
self.progressive_layer_drop.update_state(self.global_steps) self.progressive_layer_drop.update_state(self.global_steps)
self._take_model_step() self._take_model_step(lr_kwargs)
self.tput_timer.stop(report_progress) self.tput_timer.stop(report_progress)
......
...@@ -940,14 +940,14 @@ class PipelineEngine(DeepSpeedEngine): ...@@ -940,14 +940,14 @@ class PipelineEngine(DeepSpeedEngine):
if self.wall_clock_breakdown(): if self.wall_clock_breakdown():
self.timers('pipe_recv_grad').stop() self.timers('pipe_recv_grad').stop()
def _exec_optimizer_step(self): def _exec_optimizer_step(self, lr_kwargs=None):
if self.wall_clock_breakdown(): if self.wall_clock_breakdown():
self.timers('step_microstep').start() self.timers('step_microstep').start()
self.timers('step').start() self.timers('step').start()
self.mem_status('BEFORE STEP', reset_max=True) self.mem_status('BEFORE STEP', reset_max=True)
self._force_grad_boundary = True self._force_grad_boundary = True
self._take_model_step() self._take_model_step(lr_kwargs)
self._force_grad_boundary = False self._force_grad_boundary = False
self.mem_status('AFTER STEP') self.mem_status('AFTER STEP')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment