Unverified Commit 7b284fef authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[Pytorch] Check gradient in test numerics (#1229)



* update test numerics
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update test numerics
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update test numerics
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update tests/pytorch/test_numerics.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* tests fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Not passing CI fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Not passing CI fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Fix key
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7a5fd0c9
...@@ -64,6 +64,7 @@ class ModelConfig: ...@@ -64,6 +64,7 @@ class ModelConfig:
model_configs = { model_configs = {
"small": ModelConfig(128, 1e-5, 8, 36, 4, 128),
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048), "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048),
} }
...@@ -110,23 +111,30 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: ...@@ -110,23 +111,30 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
def assert_allclose( def assert_allclose(
l1: List[torch.Tensor], l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
l2: List[torch.Tensor],
atol: float,
) -> bool: ) -> bool:
"""Ensures two lists are equal.""" """Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs." assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)): for i, (t1, t2) in enumerate(zip(l1, l2)):
result = torch.allclose(t1, t2, atol=atol) tols = dict(atol=atol)
if rtol is not None:
tols["rtol"] = rtol
result = torch.allclose(t1, t2, **tols)
if not result: if not result:
diff = torch.abs(t1 - t2).flatten() diff = torch.abs(t1 - t2)
m = torch.argmax(diff) tol = atol + (rtol * torch.abs(t2))
msg = ( exceed_mask = diff > tol
f"Outputs not close enough in tensor at idx={i}. " if exceed_mask.any():
f"Location of the maximum difference: {m.item()} " indices = torch.nonzero(exceed_mask, as_tuple=True)
f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} " max_diff = diff[exceed_mask].max()
f"(diff {diff[m].item()})." max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
) max_location = [idx[max_idx].item() for idx in indices]
msg = (
f"Outputs not close enough in tensor at idx={i}. "
f"Maximum difference at location {max_location} "
f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
f"(diff {max_diff.item()})."
)
raise AssertionError(msg) raise AssertionError(msg)
...@@ -526,7 +534,7 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False ...@@ -526,7 +534,7 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params): def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params):
...@@ -631,7 +639,7 @@ def _test_e2e_full_recompute( ...@@ -631,7 +639,7 @@ def _test_e2e_full_recompute(
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_reentrant", all_boolean) @pytest.mark.parametrize("use_reentrant", all_boolean)
...@@ -764,7 +772,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= ...@@ -764,7 +772,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", ["126m"])
def test_gpt_checkpointing(dtype, bs, model): def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model] config = model_configs[model]
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
...@@ -809,7 +817,7 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): ...@@ -809,7 +817,7 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean) @pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model] config = model_configs[model]
...@@ -868,11 +876,25 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): ...@@ -868,11 +876,25 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
te_outputs = _test_e2e_gpt_accuracy(te_gpt, bs, dtype, config) te_outputs = _test_e2e_gpt_accuracy(te_gpt, bs, dtype, config)
torch_outputs = _test_e2e_gpt_accuracy(torch_gpt, bs, dtype, config) torch_outputs = _test_e2e_gpt_accuracy(torch_gpt, bs, dtype, config)
atol = {
torch.float32: 5e-3,
torch.half: 5e-2,
torch.bfloat16: 1e-1,
}
# Check output. # Check output.
if dtype == torch.float32: assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else: # Check gradients, only for small model
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) if model == "small":
atol[torch.float32] = 5e-2
rtol = {
torch.float32: 1e-2,
torch.half: 1e-2,
torch.bfloat16: 1e-2,
}
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
...@@ -906,7 +928,7 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): ...@@ -906,7 +928,7 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("mask_type", mask_types) @pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type): def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model] config = model_configs[model]
...@@ -947,6 +969,21 @@ def test_mha_accuracy(dtype, bs, model, mask_type): ...@@ -947,6 +969,21 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
else: else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
# Check gradients, only for small model
if model == "small":
atol = {
torch.float32: 5e-2,
torch.half: 5e-2,
torch.bfloat16: 5e-2,
}
rtol = {
torch.float32: 1e-2,
torch.half: 1e-2,
torch.bfloat16: 1e-2,
}
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
def _test_granular_accuracy(block, bs, dtype, config): def _test_granular_accuracy(block, bs, dtype, config):
reset_rng_states() reset_rng_states()
...@@ -1002,7 +1039,7 @@ def _test_dpa_accuracy(block, bs, dtype, config): ...@@ -1002,7 +1039,7 @@ def _test_dpa_accuracy(block, bs, dtype, config):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", ["126m"])
def test_dpa_accuracy(dtype, bs, model): def test_dpa_accuracy(dtype, bs, model):
config = model_configs[model] config = model_configs[model]
...@@ -1034,10 +1071,13 @@ def test_dpa_accuracy(dtype, bs, model): ...@@ -1034,10 +1071,13 @@ def test_dpa_accuracy(dtype, bs, model):
else: else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", ["small"])
def test_linear_accuracy(dtype, bs, model): def test_linear_accuracy(dtype, bs, model):
config = model_configs[model] config = model_configs[model]
...@@ -1066,15 +1106,20 @@ def test_linear_accuracy(dtype, bs, model): ...@@ -1066,15 +1106,20 @@ def test_linear_accuracy(dtype, bs, model):
torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config)
# Check output. # Check output.
if dtype == torch.float32: if model == "small":
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3) tolerance = 5e-3 if dtype == torch.float32 else 5e-2
else: rtol = {
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) torch.float32: 1.3e-6,
torch.half: 1e-2,
torch.bfloat16: 2e-2,
}
for te_output, torch_output in zip(te_outputs, torch_outputs):
assert_allclose(te_output, torch_output, tolerance, rtol[dtype])
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7]) @pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
...@@ -1102,18 +1147,29 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): ...@@ -1102,18 +1147,29 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
te_outputs = _test_granular_accuracy(te_rmsnorm, bs, dtype, config) te_outputs = _test_granular_accuracy(te_rmsnorm, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)
# Check output.
atol = { atol = {
torch.float32: 1e-7, torch.float32: 1e-7,
torch.half: 2e-3, torch.half: 2e-3,
torch.bfloat16: 2e-2, torch.bfloat16: 2e-2,
} }
# Check output.
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
atol[torch.float32] = 2e-3
rtol = {
torch.float32: 1.3e-6,
torch.half: 1e-3,
torch.bfloat16: 1.6e-2,
}
# Check gradients
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7]) @pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
...@@ -1142,18 +1198,29 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): ...@@ -1142,18 +1198,29 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
te_outputs = _test_granular_accuracy(te_layernorm, bs, dtype, config) te_outputs = _test_granular_accuracy(te_layernorm, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config)
# Check output.
atol = { atol = {
torch.float32: 1e-7, torch.float32: 1e-7,
torch.half: 2e-3, torch.half: 2e-3,
torch.bfloat16: 2e-2, torch.bfloat16: 2e-2,
} }
# Check output.
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
rtol = {
torch.float32: 1.3e-6,
torch.half: 1e-3,
torch.bfloat16: 1.6e-2,
}
atol[torch.float32] = 1e-4
# Check gradients
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma): def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma):
...@@ -1195,18 +1262,34 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere ...@@ -1195,18 +1262,34 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config) te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)
# Check output.
atol = { atol = {
torch.float32: 2.5e-4, torch.float32: 2.5e-4,
torch.half: 2e-3, torch.half: 2e-3,
torch.bfloat16: 2e-2, torch.bfloat16: 2e-2,
} }
# Check output.
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
if model == "small":
atol = {
torch.float32: 1e-3,
torch.half: 5e-2,
torch.bfloat16: 5e-2,
}
rtol = {
torch.float32: 1e-3,
torch.half: 4e-2,
torch.bfloat16: 4e-2,
}
# Check gradients
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
...@@ -1246,11 +1329,26 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): ...@@ -1246,11 +1329,26 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config) te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config)
torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config)
atol = {
torch.float32: 2e-2,
torch.half: 5e-2,
torch.bfloat16: 5e-2,
}
# Check output. # Check output.
if dtype == torch.float32: assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else: # Check gradients, only for small model
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) rtol = {
torch.float32: 1e-3,
torch.half: 1e-2,
torch.bfloat16: 4e-2,
}
atol[torch.half] = 2e-1
atol[torch.bfloat16] = 2e-1
if model == "small":
for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False): def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False):
...@@ -1301,7 +1399,7 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False ...@@ -1301,7 +1399,7 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_grouped_linear_accuracy( def test_grouped_linear_accuracy(
...@@ -1361,7 +1459,7 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode): ...@@ -1361,7 +1459,7 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
dtype=torch.float32, dtype=torch.float32,
num_gemms=6, num_gemms=6,
bs=2, bs=2,
model=list(model_configs.keys())[0], model="126m",
fp8=True, fp8=True,
fp8_model_params=True, fp8_model_params=True,
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
...@@ -1374,7 +1472,7 @@ def test_grouped_linear_accuracy_single_gemm(): ...@@ -1374,7 +1472,7 @@ def test_grouped_linear_accuracy_single_gemm():
dtype=torch.float32, dtype=torch.float32,
num_gemms=1, num_gemms=1,
bs=2, bs=2,
model=list(model_configs.keys())[0], model="126m",
fp8=True, fp8=True,
fp8_model_params=True, fp8_model_params=True,
) )
...@@ -1475,7 +1573,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, f ...@@ -1475,7 +1573,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, f
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", [True]) @pytest.mark.parametrize("fp8", [True])
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_padding_grouped_linear_accuracy( def test_padding_grouped_linear_accuracy(
...@@ -1594,7 +1692,7 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): ...@@ -1594,7 +1692,7 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", ["126m"])
def test_gpt_cuda_graph(dtype, bs, model): def test_gpt_cuda_graph(dtype, bs, model):
config = model_configs[model] config = model_configs[model]
...@@ -1686,7 +1784,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): ...@@ -1686,7 +1784,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", ["126m"])
def test_gpt_fp8_parameters(dtype, bs, model): def test_gpt_fp8_parameters(dtype, bs, model):
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -1710,7 +1808,7 @@ def test_gpt_fp8_parameters(dtype, bs, model): ...@@ -1710,7 +1808,7 @@ def test_gpt_fp8_parameters(dtype, bs, model):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", ["126m"])
def test_transformer_layer_hidden_states_format(dtype, bs, model): def test_transformer_layer_hidden_states_format(dtype, bs, model):
config = model_configs[model] config = model_configs[model]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment