Unverified Commit 1690094b authored by haohanchen-yagao's avatar haohanchen-yagao Committed by GitHub
Browse files

Add FP16 Support for SageMaker Model Parallel (#17386)

* Add FP16 supporot for sagemaker model parallel

* minor fix

* fix indentation

* handle mix precision exception for smmp

* minor fix

* remove amp implementation on SMMP

* remove redundant stuff

* reformat trainer

* restyling

* reformat
parent 4aabf9b5
...@@ -494,6 +494,20 @@ class Trainer: ...@@ -494,6 +494,20 @@ class Trainer:
self.use_cuda_amp = False self.use_cuda_amp = False
self.use_cpu_amp = False self.use_cpu_amp = False
# Mixed precision setup for SageMaker Model Parallel
if is_sagemaker_mp_enabled():
# BF16 + model parallelism in SageMaker: currently not supported, raise an error
if args.bf16:
raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
# When there's mismatch between SMP config and trainer argument, use SMP config as truth
if args.fp16 != smp.state.cfg.fp16:
logger.warning(
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
f"but FP16 provided in trainer argument is {args.fp16},"
f"setting to {smp.state.cfg.fp16}"
)
args.fp16 = smp.state.cfg.fp16
if args.fp16 or args.bf16: if args.fp16 or args.bf16:
if self.fsdp is not None: if self.fsdp is not None:
raise ValueError( raise ValueError(
...@@ -519,14 +533,13 @@ class Trainer: ...@@ -519,14 +533,13 @@ class Trainer:
logger.info(f"Using {args.half_precision_backend} half precision backend") logger.info(f"Using {args.half_precision_backend} half precision backend")
self.do_grad_scaling = False self.do_grad_scaling = False
if (args.fp16 or args.bf16) and not args.deepspeed: # deepspeed manages its own half precision if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
# deepspeed and SageMaker Model Parallel manage their own half precision
if args.half_precision_backend == "cuda_amp": if args.half_precision_backend == "cuda_amp":
self.use_cuda_amp = True self.use_cuda_amp = True
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
self.do_grad_scaling = True self.do_grad_scaling = True
if is_sagemaker_mp_enabled(): if self.sharded_ddp is not None:
self.scaler = smp.amp.GradScaler()
elif self.sharded_ddp is not None:
self.scaler = ShardedGradScaler() self.scaler = ShardedGradScaler()
elif is_torch_tpu_available(): elif is_torch_tpu_available():
from torch_xla.amp import GradScaler from torch_xla.amp import GradScaler
...@@ -545,18 +558,6 @@ class Trainer: ...@@ -545,18 +558,6 @@ class Trainer:
) )
self.use_apex = True self.use_apex = True
# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
if (
is_sagemaker_mp_enabled()
and self.use_cuda_amp
and args.max_grad_norm is not None
and args.max_grad_norm > 0
):
raise ValueError(
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
"along 'max_grad_norm': 0 in your hyperparameters."
)
# Label smoothing # Label smoothing
if self.args.label_smoothing_factor != 0: if self.args.label_smoothing_factor != 0:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
...@@ -938,7 +939,10 @@ class Trainer: ...@@ -938,7 +939,10 @@ class Trainer:
`create_scheduler`) in a subclass. `create_scheduler`) in a subclass.
""" """
self.create_optimizer() self.create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer) self.create_scheduler(
num_training_steps=num_training_steps,
optimizer=self.optimizer.optimizer if is_sagemaker_mp_enabled() and smp.state.cfg.fp16 else self.optimizer,
)
def create_optimizer(self): def create_optimizer(self):
""" """
...@@ -1641,7 +1645,9 @@ class Trainer: ...@@ -1641,7 +1645,9 @@ class Trainer:
# AMP: gradients need unscaling # AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer) self.scaler.unscale_(self.optimizer)
if hasattr(self.optimizer, "clip_grad_norm"): if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
elif hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(args.max_grad_norm) self.optimizer.clip_grad_norm(args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"): elif hasattr(model, "clip_grad_norm_"):
...@@ -2068,7 +2074,9 @@ class Trainer: ...@@ -2068,7 +2074,9 @@ class Trainer:
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
def opt_load_hook(mod, opt): def opt_load_hook(mod, opt):
opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) opt.load_state_dict(
smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True), gather_if_shard=False
)
self.model_wrapped.register_post_step_hook(opt_load_hook) self.model_wrapped.register_post_step_hook(opt_load_hook)
else: else:
...@@ -2292,8 +2300,7 @@ class Trainer: ...@@ -2292,8 +2300,7 @@ class Trainer:
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
scaler = self.scaler if self.do_grad_scaling else None loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
return loss_mb.reduce_mean().detach().to(self.args.device) return loss_mb.reduce_mean().detach().to(self.args.device)
with self.compute_loss_context_manager(): with self.compute_loss_context_manager():
......
...@@ -1020,15 +1020,10 @@ if is_sagemaker_mp_enabled(): ...@@ -1020,15 +1020,10 @@ if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp import smdistributed.modelparallel.torch as smp
@smp.step() @smp.step()
def smp_forward_backward(model, inputs, gradient_accumulation_steps=1, scaler=None): def smp_forward_backward(model, inputs, gradient_accumulation_steps=1):
with torch.cuda.amp.autocast(enabled=(scaler is not None)): outputs = model(**inputs)
outputs = model(**inputs)
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
loss /= gradient_accumulation_steps loss /= gradient_accumulation_steps
if scaler is not None:
loss = scaler.scale(loss).squeeze()
model.backward(loss) model.backward(loss)
return loss return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment