Commit 07916bf2 authored by Jared Casper's avatar Jared Casper
Browse files

Support gradient accumulation fusion in fp16.

parent 2366716f
...@@ -168,14 +168,6 @@ def validate_args(args, defaults={}): ...@@ -168,14 +168,6 @@ def validate_args(args, defaults={}):
if args.accumulate_allreduce_grads_in_fp32: if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp assert args.use_contiguous_buffers_in_local_ddp
else:
if args.gradient_accumulation_fusion:
args.gradient_accumulation_fusion = False
if args.rank == 0:
print('Gradient accumulation fusion to linear layer weight '
'gradient computation is supported only with fp32 '
'gradient accumulation. Setting gradient_accumulation_fusion '
'to False', flush=True)
# If we use the distributed optimizer, we need to have local DDP # If we use the distributed optimizer, we need to have local DDP
# and we should make sure use-contiguous-buffers-in-local-ddp is on. # and we should make sure use-contiguous-buffers-in-local-ddp is on.
......
...@@ -302,7 +302,12 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -302,7 +302,12 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
if ctx.gradient_accumulation_fusion: if ctx.gradient_accumulation_fusion:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad) if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
elif weight.main_grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, weight.main_grad)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
grad_weight = None grad_weight = None
else: else:
grad_weight = grad_output.t().matmul(total_input) grad_weight = grad_output.t().matmul(total_input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment