do not scale gradient in bf16 mode (#21428)
* no dot scale gradient in bf16 mode * fix since args.fp16 might be none * fixed typo * typo * only do if grad scaling is true * self.amp_dtype == torch.float16 is true * put back prop when fsdp is not none
Showing
Please register or sign in to comment