Unverified Commit 78210127 authored by schetlur-nv's avatar schetlur-nv Committed by GitHub
Browse files

Conditional wgrad support (#21)



* Conditional dgrad/wgrad support
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Fixing the change to depend only on requires_grad. Also updating LayerNorm MLP
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Minor fixes.
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Adding conditional wgrad for LayerNormLinear
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* bug fix and remove conditional dgrad

Co-authored-by: schetlur-nv schetlur@nvidia.com
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Adding unit test for wgrad disabled path
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Adding more unit tests for wgrad disabled path
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Adding unit tests for fp8 wgrad disabling, and cleaning up the code.
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* fix lint errors
Co-Authored-By: default avatarSharan Chetlur <schetlur@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 924892fd
...@@ -86,8 +86,15 @@ param_types = [torch.float32, torch.bfloat16, torch.float16] ...@@ -86,8 +86,15 @@ param_types = [torch.float32, torch.bfloat16, torch.float16]
batch_sizes = [1, 2] batch_sizes = [1, 2]
skip_wgrad = [True, False]
def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe):
def _disable_wgrads(block):
for p in block.parameters():
p.requires_grad = False
def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda() ).cuda()
...@@ -103,6 +110,10 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe): ...@@ -103,6 +110,10 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe):
.cuda() .cuda()
.bool() .bool()
) )
if skip_wgrad:
_disable_wgrads(block)
with torch.cuda.amp.autocast(enabled=True, dtype=dtype): with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask) te_out = block(te_inp_hidden_states, te_inp_attn_mask)
...@@ -113,7 +124,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe): ...@@ -113,7 +124,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe):
torch.cuda.synchronize() torch.cuda.synchronize()
def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe): def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
...@@ -129,6 +140,10 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe): ...@@ -129,6 +140,10 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe):
.cuda() .cuda()
.bool() .bool()
) )
if skip_wgrad:
_disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask) te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
...@@ -136,7 +151,7 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe): ...@@ -136,7 +151,7 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe):
torch.cuda.synchronize() torch.cuda.synchronize()
def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe): def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
...@@ -152,6 +167,10 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe): ...@@ -152,6 +167,10 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe):
.cuda() .cuda()
.bool() .bool()
) )
if skip_wgrad:
_disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
...@@ -161,10 +180,14 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe): ...@@ -161,10 +180,14 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe):
torch.cuda.synchronize() torch.cuda.synchronize()
def _test_sanity_common(block, bs, dtype, config, fp8_recipe): def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp = torch.randn( te_inp = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
if skip_wgrad:
_disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(te_inp) te_out = block(te_inp)
if isinstance(te_out, tuple): if isinstance(te_out, tuple):
...@@ -178,7 +201,8 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe): ...@@ -178,7 +201,8 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -194,14 +218,15 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model): ...@@ -194,14 +218,15 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model):
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_common(block, bs, dtype, config, fp8_recipe) _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_linear(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -214,14 +239,15 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model): ...@@ -214,14 +239,15 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model):
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_common(block, bs, dtype, config, fp8_recipe) _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -239,14 +265,15 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model): ...@@ -239,14 +265,15 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model):
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_common(block, bs, dtype, config, fp8_recipe) _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_gpt(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -271,14 +298,15 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model): ...@@ -271,14 +298,15 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model):
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe) _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_bert(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -303,14 +331,15 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model): ...@@ -303,14 +331,15 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model):
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe) _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_T5(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -336,14 +365,15 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model): ...@@ -336,14 +365,15 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model):
.cuda() .cuda()
) )
_test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe) _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -366,14 +396,15 @@ def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model): ...@@ -366,14 +396,15 @@ def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model):
.cuda() .cuda()
) )
_test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe) _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_drop_path(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -399,14 +430,15 @@ def test_sanity_drop_path(dtype, bs, fp8_recipe, model): ...@@ -399,14 +430,15 @@ def test_sanity_drop_path(dtype, bs, fp8_recipe, model):
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe) _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -432,4 +464,4 @@ def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model): ...@@ -432,4 +464,4 @@ def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model):
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe) _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
...@@ -37,14 +37,22 @@ param_types = [torch.float32, torch.bfloat16, torch.float16] ...@@ -37,14 +37,22 @@ param_types = [torch.float32, torch.bfloat16, torch.float16]
batch_sizes = [1, 2] batch_sizes = [1, 2]
skip_wgrad = [True, False]
def _test_sanity_e2e_amp(block, bs, dtype, config):
def _disable_wgrads(block):
for p in block.parameters():
p.requires_grad = False
def _test_sanity_e2e_amp(block, bs, dtype, config, skip_wgrad):
if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
return return
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda() ).cuda()
te_inp_attn_mask = ( te_inp_attn_mask = (
torch.rand( torch.rand(
( (
...@@ -57,6 +65,10 @@ def _test_sanity_e2e_amp(block, bs, dtype, config): ...@@ -57,6 +65,10 @@ def _test_sanity_e2e_amp(block, bs, dtype, config):
.cuda() .cuda()
.bool() .bool()
) )
if skip_wgrad:
_disable_wgrads(block)
with torch.cuda.amp.autocast(enabled=True, dtype=dtype): with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
te_out = block(te_inp_hidden_states, te_inp_attn_mask) te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
...@@ -66,7 +78,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config): ...@@ -66,7 +78,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config):
torch.cuda.synchronize() torch.cuda.synchronize()
def _test_sanity_e2e(block, bs, dtype, config): def _test_sanity_e2e(block, bs, dtype, config, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
...@@ -82,13 +94,17 @@ def _test_sanity_e2e(block, bs, dtype, config): ...@@ -82,13 +94,17 @@ def _test_sanity_e2e(block, bs, dtype, config):
.cuda() .cuda()
.bool() .bool()
) )
if skip_wgrad:
_disable_wgrads(block)
te_out = block(te_inp_hidden_states, te_inp_attn_mask) te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
torch.cuda.synchronize() torch.cuda.synchronize()
def _test_sanity_e2e_T5(block, bs, dtype, config): def _test_sanity_e2e_T5(block, bs, dtype, config, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
...@@ -104,6 +120,10 @@ def _test_sanity_e2e_T5(block, bs, dtype, config): ...@@ -104,6 +120,10 @@ def _test_sanity_e2e_T5(block, bs, dtype, config):
.cuda() .cuda()
.bool() .bool()
) )
if skip_wgrad:
_disable_wgrads(block)
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
) )
...@@ -112,10 +132,14 @@ def _test_sanity_e2e_T5(block, bs, dtype, config): ...@@ -112,10 +132,14 @@ def _test_sanity_e2e_T5(block, bs, dtype, config):
torch.cuda.synchronize() torch.cuda.synchronize()
def _test_sanity_common(block, bs, dtype, config): def _test_sanity_common(block, bs, dtype, config, skip_wgrad):
te_inp = torch.randn( te_inp = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
if skip_wgrad:
_disable_wgrads(block)
te_out = block(te_inp) te_out = block(te_inp)
if isinstance(te_out, tuple): if isinstance(te_out, tuple):
te_out = te_out[0] te_out = te_out[0]
...@@ -127,7 +151,8 @@ def _test_sanity_common(block, bs, dtype, config): ...@@ -127,7 +151,8 @@ def _test_sanity_common(block, bs, dtype, config):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_linear(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_linear(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -143,13 +168,14 @@ def test_sanity_layernorm_linear(dtype, bs, model): ...@@ -143,13 +168,14 @@ def test_sanity_layernorm_linear(dtype, bs, model):
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_common(block, bs, dtype, config) _test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_linear(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_linear(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -162,13 +188,15 @@ def test_sanity_linear(dtype, bs, model): ...@@ -162,13 +188,15 @@ def test_sanity_linear(dtype, bs, model):
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_common(block, bs, dtype, config)
_test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_mlp(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_mlp(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -186,13 +214,14 @@ def test_sanity_layernorm_mlp(dtype, bs, model): ...@@ -186,13 +214,14 @@ def test_sanity_layernorm_mlp(dtype, bs, model):
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_common(block, bs, dtype, config) _test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_gpt(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_gpt(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -217,13 +246,14 @@ def test_sanity_gpt(dtype, bs, model): ...@@ -217,13 +246,14 @@ def test_sanity_gpt(dtype, bs, model):
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config) _test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_bert(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_bert(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -248,13 +278,14 @@ def test_sanity_bert(dtype, bs, model): ...@@ -248,13 +278,14 @@ def test_sanity_bert(dtype, bs, model):
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config) _test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_T5(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_T5(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -280,13 +311,14 @@ def test_sanity_T5(dtype, bs, model): ...@@ -280,13 +311,14 @@ def test_sanity_T5(dtype, bs, model):
.cuda() .cuda()
) )
_test_sanity_e2e_T5(block, bs, dtype, config) _test_sanity_e2e_T5(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_amp_and_nvfuser(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_amp_and_nvfuser(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -309,13 +341,14 @@ def test_sanity_amp_and_nvfuser(dtype, bs, model): ...@@ -309,13 +341,14 @@ def test_sanity_amp_and_nvfuser(dtype, bs, model):
.cuda() .cuda()
) )
_test_sanity_e2e_amp(block, bs, dtype, config) _test_sanity_e2e_amp(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_drop_path(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_drop_path(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -341,13 +374,14 @@ def test_sanity_drop_path(dtype, bs, model): ...@@ -341,13 +374,14 @@ def test_sanity_drop_path(dtype, bs, model):
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config) _test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_fused_qkv_params(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_fused_qkv_params(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -373,4 +407,4 @@ def test_sanity_fused_qkv_params(dtype, bs, model): ...@@ -373,4 +407,4 @@ def test_sanity_fused_qkv_params(dtype, bs, model):
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config) _test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
...@@ -683,7 +683,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -683,7 +683,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
# DGRAD # DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad = fp8_gemm( dgrad = fp8_gemm(
weight_t_fp8, weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT], fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
...@@ -696,7 +696,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -696,7 +696,7 @@ class _LayerNormLinear(torch.autograd.Function):
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
) )
else: else:
# DGRAD # DGRAD: Evaluated unconditionally to feed into Linear backward
dgrad, _, _ = gemm( dgrad, _, _ = gemm(
weight, weight,
grad_output, grad_output,
...@@ -715,6 +715,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -715,6 +715,7 @@ class _LayerNormLinear(torch.autograd.Function):
elif ctx.parallel_mode == "column" and ctx.tensor_parallel: elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if weight.requires_grad:
if ctx.fp8: if ctx.fp8:
# WGRAD # WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
...@@ -795,7 +796,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -795,7 +796,7 @@ class _LayerNormLinear(torch.autograd.Function):
dxmat.view(ctx.inp_shape), dxmat.view(ctx.inp_shape),
dgamma, dgamma,
dbeta, dbeta,
wgrad, wgrad if weight.requires_grad else None,
None, None,
None, None,
grad_bias, grad_bias,
...@@ -1292,6 +1293,7 @@ class _Linear(torch.autograd.Function): ...@@ -1292,6 +1293,7 @@ class _Linear(torch.autograd.Function):
elif ctx.parallel_mode == "column" and ctx.tensor_parallel: elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if weight.requires_grad:
if ctx.fp8: if ctx.fp8:
# WGRAD # WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
...@@ -1350,7 +1352,7 @@ class _Linear(torch.autograd.Function): ...@@ -1350,7 +1352,7 @@ class _Linear(torch.autograd.Function):
) )
return ( return (
wgrad, wgrad if weight.requires_grad else None,
None, None,
None, None,
dgrad.view(ctx.inp_shape), dgrad.view(ctx.inp_shape),
...@@ -1858,7 +1860,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1858,7 +1860,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
# FC2 DGRAD # FC2 DGRAD; Unconditional
fc2_dgrad = fp8_gemm( fc2_dgrad = fp8_gemm(
fc2_weight_t_fp8, fc2_weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_WEIGHT], fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_WEIGHT],
...@@ -1873,8 +1875,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1873,8 +1875,8 @@ class _LayerNormMLP(torch.autograd.Function):
# FC2 WGRAD # FC2 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if fc2_weight.requires_grad:
gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward) gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward)
fc2_wgrad = fp8_gemm( fc2_wgrad = fp8_gemm(
gelu_out_t, gelu_out_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_INPUT], fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_INPUT],
...@@ -1888,7 +1890,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1888,7 +1890,9 @@ class _LayerNormMLP(torch.autograd.Function):
get_workspace(), get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation, fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=fc2_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
) )
...@@ -1900,6 +1904,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1900,6 +1904,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_backward, fp8_dtype_backward,
) )
else: else:
if fc2_weight.requires_grad:
gelu_out_c = cast_from_fp8( gelu_out_c = cast_from_fp8(
gelu_out, gelu_out,
ctx.fp8_meta["scaling_fwd"], ctx.fp8_meta["scaling_fwd"],
...@@ -1917,12 +1922,15 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1917,12 +1922,15 @@ class _LayerNormMLP(torch.autograd.Function):
use_bias=ctx.use_bias, use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation, fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=fc2_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
) )
fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused( fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused(
fc2_dgrad, fc1_out, fc1_bias fc2_dgrad, fc1_out, fc1_bias
) )
dgelu = cast_to_fp8( dgelu = cast_to_fp8(
dgelu_no_fp8, dgelu_no_fp8,
ctx.fp8_meta["scaling_bwd"], ctx.fp8_meta["scaling_bwd"],
...@@ -1931,7 +1939,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1931,7 +1939,7 @@ class _LayerNormMLP(torch.autograd.Function):
) )
dgelu_t = None dgelu_t = None
# FC1 DGRAD # FC1 DGRAD: Unconditional
fc1_dgrad = fp8_gemm( fc1_dgrad = fp8_gemm(
fc1_weight_t_fp8, fc1_weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT], fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
...@@ -1944,7 +1952,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1944,7 +1952,7 @@ class _LayerNormMLP(torch.autograd.Function):
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
) )
else: else:
# FC2 DGRAD # FC2 DGRAD; Unconditional
fc2_dgrad, _, _ = gemm( fc2_dgrad, _, _ = gemm(
fc2_weight, fc2_weight,
grad_output, grad_output,
...@@ -1957,6 +1965,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1957,6 +1965,7 @@ class _LayerNormMLP(torch.autograd.Function):
) )
# FC2 WGRAD # FC2 WGRAD
if fc2_weight.requires_grad:
fc2_wgrad, fc2_bias_grad, _ = gemm( fc2_wgrad, fc2_bias_grad, _ = gemm(
gelu_out, gelu_out,
grad_output, grad_output,
...@@ -1975,7 +1984,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1975,7 +1984,7 @@ class _LayerNormMLP(torch.autograd.Function):
else: else:
dgelu = fc2_dgrad dgelu = fc2_dgrad
# FC1 DGRAD # FC1 DGRAD: Unconditional
fc1_dgrad, _, _ = gemm( fc1_dgrad, _, _ = gemm(
fc1_weight, fc1_weight,
dgelu, dgelu,
...@@ -1994,6 +2003,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1994,6 +2003,7 @@ class _LayerNormMLP(torch.autograd.Function):
elif ctx.set_parallel_mode and ctx.tensor_parallel: elif ctx.set_parallel_mode and ctx.tensor_parallel:
fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)
if fc1_weight.requires_grad:
if ctx.fp8: if ctx.fp8:
# FC1 WGRAD # FC1 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
...@@ -2011,7 +2021,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2011,7 +2021,9 @@ class _LayerNormMLP(torch.autograd.Function):
get_workspace(), get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation, fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=fc1_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
) )
else: else:
...@@ -2031,7 +2043,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2031,7 +2043,9 @@ class _LayerNormMLP(torch.autograd.Function):
grad=True, grad=True,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation, fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=fc1_weight.main_grad
if ctx.fuse_wgrad_accumulation
else None,
) )
else: else:
# FC1 WGRAD # FC1 WGRAD
...@@ -2079,11 +2093,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2079,11 +2093,11 @@ class _LayerNormMLP(torch.autograd.Function):
dxmat.view(ctx.inp_shape), dxmat.view(ctx.inp_shape),
dgamma, dgamma,
dbeta, dbeta,
fc1_wgrad, fc1_wgrad if fc1_weight.requires_grad else None,
None, None,
None, None,
fc1_bias_grad, fc1_bias_grad,
fc2_wgrad, fc2_wgrad if fc2_weight.requires_grad else None,
None, None,
None, None,
fc2_bias_grad, fc2_bias_grad,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment