Unverified Commit 78210127 authored by schetlur-nv's avatar schetlur-nv Committed by GitHub
Browse files

Conditional wgrad support (#21)



* Conditional dgrad/wgrad support
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Fixing the change to depend only on requires_grad. Also updating LayerNorm MLP
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Minor fixes.
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Adding conditional wgrad for LayerNormLinear
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* bug fix and remove conditional dgrad

Co-authored-by: schetlur-nv schetlur@nvidia.com
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Adding unit test for wgrad disabled path
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Adding more unit tests for wgrad disabled path
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* Adding unit tests for fp8 wgrad disabling, and cleaning up the code.
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>

* fix lint errors
Co-Authored-By: default avatarSharan Chetlur <schetlur@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSharan Chetlur <schetlur@dlcluster.nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 924892fd
...@@ -86,8 +86,15 @@ param_types = [torch.float32, torch.bfloat16, torch.float16] ...@@ -86,8 +86,15 @@ param_types = [torch.float32, torch.bfloat16, torch.float16]
batch_sizes = [1, 2] batch_sizes = [1, 2]
skip_wgrad = [True, False]
def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe):
def _disable_wgrads(block):
for p in block.parameters():
p.requires_grad = False
def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda() ).cuda()
...@@ -103,6 +110,10 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe): ...@@ -103,6 +110,10 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe):
.cuda() .cuda()
.bool() .bool()
) )
if skip_wgrad:
_disable_wgrads(block)
with torch.cuda.amp.autocast(enabled=True, dtype=dtype): with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask) te_out = block(te_inp_hidden_states, te_inp_attn_mask)
...@@ -113,7 +124,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe): ...@@ -113,7 +124,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe):
torch.cuda.synchronize() torch.cuda.synchronize()
def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe): def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
...@@ -129,6 +140,10 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe): ...@@ -129,6 +140,10 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe):
.cuda() .cuda()
.bool() .bool()
) )
if skip_wgrad:
_disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(te_inp_hidden_states, te_inp_attn_mask) te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
...@@ -136,7 +151,7 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe): ...@@ -136,7 +151,7 @@ def _test_sanity_e2e(block, bs, dtype, config, fp8_recipe):
torch.cuda.synchronize() torch.cuda.synchronize()
def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe): def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
...@@ -152,6 +167,10 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe): ...@@ -152,6 +167,10 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe):
.cuda() .cuda()
.bool() .bool()
) )
if skip_wgrad:
_disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
...@@ -161,10 +180,14 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe): ...@@ -161,10 +180,14 @@ def _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe):
torch.cuda.synchronize() torch.cuda.synchronize()
def _test_sanity_common(block, bs, dtype, config, fp8_recipe): def _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad):
te_inp = torch.randn( te_inp = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
if skip_wgrad:
_disable_wgrads(block)
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
te_out = block(te_inp) te_out = block(te_inp)
if isinstance(te_out, tuple): if isinstance(te_out, tuple):
...@@ -178,7 +201,8 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe): ...@@ -178,7 +201,8 @@ def _test_sanity_common(block, bs, dtype, config, fp8_recipe):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -194,14 +218,15 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model): ...@@ -194,14 +218,15 @@ def test_sanity_layernorm_linear(dtype, bs, fp8_recipe, model):
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_common(block, bs, dtype, config, fp8_recipe) _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_linear(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_linear(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -214,14 +239,15 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model): ...@@ -214,14 +239,15 @@ def test_sanity_linear(dtype, bs, fp8_recipe, model):
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_common(block, bs, dtype, config, fp8_recipe) _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -239,14 +265,15 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model): ...@@ -239,14 +265,15 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model):
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_common(block, bs, dtype, config, fp8_recipe) _test_sanity_common(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_gpt(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -271,14 +298,15 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model): ...@@ -271,14 +298,15 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model):
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe) _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_bert(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_bert(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -303,14 +331,15 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model): ...@@ -303,14 +331,15 @@ def test_sanity_bert(dtype, bs, fp8_recipe, model):
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe) _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_T5(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_T5(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -336,14 +365,15 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model): ...@@ -336,14 +365,15 @@ def test_sanity_T5(dtype, bs, fp8_recipe, model):
.cuda() .cuda()
) )
_test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe) _test_sanity_e2e_T5(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -366,14 +396,15 @@ def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model): ...@@ -366,14 +396,15 @@ def test_sanity_amp_and_nvfuser(dtype, bs, fp8_recipe, model):
.cuda() .cuda()
) )
_test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe) _test_sanity_e2e_amp(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_drop_path(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_drop_path(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -399,14 +430,15 @@ def test_sanity_drop_path(dtype, bs, fp8_recipe, model): ...@@ -399,14 +430,15 @@ def test_sanity_drop_path(dtype, bs, fp8_recipe, model):
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe) _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -432,4 +464,4 @@ def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model): ...@@ -432,4 +464,4 @@ def test_sanity_fused_qkv_params(dtype, bs, fp8_recipe, model):
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config, fp8_recipe) _test_sanity_e2e(block, bs, dtype, config, fp8_recipe, skip_wgrad)
...@@ -37,14 +37,22 @@ param_types = [torch.float32, torch.bfloat16, torch.float16] ...@@ -37,14 +37,22 @@ param_types = [torch.float32, torch.bfloat16, torch.float16]
batch_sizes = [1, 2] batch_sizes = [1, 2]
skip_wgrad = [True, False]
def _test_sanity_e2e_amp(block, bs, dtype, config):
def _disable_wgrads(block):
for p in block.parameters():
p.requires_grad = False
def _test_sanity_e2e_amp(block, bs, dtype, config, skip_wgrad):
if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
return return
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=torch.float32, requires_grad=True
).cuda() ).cuda()
te_inp_attn_mask = ( te_inp_attn_mask = (
torch.rand( torch.rand(
( (
...@@ -57,6 +65,10 @@ def _test_sanity_e2e_amp(block, bs, dtype, config): ...@@ -57,6 +65,10 @@ def _test_sanity_e2e_amp(block, bs, dtype, config):
.cuda() .cuda()
.bool() .bool()
) )
if skip_wgrad:
_disable_wgrads(block)
with torch.cuda.amp.autocast(enabled=True, dtype=dtype): with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
te_out = block(te_inp_hidden_states, te_inp_attn_mask) te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
...@@ -66,7 +78,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config): ...@@ -66,7 +78,7 @@ def _test_sanity_e2e_amp(block, bs, dtype, config):
torch.cuda.synchronize() torch.cuda.synchronize()
def _test_sanity_e2e(block, bs, dtype, config): def _test_sanity_e2e(block, bs, dtype, config, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
...@@ -82,13 +94,17 @@ def _test_sanity_e2e(block, bs, dtype, config): ...@@ -82,13 +94,17 @@ def _test_sanity_e2e(block, bs, dtype, config):
.cuda() .cuda()
.bool() .bool()
) )
if skip_wgrad:
_disable_wgrads(block)
te_out = block(te_inp_hidden_states, te_inp_attn_mask) te_out = block(te_inp_hidden_states, te_inp_attn_mask)
loss = te_out.sum() loss = te_out.sum()
loss.backward() loss.backward()
torch.cuda.synchronize() torch.cuda.synchronize()
def _test_sanity_e2e_T5(block, bs, dtype, config): def _test_sanity_e2e_T5(block, bs, dtype, config, skip_wgrad):
te_inp_hidden_states = torch.randn( te_inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
...@@ -104,6 +120,10 @@ def _test_sanity_e2e_T5(block, bs, dtype, config): ...@@ -104,6 +120,10 @@ def _test_sanity_e2e_T5(block, bs, dtype, config):
.cuda() .cuda()
.bool() .bool()
) )
if skip_wgrad:
_disable_wgrads(block)
te_out = block( te_out = block(
te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states te_inp_hidden_states, te_inp_attn_mask, encoder_output=te_inp_hidden_states
) )
...@@ -112,10 +132,14 @@ def _test_sanity_e2e_T5(block, bs, dtype, config): ...@@ -112,10 +132,14 @@ def _test_sanity_e2e_T5(block, bs, dtype, config):
torch.cuda.synchronize() torch.cuda.synchronize()
def _test_sanity_common(block, bs, dtype, config): def _test_sanity_common(block, bs, dtype, config, skip_wgrad):
te_inp = torch.randn( te_inp = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda() ).cuda()
if skip_wgrad:
_disable_wgrads(block)
te_out = block(te_inp) te_out = block(te_inp)
if isinstance(te_out, tuple): if isinstance(te_out, tuple):
te_out = te_out[0] te_out = te_out[0]
...@@ -127,7 +151,8 @@ def _test_sanity_common(block, bs, dtype, config): ...@@ -127,7 +151,8 @@ def _test_sanity_common(block, bs, dtype, config):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_linear(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_linear(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -143,13 +168,14 @@ def test_sanity_layernorm_linear(dtype, bs, model): ...@@ -143,13 +168,14 @@ def test_sanity_layernorm_linear(dtype, bs, model):
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_common(block, bs, dtype, config) _test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_linear(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_linear(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -162,13 +188,15 @@ def test_sanity_linear(dtype, bs, model): ...@@ -162,13 +188,15 @@ def test_sanity_linear(dtype, bs, model):
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_common(block, bs, dtype, config)
_test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_layernorm_mlp(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_layernorm_mlp(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -186,13 +214,14 @@ def test_sanity_layernorm_mlp(dtype, bs, model): ...@@ -186,13 +214,14 @@ def test_sanity_layernorm_mlp(dtype, bs, model):
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
) )
_test_sanity_common(block, bs, dtype, config) _test_sanity_common(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_gpt(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_gpt(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -217,13 +246,14 @@ def test_sanity_gpt(dtype, bs, model): ...@@ -217,13 +246,14 @@ def test_sanity_gpt(dtype, bs, model):
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config) _test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_bert(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_bert(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -248,13 +278,14 @@ def test_sanity_bert(dtype, bs, model): ...@@ -248,13 +278,14 @@ def test_sanity_bert(dtype, bs, model):
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config) _test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_T5(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_T5(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -280,13 +311,14 @@ def test_sanity_T5(dtype, bs, model): ...@@ -280,13 +311,14 @@ def test_sanity_T5(dtype, bs, model):
.cuda() .cuda()
) )
_test_sanity_e2e_T5(block, bs, dtype, config) _test_sanity_e2e_T5(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_amp_and_nvfuser(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_amp_and_nvfuser(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -309,13 +341,14 @@ def test_sanity_amp_and_nvfuser(dtype, bs, model): ...@@ -309,13 +341,14 @@ def test_sanity_amp_and_nvfuser(dtype, bs, model):
.cuda() .cuda()
) )
_test_sanity_e2e_amp(block, bs, dtype, config) _test_sanity_e2e_amp(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_drop_path(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_drop_path(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -341,13 +374,14 @@ def test_sanity_drop_path(dtype, bs, model): ...@@ -341,13 +374,14 @@ def test_sanity_drop_path(dtype, bs, model):
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config) _test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_sanity_fused_qkv_params(dtype, bs, model): @pytest.mark.parametrize("skip_wgrad", skip_wgrad)
def test_sanity_fused_qkv_params(dtype, bs, model, skip_wgrad):
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -373,4 +407,4 @@ def test_sanity_fused_qkv_params(dtype, bs, model): ...@@ -373,4 +407,4 @@ def test_sanity_fused_qkv_params(dtype, bs, model):
.cuda() .cuda()
) )
_test_sanity_e2e(block, bs, dtype, config) _test_sanity_e2e(block, bs, dtype, config, skip_wgrad)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment