Unverified Commit 78210127 authored by schetlur-nv's avatar schetlur-nv Committed by GitHub
Browse files

Conditional wgrad support (#21)



* Conditional dgrad/wgrad support
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Fixing the change to depend only on requires_grad. Also updating LayerNorm MLP
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Minor fixes.
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Adding conditional wgrad for LayerNormLinear
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* bug fix and remove conditional dgrad

Co-authored-by: schetlur-nv schetlur@nvidia.com
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Adding unit test for wgrad disabled path
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Adding more unit tests for wgrad disabled path
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Adding unit tests for fp8 wgrad disabling, and cleaning up the code.
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* fix lint errors
Co-Authored-By: default avatarSharan Chetlur <schetlur@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 924892fd
......@@ -86,8 +86,15 @@ param_types = [torch.float32, torch.bfloat16, torch.float16]
batch_sizes = [1, 2]
skip_wgrad = [True, False]
def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe):
def _disable_wgrads(block):
for p in block.parameters():
p.requires_grad = False
def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda()
......@@ -103,6 +110,10 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe):
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
......@@ -113,7 +124,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe):
torch.cuda.synchronize()
def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe):
def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
......@@ -129,6 +140,10 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe):
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum()
......@@ -136,7 +151,7 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe):
torch.cuda.synchronize()
def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe):
def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
......@@ -152,6 +167,10 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe):
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
......@@ -161,10 +180,14 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe):
torch.cuda.synchronize()
def _test_sanity_common(block, bs, dtype, config, fp8_recipe):
def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
if skip_wgrad:
_disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(te_inp)
if isinstance(te_out, tuple):
......@@ -178,7 +201,8 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe):
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -194,14 +218,15 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model):
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, fp8_recipe)
_test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_linear(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -214,14 +239,15 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model):
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, fp8_recipe)
_test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -239,14 +265,15 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model):
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config, fp8_recipe)
_test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_gpt(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -271,14 +298,15 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_bert(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -303,14 +331,15 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_T5(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -336,14 +365,15 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model):
.cuda()
)
_test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe)
_test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -366,14 +396,15 @@ def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model):
.cuda()
)
_test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe)
_test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_drop_path(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -399,14 +430,15 @@ def test_sanity_drop_path(dtype, bs, fp8_recipe, model):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -432,4 +464,4 @@ def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe)
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
......@@ -37,14 +37,22 @@ param_types = [torch.float32, torch.bfloat16, torch.float16]
batch_sizes = [1, 2]
skip_wgrad = [True, False]
def _test_sanity_e2e_amp(block, bs, dtype, config):
def _disable_wgrads(block):
for p in block.parameters():
p.requires_grad = False
def _test_sanity_e2e_amp(block, bs, dtype, config, skip_wgrad):
if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
return
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda()
te_inp_attn_mask = (
torch.rand(
(
......@@ -57,6 +65,10 @@ def _test_sanity_e2e_amp(block, bs, dtype, config):
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum()
......@@ -66,7 +78,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config):
torch.cuda.synchronize()
def _test_sanity_e2e(block, bs, dtype, config):
def _test_sanity_e2e(block, bs, dtype, config, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
......@@ -82,13 +94,17 @@ def _test_sanity_e2e(block, bs, dtype, config):
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
def _test_sanity_e2e_T5(block, bs, dtype, config):
def _test_sanity_e2e_T5(block, bs, dtype, config, skip_wgrad):
te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
......@@ -104,6 +120,10 @@ def _test_sanity_e2e_T5(block, bs, dtype, config):
.cuda()
.bool()
)
if skip_wgrad:
_disable_wgrads(block)
te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
)
......@@ -112,10 +132,14 @@ def _test_sanity_e2e_T5(block, bs, dtype, config):
torch.cuda.synchronize()
def _test_sanity_common(block, bs, dtype, config):
def _test_sanity_common(block, bs, dtype, config, skip_wgrad):
te_inp = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
if skip_wgrad:
_disable_wgrads(block)
te_out = block(te_inp)
if isinstance(te_out, tuple):
te_out = te_out[0]
......@@ -127,7 +151,8 @@ def _test_sanity_common(block, bs, dtype, config):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_linear(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_linear(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -143,13 +168,14 @@ def test_sanity_layernorm_linear(dtype, bs, model):
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config)
_test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_linear(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_linear(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -162,13 +188,15 @@ def test_sanity_linear(dtype, bs, model):
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config)
_test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_mlp(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_mlp(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -186,13 +214,14 @@ def test_sanity_layernorm_mlp(dtype, bs, model):
.to(dtype=dtype)
.cuda()
)
_test_sanity_common(block, bs, dtype, config)
_test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_gpt(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_gpt(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -217,13 +246,14 @@ def test_sanity_gpt(dtype, bs, model):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config)
_test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_bert(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_bert(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -248,13 +278,14 @@ def test_sanity_bert(dtype, bs, model):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config)
_test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_T5(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_T5(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -280,13 +311,14 @@ def test_sanity_T5(dtype, bs, model):
.cuda()
)
_test_sanity_e2e_T5(block, bs, dtype, config)
_test_sanity_e2e_T5(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_amp_and_nvfuser(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_amp_and_nvfuser(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -309,13 +341,14 @@ def test_sanity_amp_and_nvfuser(dtype, bs, model):
.cuda()
)
_test_sanity_e2e_amp(block, bs, dtype, config)
_test_sanity_e2e_amp(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_drop_path(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_drop_path(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -341,13 +374,14 @@ def test_sanity_drop_path(dtype, bs, model):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config)
_test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_fused_qkv_params(dtype, bs, model):
@pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_fused_qkv_params(dtype, bs, model, skip_wgrad):
config = model_configs[model]
sigma = 0.023
......@@ -373,4 +407,4 @@ def test_sanity_fused_qkv_params(dtype, bs, model):
.cuda()
)
_test_sanity_e2e(block, bs, dtype, config)
_test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment