Unverified Commit 9f0646a5 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

Smangrul/accelerate mp integrate (#23148)

* mixed precision support via accelerate

* fix issues

* fix for the sharded ddp case

* fix flax and tf failing tests

* `refactor the place to create `Accelerator` object

* address comments by removing debugging print statements
parent de9255de
...@@ -212,6 +212,8 @@ if is_accelerate_available(): ...@@ -212,6 +212,8 @@ if is_accelerate_available():
if version.parse(accelerate_version) >= version.parse("0.16"): if version.parse(accelerate_version) >= version.parse("0.16"):
from accelerate import skip_first_batches from accelerate import skip_first_batches
from accelerate import Accelerator
if TYPE_CHECKING: if TYPE_CHECKING:
import optuna import optuna
...@@ -337,6 +339,9 @@ class Trainer: ...@@ -337,6 +339,9 @@ class Trainer:
self.deepspeed = None self.deepspeed = None
self.is_in_train = False self.is_in_train = False
# create accelerator object
self.accelerator = Accelerator()
# memory metrics - must set up as early as possible # memory metrics - must set up as early as possible
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
self._memory_tracker.start() self._memory_tracker.start()
...@@ -607,7 +612,7 @@ class Trainer: ...@@ -607,7 +612,7 @@ class Trainer:
"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
) )
if args.fp16 or args.bf16: if (args.fp16 or args.bf16) and self.sharded_ddp is not None:
if args.half_precision_backend == "auto": if args.half_precision_backend == "auto":
if args.device == torch.device("cpu"): if args.device == torch.device("cpu"):
if args.fp16: if args.fp16:
...@@ -624,6 +629,7 @@ class Trainer: ...@@ -624,6 +629,7 @@ class Trainer:
self.do_grad_scaling = False self.do_grad_scaling = False
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()): if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
# deepspeed and SageMaker Model Parallel manage their own half precision # deepspeed and SageMaker Model Parallel manage their own half precision
if self.sharded_ddp is not None:
if args.half_precision_backend == "cuda_amp": if args.half_precision_backend == "cuda_amp":
self.use_cuda_amp = True self.use_cuda_amp = True
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
...@@ -647,7 +653,7 @@ class Trainer: ...@@ -647,7 +653,7 @@ class Trainer:
elif args.half_precision_backend == "cpu_amp": elif args.half_precision_backend == "cpu_amp":
self.use_cpu_amp = True self.use_cpu_amp = True
self.amp_dtype = torch.bfloat16 self.amp_dtype = torch.bfloat16
else: elif args.half_precision_backend == "apex":
if not is_apex_available(): if not is_apex_available():
raise ImportError( raise ImportError(
"Using FP16 with APEX but APEX is not installed, please refer to" "Using FP16 with APEX but APEX is not installed, please refer to"
...@@ -1801,6 +1807,11 @@ class Trainer: ...@@ -1801,6 +1807,11 @@ class Trainer:
if delay_optimizer_creation: if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# prepare using `accelerator` prepare
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
# Check if saved optimizer or scheduler states exist # Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint) self._load_optimizer_and_scheduler(resume_from_checkpoint)
...@@ -2013,10 +2024,15 @@ class Trainer: ...@@ -2013,10 +2024,15 @@ class Trainer:
elif hasattr(model, "clip_grad_norm_"): elif hasattr(model, "clip_grad_norm_"):
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
model.clip_grad_norm_(args.max_grad_norm) model.clip_grad_norm_(args.max_grad_norm)
else: elif self.use_apex:
# Revert to normal clipping otherwise, handling Apex or full precision # Revert to normal clipping otherwise, handling Apex or full precision
nn.utils.clip_grad_norm_( nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer) if self.use_apex else model.parameters(), amp.master_params(self.optimizer),
args.max_grad_norm,
)
else:
self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm, args.max_grad_norm,
) )
...@@ -2802,7 +2818,7 @@ class Trainer: ...@@ -2802,7 +2818,7 @@ class Trainer:
# loss gets scaled under gradient_accumulation_steps in deepspeed # loss gets scaled under gradient_accumulation_steps in deepspeed
loss = self.deepspeed.backward(loss) loss = self.deepspeed.backward(loss)
else: else:
loss.backward() self.accelerator.backward(loss)
return loss.detach() return loss.detach()
......
...@@ -1562,6 +1562,15 @@ class TrainingArguments: ...@@ -1562,6 +1562,15 @@ class TrainingArguments:
FutureWarning, FutureWarning,
) )
# if training args is specified, it will override the one specified in the accelerate config
if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0:
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
if self.fp16:
mixed_precision_dtype = "fp16"
elif self.bf16:
mixed_precision_dtype = "bf16"
os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
def __str__(self): def __str__(self):
self_as_dict = asdict(self) self_as_dict = asdict(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment