Unverified Commit 06eebf66 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Mcore DDP support (#446)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 76669cdd
...@@ -226,9 +226,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_ ...@@ -226,9 +226,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_
if "layer_norm_weight" in name: if "layer_norm_weight" in name:
continue continue
elif "weight" in name and p.requires_grad: elif "weight" in name and p.requires_grad:
assert ( assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."
p.grad is None and torch.count_nonzero(p.main_grad) > 0
), "Gradient not accumulated."
def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
......
...@@ -45,7 +45,6 @@ def fp8_gemm( ...@@ -45,7 +45,6 @@ def fp8_gemm(
assert_dim_for_fp8_exec(A) assert_dim_for_fp8_exec(A)
assert_dim_for_fp8_exec(B) assert_dim_for_fp8_exec(B)
return_output = False
if out is None: if out is None:
out = torch.empty( out = torch.empty(
B.shape[0], B.shape[0],
...@@ -53,7 +52,7 @@ def fp8_gemm( ...@@ -53,7 +52,7 @@ def fp8_gemm(
dtype=out_dtype, dtype=out_dtype,
device="cuda", device="cuda",
) )
return_output = True
# Use bfloat16 as default bias_dtype # Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias.dtype bias_dtype = torch.bfloat16 if bias is None else bias.dtype
if gelu: if gelu:
...@@ -110,13 +109,7 @@ def fp8_gemm( ...@@ -110,13 +109,7 @@ def fp8_gemm(
args = tuple(args + (True, extra_output_tensor,)) args = tuple(args + (True, extra_output_tensor,))
_ = fn(*args) _ = fn(*args)
if return_output: return out, gelu_input
if gelu:
return out, gelu_input
return out
if gelu:
return gelu_input
return None
def gemm( def gemm(
...@@ -144,7 +137,6 @@ def gemm( ...@@ -144,7 +137,6 @@ def gemm(
empty_tensor = torch.Tensor() empty_tensor = torch.Tensor()
fp8_index = -1 # dummy index fp8_index = -1 # dummy index
return_output = False
if out is None: if out is None:
out = torch.empty( out = torch.empty(
B.shape[1] if transb else B.shape[0], B.shape[1] if transb else B.shape[0],
...@@ -152,7 +144,6 @@ def gemm( ...@@ -152,7 +144,6 @@ def gemm(
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
) )
return_output = True
if gelu and not grad: if gelu and not grad:
gelu_input = torch.empty_like(out, dtype=dtype) gelu_input = torch.empty_like(out, dtype=dtype)
...@@ -222,6 +213,4 @@ def gemm( ...@@ -222,6 +213,4 @@ def gemm(
args = tuple(args + (False, extra_output_tensor,)) args = tuple(args + (False, extra_output_tensor,))
_ = fn(*args) _ = fn(*args)
if return_output: return out, grad_bias, gelu_input
return out, grad_bias, gelu_input
return None, grad_bias, gelu_input
...@@ -173,7 +173,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -173,7 +173,7 @@ class _LayerNormLinear(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward) fp8_dtype_forward)
out = tex.fp8_gemm( out, _ = tex.fp8_gemm(
weight_fp8, weight_fp8,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
...@@ -389,7 +389,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -389,7 +389,7 @@ class _LayerNormLinear(torch.autograd.Function):
# WGRAD # WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
wgrad = tex.fp8_gemm( wgrad, _ = tex.fp8_gemm(
ln_out_total_t, ln_out_total_t,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
...@@ -444,7 +444,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -444,7 +444,6 @@ class _LayerNormLinear(torch.autograd.Function):
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
) )
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
# Column Parallel Linear # Column Parallel Linear
...@@ -474,6 +473,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -474,6 +473,9 @@ class _LayerNormLinear(torch.autograd.Function):
if not ctx.use_bias: if not ctx.use_bias:
grad_bias = None grad_bias = None
# Handle custom DDP from mcore.
weight.grad_added_to_main_grad = ctx.fuse_wgrad_accumulation
return ( return (
dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None, dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma, dgamma,
......
...@@ -223,7 +223,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -223,7 +223,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
) )
fc1_out = tex.fp8_gemm( fc1_out, _ = tex.fp8_gemm(
fc1_weight_fp8, fc1_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
...@@ -344,7 +344,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -344,7 +344,7 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size = list(gelu_out.size()) dim_size = list(gelu_out.size())
dim_size[1] = fc2_weight.size(0) dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
_, _, _ = tex.gemm( _ = tex.gemm(
fc2_weight, fc2_weight,
gelu_out, gelu_out,
activation_dtype, activation_dtype,
...@@ -498,7 +498,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -498,7 +498,7 @@ class _LayerNormMLP(torch.autograd.Function):
) )
# FC2 DGRAD; Unconditional # FC2 DGRAD; Unconditional
fc2_dgrad = tex.fp8_gemm( fc2_dgrad, _ = tex.fp8_gemm(
fc2_weight_t_fp8, fc2_weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_WEIGHT, tex.FP8FwdTensors.GEMM2_WEIGHT,
...@@ -519,7 +519,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -519,7 +519,7 @@ class _LayerNormMLP(torch.autograd.Function):
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if fc2_weight.requires_grad: if fc2_weight.requires_grad:
gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward) gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward)
fc2_wgrad = tex.fp8_gemm( fc2_wgrad, _ = tex.fp8_gemm(
gelu_out_t, gelu_out_t,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_INPUT, tex.FP8FwdTensors.GEMM2_INPUT,
...@@ -675,7 +675,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -675,7 +675,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
) )
# FC1 DGRAD: Unconditional # FC1 DGRAD: Unconditional
_, _, _ = tex.gemm( _ = tex.gemm(
fc1_weight, fc1_weight,
dgelu, dgelu,
ctx.activation_dtype, ctx.activation_dtype,
...@@ -705,7 +705,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -705,7 +705,7 @@ class _LayerNormMLP(torch.autograd.Function):
# FC1 WGRAD # FC1 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
fc1_wgrad = tex.fp8_gemm( fc1_wgrad, _ = tex.fp8_gemm(
ln_out_total_t, ln_out_total_t,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
...@@ -794,6 +794,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -794,6 +794,10 @@ class _LayerNormMLP(torch.autograd.Function):
) )
dbeta = None dbeta = None
# Handle custom DDP from mcore.
fc1_weight.grad_added_to_main_grad = ctx.fuse_wgrad_accumulation
fc2_weight.grad_added_to_main_grad = ctx.fuse_wgrad_accumulation
return ( return (
dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None, dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma, dgamma,
......
...@@ -211,7 +211,7 @@ class _Linear(torch.autograd.Function): ...@@ -211,7 +211,7 @@ class _Linear(torch.autograd.Function):
dim_size[1] = weight.size(0) dim_size[1] = weight.size(0)
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
_, _, _ = gemm( _ = gemm(
weight, weight,
inputmat_total, inputmat_total,
activation_dtype, activation_dtype,
...@@ -325,7 +325,7 @@ class _Linear(torch.autograd.Function): ...@@ -325,7 +325,7 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad: if ctx.requires_dgrad:
if ctx.fp8: if ctx.fp8:
dgrad = fp8_gemm( dgrad, _ = fp8_gemm(
weight_t_fp8, weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
...@@ -368,7 +368,7 @@ class _Linear(torch.autograd.Function): ...@@ -368,7 +368,7 @@ class _Linear(torch.autograd.Function):
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if ctx.ub_split_ag: if ctx.ub_split_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
wgrad = fp8_gemm( wgrad, _ = fp8_gemm(
inputmat_t_total, inputmat_t_total,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
...@@ -415,6 +415,9 @@ class _Linear(torch.autograd.Function): ...@@ -415,6 +415,9 @@ class _Linear(torch.autograd.Function):
if not ctx.use_bias: if not ctx.use_bias:
grad_bias = None grad_bias = None
# Handle custom DDP from mcore.
weight.grad_added_to_main_grad = ctx.fuse_wgrad_accumulation
return ( return (
wgrad if weight.requires_grad else None, wgrad if weight.requires_grad else None,
None, None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment