Unverified Commit eb74c3f1 authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

updates to amp to support grad clip and grad accumulation (#290)

* updates to amp to support grad clip and grad accumulation
* zero grad using optimizer if in amp mode
parent 3cc96e17
...@@ -774,7 +774,12 @@ class DeepSpeedLight(Module): ...@@ -774,7 +774,12 @@ class DeepSpeedLight(Module):
if self.zero_optimization(): if self.zero_optimization():
self.optimizer.backward(loss) self.optimizer.backward(loss)
elif self.amp_enabled(): elif self.amp_enabled():
with amp.scale_loss(loss, self.optimizer) as scaled_loss: # AMP requires delaying unscale when inside gradient accumulation boundaries
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
delay_unscale = not self.is_gradient_accumulation_boundary()
with amp.scale_loss(loss,
self.optimizer,
delay_unscale=delay_unscale) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
elif self.fp16_enabled(): elif self.fp16_enabled():
self.optimizer.backward(loss) self.optimizer.backward(loss)
...@@ -828,14 +833,22 @@ class DeepSpeedLight(Module): ...@@ -828,14 +833,22 @@ class DeepSpeedLight(Module):
if self.is_gradient_accumulation_boundary(): if self.is_gradient_accumulation_boundary():
if not self.fp16_enabled() and self.gradient_clipping() > 0.0: if self.gradient_clipping() > 0.0:
if not self.fp16_enabled() and not self.amp_enabled():
self.clip_fp32_gradients() self.clip_fp32_gradients()
elif self.amp_enabled():
# AMP's recommended way of doing clipping
# https://nvidia.github.io/apex/advanced.html#gradient-clipping
master_params = amp.master_params(self.optimizer)
torch.nn.utils.clip_grad_norm_(parameters=master_params,
max_norm=self.gradient_clipping())
self.optimizer.step() self.optimizer.step()
# zero grad in basic optimizer could be unreliable and may not exhibit #zero grad in basic optimizer could be unreliable and may not exhibit
# the behaviour that we want #the behaviour that we want
if not self.zero_optimization() and not self.fp16_enabled(): if not self.zero_optimization() and not self.fp16_enabled(
) and not self.amp_enabled():
self.zero_grad() self.zero_grad()
else: else:
self.optimizer.zero_grad() self.optimizer.zero_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment