Unverified Commit 06eebf66 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Mcore DDP support (#446)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 76669cdd
......@@ -226,9 +226,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, bs, dtype, config, fp8_
if "layer_norm_weight" in name:
continue
elif "weight" in name and p.requires_grad:
assert (
p.grad is None and torch.count_nonzero(p.main_grad) > 0
), "Gradient not accumulated."
assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."
def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
......
......@@ -45,7 +45,6 @@ def fp8_gemm(
assert_dim_for_fp8_exec(A)
assert_dim_for_fp8_exec(B)
return_output = False
if out is None:
out = torch.empty(
B.shape[0],
......@@ -53,7 +52,7 @@ def fp8_gemm(
dtype=out_dtype,
device="cuda",
)
return_output = True
# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias.dtype
if gelu:
......@@ -110,13 +109,7 @@ def fp8_gemm(
args = tuple(args + (True, extra_output_tensor,))
_ = fn(*args)
if return_output:
if gelu:
return out, gelu_input
return out
if gelu:
return gelu_input
return None
return out, gelu_input
def gemm(
......@@ -144,7 +137,6 @@ def gemm(
empty_tensor = torch.Tensor()
fp8_index = -1 # dummy index
return_output = False
if out is None:
out = torch.empty(
B.shape[1] if transb else B.shape[0],
......@@ -152,7 +144,6 @@ def gemm(
dtype=dtype,
device="cuda",
)
return_output = True
if gelu and not grad:
gelu_input = torch.empty_like(out, dtype=dtype)
......@@ -222,6 +213,4 @@ def gemm(
args = tuple(args + (False, extra_output_tensor,))
_ = fn(*args)
if return_output:
return out, grad_bias, gelu_input
return None, grad_bias, gelu_input
return out, grad_bias, gelu_input
......@@ -173,7 +173,7 @@ class _LayerNormLinear(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward)
out = tex.fp8_gemm(
out, _ = tex.fp8_gemm(
weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
......@@ -389,7 +389,7 @@ class _LayerNormLinear(torch.autograd.Function):
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
wgrad = tex.fp8_gemm(
wgrad, _ = tex.fp8_gemm(
ln_out_total_t,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
......@@ -444,7 +444,6 @@ class _LayerNormLinear(torch.autograd.Function):
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
)
if ctx.ub_bulk_wgrad:
dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
# Column Parallel Linear
......@@ -474,6 +473,9 @@ class _LayerNormLinear(torch.autograd.Function):
if not ctx.use_bias:
grad_bias = None
# Handle custom DDP from mcore.
weight.grad_added_to_main_grad = ctx.fuse_wgrad_accumulation
return (
dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
......
......@@ -223,7 +223,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward,
)
fc1_out = tex.fp8_gemm(
fc1_out, _ = tex.fp8_gemm(
fc1_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
......@@ -344,7 +344,7 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size = list(gelu_out.size())
dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
_, _, _ = tex.gemm(
_ = tex.gemm(
fc2_weight,
gelu_out,
activation_dtype,
......@@ -498,7 +498,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
# FC2 DGRAD; Unconditional
fc2_dgrad = tex.fp8_gemm(
fc2_dgrad, _ = tex.fp8_gemm(
fc2_weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_WEIGHT,
......@@ -519,7 +519,7 @@ class _LayerNormMLP(torch.autograd.Function):
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if fc2_weight.requires_grad:
gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward)
fc2_wgrad = tex.fp8_gemm(
fc2_wgrad, _ = tex.fp8_gemm(
gelu_out_t,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_INPUT,
......@@ -675,7 +675,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
)
# FC1 DGRAD: Unconditional
_, _, _ = tex.gemm(
_ = tex.gemm(
fc1_weight,
dgelu,
ctx.activation_dtype,
......@@ -705,7 +705,7 @@ class _LayerNormMLP(torch.autograd.Function):
# FC1 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
fc1_wgrad = tex.fp8_gemm(
fc1_wgrad, _ = tex.fp8_gemm(
ln_out_total_t,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
......@@ -794,6 +794,10 @@ class _LayerNormMLP(torch.autograd.Function):
)
dbeta = None
# Handle custom DDP from mcore.
fc1_weight.grad_added_to_main_grad = ctx.fuse_wgrad_accumulation
fc2_weight.grad_added_to_main_grad = ctx.fuse_wgrad_accumulation
return (
dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
......
......@@ -211,7 +211,7 @@ class _Linear(torch.autograd.Function):
dim_size[1] = weight.size(0)
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
_, _, _ = gemm(
_ = gemm(
weight,
inputmat_total,
activation_dtype,
......@@ -325,7 +325,7 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad:
if ctx.fp8:
dgrad = fp8_gemm(
dgrad, _ = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
......@@ -368,7 +368,7 @@ class _Linear(torch.autograd.Function):
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if ctx.ub_split_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
wgrad = fp8_gemm(
wgrad, _ = fp8_gemm(
inputmat_t_total,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
......@@ -415,6 +415,9 @@ class _Linear(torch.autograd.Function):
if not ctx.use_bias:
grad_bias = None
# Handle custom DDP from mcore.
weight.grad_added_to_main_grad = ctx.fuse_wgrad_accumulation
return (
wgrad if weight.requires_grad else None,
None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment