Unverified Commit d7633a4e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add basic support for FP16 in SageMaker model parallelism (#11407)

* Add FP16 support for SageMaker MP

* Add print debugs

* Squeeze

* Remove debug statements

* Add defensive check

* Typo
parent 38a716cd
...@@ -412,7 +412,12 @@ class Trainer: ...@@ -412,7 +412,12 @@ class Trainer:
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16 if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16
if self.fp16_backend == "amp": if self.fp16_backend == "amp":
self.use_amp = True self.use_amp = True
self.scaler = ShardedGradScaler() if self.sharded_ddp is not None else torch.cuda.amp.GradScaler() if is_sagemaker_mp_enabled():
self.scaler = smp.amp.GradScaler()
elif self.sharded_ddp is not None:
self.scaler = ShardedGradScaler()
else:
self.scaler = torch.cuda.amp.GradScaler()
else: else:
if not is_apex_available(): if not is_apex_available():
raise ImportError( raise ImportError(
...@@ -420,6 +425,13 @@ class Trainer: ...@@ -420,6 +425,13 @@ class Trainer:
) )
self.use_apex = True self.use_apex = True
# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
if is_sagemaker_mp_enabled() and self.use_amp and args.max_grad_norm is not None and args.max_grad_norm > 0:
raise ValueError(
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
"along 'max_grad_norm': 0 in your hyperparameters."
)
# Label smoothing # Label smoothing
if self.args.label_smoothing_factor != 0: if self.args.label_smoothing_factor != 0:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
...@@ -1607,7 +1619,8 @@ class Trainer: ...@@ -1607,7 +1619,8 @@ class Trainer:
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) scaler = self.scaler if self.use_amp else None
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
return loss_mb.reduce_mean().detach().to(self.args.device) return loss_mb.reduce_mean().detach().to(self.args.device)
if self.use_amp: if self.use_amp:
......
...@@ -974,10 +974,15 @@ if is_sagemaker_mp_enabled(): ...@@ -974,10 +974,15 @@ if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp import smdistributed.modelparallel.torch as smp
@smp.step() @smp.step()
def smp_forward_backward(model, inputs, gradient_accumulation_steps=1): def smp_forward_backward(model, inputs, gradient_accumulation_steps=1, scaler=None):
with torch.cuda.amp.autocast(enabled=(scaler is not None)):
outputs = model(**inputs) outputs = model(**inputs)
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
loss /= gradient_accumulation_steps loss /= gradient_accumulation_steps
if scaler is not None:
loss = scaler.scale(loss).squeeze()
model.backward(loss) model.backward(loss)
return loss return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment