Unverified Commit a2e19b7a authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Conditional dgrad computation for Linear API (#134)



* small cleanup before starting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* conditional dgrad for Linear
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* add tests and small improvements to LNLinear and LNMLP
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4105125a
...@@ -237,9 +237,12 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -237,9 +237,12 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
torch.cuda.synchronize() torch.cuda.synchronize()
def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad):
if skip_dgrad and skip_wgrad:
pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.")
te_inp = torch.randn( te_inp = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=not skip_dgrad
).cuda() ).cuda()
if skip_wgrad: if skip_wgrad:
...@@ -261,7 +264,8 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad): ...@@ -261,7 +264,8 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad):
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): @pytest.mark.parametrize("skip_dgrad", all_boolean)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -281,7 +285,7 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_ ...@@ -281,7 +285,7 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
...@@ -289,7 +293,8 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_ ...@@ -289,7 +293,8 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad, zero_
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("skip_dgrad", all_boolean)
def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad, skip_dgrad):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -305,7 +310,7 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -305,7 +310,7 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
...@@ -314,7 +319,8 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad): ...@@ -314,7 +319,8 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): @pytest.mark.parametrize("skip_dgrad", all_boolean)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -336,7 +342,7 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen ...@@ -336,7 +342,7 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
......
...@@ -870,6 +870,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -870,6 +870,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output = return_layernorm_output
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma ctx.zero_centered_gamma = zero_centered_gamma
ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and sequence_parallel: if parallel_mode == "row" and sequence_parallel:
...@@ -1040,7 +1041,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -1040,7 +1041,7 @@ class _LayerNormLinear(torch.autograd.Function):
grad_bias = None grad_bias = None
return ( return (
dxmat.view(ctx.inp_shape), dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma, dgamma,
dbeta, dbeta,
wgrad if weight.requires_grad else None, wgrad if weight.requires_grad else None,
...@@ -1562,7 +1563,7 @@ class _Linear(torch.autograd.Function): ...@@ -1562,7 +1563,7 @@ class _Linear(torch.autograd.Function):
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.requires_wgrad = weight.requires_grad ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and sequence_parallel: if parallel_mode == "row" and sequence_parallel:
...@@ -1601,11 +1602,11 @@ class _Linear(torch.autograd.Function): ...@@ -1601,11 +1602,11 @@ class _Linear(torch.autograd.Function):
if ctx.parallel_mode == "column" and ctx.sequence_parallel: if ctx.parallel_mode == "column" and ctx.sequence_parallel:
if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
inputmat_t_total, handle = gather_along_last_dim( inputmat_t_total, handle = gather_along_last_dim(
inputmat_t, ctx.tp_group, async_op=True inputmat_t, ctx.tp_group, async_op=ctx.requires_dgrad
) )
else: else:
inputmat_total, handle = gather_along_first_dim( inputmat_total, handle = gather_along_first_dim(
inputmat, ctx.tp_group, async_op=True inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
) )
else: else:
inputmat_t_total = inputmat_t inputmat_t_total = inputmat_t
...@@ -1626,41 +1627,41 @@ class _Linear(torch.autograd.Function): ...@@ -1626,41 +1627,41 @@ class _Linear(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
# DGRAD if ctx.requires_dgrad:
dgrad = fp8_gemm( if ctx.fp8:
weight_t_fp8, dgrad = fp8_gemm(
fwd_scale_inverses, weight_t_fp8,
tex.FP8FwdTensors.GEMM1_WEIGHT, fwd_scale_inverses,
fp8_dtype_forward, tex.FP8FwdTensors.GEMM1_WEIGHT,
grad_output_c, fp8_dtype_forward,
ctx.fp8_meta["scaling_bwd"].scale_inv, grad_output_c,
tex.FP8BwdTensors.GRAD_OUTPUT1, ctx.fp8_meta["scaling_bwd"].scale_inv,
fp8_dtype_backward, tex.FP8BwdTensors.GRAD_OUTPUT1,
ctx.activation_dtype, fp8_dtype_backward,
get_workspace(), ctx.activation_dtype,
use_split_accumulator=_2X_ACC_DGRAD, get_workspace(),
) use_split_accumulator=_2X_ACC_DGRAD,
else: )
# DGRAD else:
dgrad, _, _ = gemm( dgrad, _, _ = gemm(
weight, weight,
grad_output, grad_output,
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
layout="NN", layout="NN",
grad=True, grad=True,
) )
# Overlap dgrad-RS/AR with wgrad # Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel: if ctx.parallel_mode == "column" and ctx.sequence_parallel:
handle.wait() handle.wait()
dgrad, handle = reduce_scatter_along_first_dim( dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True dgrad, ctx.tp_group, async_op=True
) )
elif ctx.parallel_mode == "column" and ctx.tensor_parallel: elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if ctx.requires_wgrad: if weight.requires_grad:
if ctx.fp8: if ctx.fp8:
# WGRAD # WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
...@@ -1712,10 +1713,10 @@ class _Linear(torch.autograd.Function): ...@@ -1712,10 +1713,10 @@ class _Linear(torch.autograd.Function):
grad_bias = None grad_bias = None
return ( return (
wgrad if ctx.requires_wgrad else None, wgrad if weight.requires_grad else None,
None, None,
None, None,
dgrad.view(ctx.inp_shape), dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
grad_bias, grad_bias,
None, None,
None, None,
...@@ -2281,6 +2282,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2281,6 +2282,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.set_parallel_mode = set_parallel_mode ctx.set_parallel_mode = set_parallel_mode
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma ctx.zero_centered_gamma = zero_centered_gamma
ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear # Row Parallel Linear
if set_parallel_mode and sequence_parallel: if set_parallel_mode and sequence_parallel:
...@@ -2575,7 +2577,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2575,7 +2577,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_bias_grad = None fc2_bias_grad = None
return ( return (
dxmat.view(ctx.inp_shape), dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma, dgamma,
dbeta, dbeta,
fc1_wgrad if fc1_weight.requires_grad else None, fc1_wgrad if fc1_weight.requires_grad else None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment