Unverified Commit 78210127 authored by schetlur-nv's avatar schetlur-nv Committed by GitHub
Browse files

Conditional wgrad support (#21)



* Conditional dgrad/wgrad support
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Fixing the change to depend only on requires_grad. Also updating LayerNorm MLP
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Minor fixes.
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Adding conditional wgrad for LayerNormLinear
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* bug fix and remove conditional dgrad

Co-authored-by: schetlur-nv schetlur@nvidia.com
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Adding unit test for wgrad disabled path
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Adding more unit tests for wgrad disabled path
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Adding unit tests for fp8 wgrad disabling, and cleaning up the code.
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* fix lint errors
Co-Authored-By: default avatarSharan Chetlur <schetlur@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 924892fd
......@@ -86,8 +86,15 @@ param_types = [torch.float32, torch.bfloat16, torch.float16]
batch_sizes = [1, 2]
skip_wgrad = [True, False]
def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe):
def _disable_wgrads(block):
for p in block.parameters():
p.requires_grad = False
def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda()
......@@ -103,6 +110,10 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe):
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
......@@ -113,7 +124,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe):
torch.cuda.synchronize()
def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe):
def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
......@@ -129,6 +140,10 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe):
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum()
......@@ -136,7 +151,7 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe):
torch.cuda.synchronize()
def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe):
def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
......@@ -152,6 +167,10 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe):
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
......@@ -161,10 +180,14 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe):
torch.cuda.synchronize()
def _test_sanity_common(block, bs, dtype, config, fp8_recipe):
def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
if skip_wgrad:
_disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(te_inp)
if isinstance(te_out, tuple):
......@@ -178,7 +201,8 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe):
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -194,14 +218,15 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model):
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, fp8_recipe)
_test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_linear(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -214,14 +239,15 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model):
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, fp8_recipe)
_test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -239,14 +265,15 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model):
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, fp8_recipe)
_test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_gpt(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -271,14 +298,15 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_bert(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -303,14 +331,15 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_T5(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -336,14 +365,15 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model):
.cuda()
)
_test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe)
_test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -366,14 +396,15 @@ def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model):
.cuda()
)
_test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe)
_test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_drop_path(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -399,14 +430,15 @@ def test_sanity_drop_path(dtype, bs, fp8_recipe, model):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -432,4 +464,4 @@ def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
......@@ -37,14 +37,22 @@ param_types = [torch.float32, torch.bfloat16, torch.float16]
batch_sizes = [1, 2]
skip_wgrad = [True, False]
def _test_sanity_e2e_amp(block, bs, dtype, config):
def _disable_wgrads(block):
for p in block.parameters():
p.requires_grad = False
def _test_sanity_e2e_amp(block, bs, dtype, config, skip_wgrad):
if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
return
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
......@@ -57,6 +65,10 @@ def _test_sanity_e2e_amp(block, bs, dtype, config):
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum()
......@@ -66,7 +78,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config):
torch.cuda.synchronize()
def _test_sanity_e2e(block, bs, dtype, config):
def _test_sanity_e2e(block, bs, dtype, config, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
......@@ -82,13 +94,17 @@ def _test_sanity_e2e(block, bs, dtype, config):
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
def _test_sanity_e2e_T5(block, bs, dtype, config):
def _test_sanity_e2e_T5(block, bs, dtype, config, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
......@@ -104,6 +120,10 @@ def _test_sanity_e2e_T5(block, bs, dtype, config):
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
)
......@@ -112,10 +132,14 @@ def _test_sanity_e2e_T5(block, bs, dtype, config):
torch.cuda.synchronize()
def _test_sanity_common(block, bs, dtype, config):
def _test_sanity_common(block, bs, dtype, config, skip_wgrad):
te_inp = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
if skip_wgrad:
_disable_wgrads(block)
te_out = block(te_inp)
if isinstance(te_out, tuple):
te_out = te_out[0]
......@@ -127,7 +151,8 @@ def _test_sanity_common(block, bs, dtype, config):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_linear(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_linear(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -143,13 +168,14 @@ def test_sanity_layernorm_linear(dtype, bs, model):
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config)
_test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_linear(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_linear(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -162,13 +188,15 @@ def test_sanity_linear(dtype, bs, model):
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config)
_test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_mlp(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_mlp(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -186,13 +214,14 @@ def test_sanity_layernorm_mlp(dtype, bs, model):
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config)
_test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_gpt(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_gpt(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -217,13 +246,14 @@ def test_sanity_gpt(dtype, bs, model):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config)
_test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_bert(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_bert(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -248,13 +278,14 @@ def test_sanity_bert(dtype, bs, model):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config)
_test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_T5(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_T5(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -280,13 +311,14 @@ def test_sanity_T5(dtype, bs, model):
.cuda()
)
_test_sanity_e2e_T5(block, bs, dtype, config)
_test_sanity_e2e_T5(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_amp_and_nvfuser(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_amp_and_nvfuser(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -309,13 +341,14 @@ def test_sanity_amp_and_nvfuser(dtype, bs, model):
.cuda()
)
_test_sanity_e2e_amp(block, bs, dtype, config)
_test_sanity_e2e_amp(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_drop_path(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_drop_path(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -341,13 +374,14 @@ def test_sanity_drop_path(dtype, bs, model):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config)
_test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_fused_qkv_params(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_fused_qkv_params(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -373,4 +407,4 @@ def test_sanity_fused_qkv_params(dtype, bs, model):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config)
_test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
......@@ -683,7 +683,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False
)
# DGRAD
# DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
......@@ -696,7 +696,7 @@ class _LayerNormLinear(torch.autograd.Function):
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# DGRAD
# DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad, _, _ = gemm(
weight,
grad_output,
......@@ -715,6 +715,7 @@ class _LayerNormLinear(torch.autograd.Function):
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if weight.requires_grad:
if ctx.fp8:
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
......@@ -795,7 +796,7 @@ class _LayerNormLinear(torch.autograd.Function):
dxmat.view(ctx.inp_shape),
dgamma,
dbeta,
wgrad,
wgrad if weight.requires_grad else None,
None,
None,
grad_bias,
......@@ -1292,6 +1293,7 @@ class _Linear(torch.autograd.Function):
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if weight.requires_grad:
if ctx.fp8:
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
......@@ -1350,7 +1352,7 @@ class _Linear(torch.autograd.Function):
)
return (
wgrad,
wgrad if weight.requires_grad else None,
None,
None,
dgrad.view(ctx.inp_shape),
......@@ -1858,7 +1860,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False
)
# FC2 DGRAD
# FC2 DGRAD; Unconditional
fc2_dgrad = fp8_gemm(
fc2_weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_WEIGHT],
......@@ -1873,8 +1875,8 @@ class _LayerNormMLP(torch.autograd.Function):
# FC2 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if fc2_weight.requires_grad:
gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward)
fc2_wgrad = fp8_gemm(
gelu_out_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_INPUT],
......@@ -1888,7 +1890,9 @@ class _LayerNormMLP(torch.autograd.Function):
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
out=fc2_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
......@@ -1900,6 +1904,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_backward,
)
else:
if fc2_weight.requires_grad:
gelu_out_c = cast_from_fp8(
gelu_out,
ctx.fp8_meta["scaling_fwd"],
......@@ -1917,12 +1922,15 @@ class _LayerNormMLP(torch.autograd.Function):
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
out=fc2_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
)
fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused(
fc2_dgrad, fc1_out, fc1_bias
)
dgelu = cast_to_fp8(
dgelu_no_fp8,
ctx.fp8_meta["scaling_bwd"],
......@@ -1931,7 +1939,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
dgelu_t = None
# FC1 DGRAD
# FC1 DGRAD: Unconditional
fc1_dgrad = fp8_gemm(
fc1_weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
......@@ -1944,7 +1952,7 @@ class _LayerNormMLP(torch.autograd.Function):
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# FC2 DGRAD
# FC2 DGRAD; Unconditional
fc2_dgrad, _, _ = gemm(
fc2_weight,
grad_output,
......@@ -1957,6 +1965,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
# FC2 WGRAD
if fc2_weight.requires_grad:
fc2_wgrad, fc2_bias_grad, _ = gemm(
gelu_out,
grad_output,
......@@ -1975,7 +1984,7 @@ class _LayerNormMLP(torch.autograd.Function):
else:
dgelu = fc2_dgrad
# FC1 DGRAD
# FC1 DGRAD: Unconditional
fc1_dgrad, _, _ = gemm(
fc1_weight,
dgelu,
......@@ -1994,6 +2003,7 @@ class _LayerNormMLP(torch.autograd.Function):
elif ctx.set_parallel_mode and ctx.tensor_parallel:
fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)
if fc1_weight.requires_grad:
if ctx.fp8:
# FC1 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
......@@ -2011,7 +2021,9 @@ class _LayerNormMLP(torch.autograd.Function):
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
out=fc1_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
else:
......@@ -2031,7 +2043,9 @@ class _LayerNormMLP(torch.autograd.Function):
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
out=fc1_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
)
else:
# FC1 WGRAD
......@@ -2079,11 +2093,11 @@ class _LayerNormMLP(torch.autograd.Function):
dxmat.view(ctx.inp_shape),
dgamma,
dbeta,
fc1_wgrad,
fc1_wgrad if fc1_weight.requires_grad else None,
None,
None,
fc1_bias_grad,
fc2_wgrad,
fc2_wgrad if fc2_weight.requires_grad else None,
None,
None,
fc2_bias_grad,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment