Unverified Commit d7633a4e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add basic support for FP16 in SageMaker model parallelism (#11407)

* Add FP16 support for SageMaker MP

* Add print debugs

* Squeeze

* Remove debug statements

* Add defensive check

* Typo
parent 38a716cd
......@@ -412,7 +412,12 @@ class Trainer:
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16
if self.fp16_backend == "amp":
self.use_amp = True
self.scaler = ShardedGradScaler() if self.sharded_ddp is not None else torch.cuda.amp.GradScaler()
if is_sagemaker_mp_enabled():
self.scaler = smp.amp.GradScaler()
elif self.sharded_ddp is not None:
self.scaler = ShardedGradScaler()
else:
self.scaler = torch.cuda.amp.GradScaler()
else:
if not is_apex_available():
raise ImportError(
......@@ -420,6 +425,13 @@ class Trainer:
)
self.use_apex = True
# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
if is_sagemaker_mp_enabled() and self.use_amp and args.max_grad_norm is not None and args.max_grad_norm > 0:
raise ValueError(
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
"along 'max_grad_norm': 0 in your hyperparameters."
)
# Label smoothing
if self.args.label_smoothing_factor != 0:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
......@@ -1607,7 +1619,8 @@ class Trainer:
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
scaler = self.scaler if self.use_amp else None
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
return loss_mb.reduce_mean().detach().to(self.args.device)
if self.use_amp:
......
......@@ -974,10 +974,15 @@ if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
@smp.step()
def smp_forward_backward(model, inputs, gradient_accumulation_steps=1):
def smp_forward_backward(model, inputs, gradient_accumulation_steps=1, scaler=None):
with torch.cuda.amp.autocast(enabled=(scaler is not None)):
outputs = model(**inputs)
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
loss /= gradient_accumulation_steps
if scaler is not None:
loss = scaler.scale(loss).squeeze()
model.backward(loss)
return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment