Unverified Commit fb13a7df authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

do not scale gradient in bf16 mode (#21428)

* no dot scale gradient in bf16 mode

* fix since args.fp16 might be none

* fixed typo

* typo

* only do if grad scaling is true

* self.amp_dtype == torch.float16 is true

* put back prop when fsdp is not none
parent 197e7ce9
...@@ -595,27 +595,26 @@ class Trainer: ...@@ -595,27 +595,26 @@ class Trainer:
if args.half_precision_backend == "cuda_amp": if args.half_precision_backend == "cuda_amp":
self.use_cuda_amp = True self.use_cuda_amp = True
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
self.do_grad_scaling = True # bf16 does not need grad scaling
if self.sharded_ddp is not None: self.do_grad_scaling = self.amp_dtype == torch.float16
self.scaler = ShardedGradScaler() if self.do_grad_scaling:
elif self.fsdp is not None: if self.sharded_ddp is not None:
if self.amp_dtype == torch.float16: self.scaler = ShardedGradScaler()
elif self.fsdp is not None:
from torch.distributed.fsdp.sharded_grad_scaler import ( from torch.distributed.fsdp.sharded_grad_scaler import (
ShardedGradScaler as FSDPShardedGradScaler, ShardedGradScaler as FSDPShardedGradScaler,
) )
self.scaler = FSDPShardedGradScaler() self.scaler = FSDPShardedGradScaler()
else: elif is_torch_tpu_available():
self.do_grad_scaling = False from torch_xla.amp import GradScaler
self.use_cuda_amp = False
self.amp_dtype = None
elif is_torch_tpu_available():
from torch_xla.amp import GradScaler
self.scaler = GradScaler() self.scaler = GradScaler()
else: else:
self.scaler = torch.cuda.amp.GradScaler() self.scaler = torch.cuda.amp.GradScaler()
elif self.fsdp is not None:
self.use_cuda_amp = False
self.amp_dtype = None
elif args.half_precision_backend == "cpu_amp": elif args.half_precision_backend == "cpu_amp":
self.use_cpu_amp = True self.use_cpu_amp = True
self.amp_dtype = torch.bfloat16 self.amp_dtype = torch.bfloat16
...@@ -669,7 +668,7 @@ class Trainer: ...@@ -669,7 +668,7 @@ class Trainer:
# torch.compile # torch.compile
if args.torch_compile and not is_torch_compile_available(): if args.torch_compile and not is_torch_compile_available():
raise RuntimeError("Using torch.compile requires a nighly install of PyTorch.") raise RuntimeError("Using torch.compile requires a nightly install of PyTorch.")
def add_callback(self, callback): def add_callback(self, callback):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment