Commit 44740c6c authored by yuguo's avatar yuguo
Browse files

Merge commit '7a9a0825' of...

Merge commit '7a9a0825' of https://github.com/NVIDIA/TransformerEngine
parents 8113d9e0 7a9a0825
...@@ -106,7 +106,7 @@ all_normalizations = ["LayerNorm", "RMSNorm"] ...@@ -106,7 +106,7 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
mask_types = ["causal", "no_mask"] mask_types = ["causal", "no_mask"]
NVTE_TEST_NVINSPECT_ENABLED = os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False) NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))
if NVTE_TEST_NVINSPECT_ENABLED: if NVTE_TEST_NVINSPECT_ENABLED:
# The numerics of all the layers should work the same, # The numerics of all the layers should work the same,
...@@ -1059,8 +1059,11 @@ def test_mha_accuracy(dtype, bs, model, mask_type): ...@@ -1059,8 +1059,11 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False): def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False, recipe=None):
reset_rng_states() reset_rng_states()
fp8 = recipe is not None
if fp8:
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size), (config.seq_len, bs, config.hidden_size),
...@@ -1070,6 +1073,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False) ...@@ -1070,6 +1073,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False)
) )
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
out = block(inp_hidden_states) out = block(inp_hidden_states)
if isinstance(out, (List, Tuple)): if isinstance(out, (List, Tuple)):
out = out[0] out = out[0]
...@@ -1268,6 +1272,64 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ ...@@ -1268,6 +1272,64 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_
torch.testing.assert_close(o, o_ref, rtol=0, atol=0) torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
def test_linear_accuracy_save_original_input(dtype, model, recipe):
bs = 1
fuse_wgrad_accumulation = True
fp8_model_params = False
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
te_linear_ref = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
save_original_input=False,
).eval()
te_linear = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
save_original_input=True,
).eval()
# Share params
with torch.no_grad():
te_linear_ref.weight = Parameter(te_linear.weight.clone())
if fuse_wgrad_accumulation:
weight = getattr(te_linear, f"weight")
weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
te_linear_ref.weight.main_grad = weight.main_grad.clone()
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe)
te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
...@@ -1768,6 +1830,111 @@ def test_grouped_linear_accuracy( ...@@ -1768,6 +1830,111 @@ def test_grouped_linear_accuracy(
device="cuda", device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation, fuse_wgrad_accumulation=fuse_wgrad_accumulation,
delay_wgrad_compute=delay_wgrad_compute, delay_wgrad_compute=delay_wgrad_compute,
save_original_input=False,
).eval()
sequential_linear = torch.nn.ModuleList(
[
Linear(
config.hidden_size,
4 * config.hidden_size,
bias=bias,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
for _ in range(num_gemms)
]
)
# Share params
with torch.no_grad():
for i in range(num_gemms):
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
if bias:
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
if fuse_wgrad_accumulation:
weight_i = getattr(grouped_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
outputs_ref = _test_grouped_linear_accuracy(
sequential_linear,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
)
outputs = _test_grouped_linear_accuracy(
grouped_linear,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("num_gemms", [3])
@pytest.mark.parametrize("bs", [1])
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
@pytest.mark.parametrize("fp8_model_params", [False])
@pytest.mark.parametrize("fuse_wgrad_accumulation", [True])
@pytest.mark.parametrize("bias", [False])
@pytest.mark.parametrize("delay_wgrad_compute", [True])
def test_grouped_linear_accuracy_save_original_input(
dtype,
num_gemms,
bs,
model,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
bias,
delay_wgrad_compute,
parallel_mode=None,
):
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=bias,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
delay_wgrad_compute=delay_wgrad_compute,
save_original_input=True,
).eval() ).eval()
sequential_linear = torch.nn.ModuleList( sequential_linear = torch.nn.ModuleList(
[ [
...@@ -1948,7 +2115,89 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r ...@@ -1948,7 +2115,89 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_padding_grouped_linear_accuracy( def test_padding_grouped_linear_accuracy(
dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None dtype,
num_gemms,
bs,
model,
fp8,
recipe,
fp8_model_params,
parallel_mode=None,
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
fp8=fp8,
).eval()
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
ref_grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
save_original_input=False,
).eval()
# Share params
with torch.no_grad():
inner_grouped_linear = grouped_linear.linear_fn
for i in range(num_gemms):
setattr(
ref_grouped_linear,
f"weight{i}",
Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
)
outputs = _test_padding_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
)
outputs_ref = _test_padding_grouped_linear_accuracy(
ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3])
@pytest.mark.parametrize("bs", [1])
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", [True])
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", [False])
def test_padding_grouped_linear_accuracy_save_original_input(
dtype,
num_gemms,
bs,
model,
fp8,
recipe,
fp8_model_params,
parallel_mode=None,
): ):
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -1958,6 +2207,8 @@ def test_padding_grouped_linear_accuracy( ...@@ -1958,6 +2207,8 @@ def test_padding_grouped_linear_accuracy(
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available: if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling) pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model] config = model_configs[model]
if config.seq_len % 16 != 0 and fp8: if config.seq_len % 16 != 0 and fp8:
...@@ -1983,6 +2234,7 @@ def test_padding_grouped_linear_accuracy( ...@@ -1983,6 +2234,7 @@ def test_padding_grouped_linear_accuracy(
params_dtype=dtype, params_dtype=dtype,
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
device="cuda", device="cuda",
save_original_input=True,
).eval() ).eval()
# Share params # Share params
...@@ -2334,9 +2586,9 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -2334,9 +2586,9 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
if ( if (
backend == "FusedAttention" backend == "FusedAttention"
and get_device_compute_capability() == (8, 9) and get_device_compute_capability() == (8, 9)
and get_cudnn_version() < (9, 11, 0) and get_cudnn_version() < (9, 12, 0)
): ):
pytest.skip("Skip KV cache for sm89 and cuDNN < 9.11") pytest.skip("Skip KV cache for sm89 and cuDNN < 9.12")
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
......
This diff is collapsed.
...@@ -61,19 +61,23 @@ class TestParallelCrossEntropy: ...@@ -61,19 +61,23 @@ class TestParallelCrossEntropy:
test_loss = self.test_loss_func( test_loss = self.test_loss_func(
self.input_test, self.tar_test, label_smoothing, reduce_loss, None self.input_test, self.tar_test, label_smoothing, reduce_loss, None
) )
if reduce_loss:
test_loss.backward()
ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref)
# Handle backward pass based on the test scenario
if reduce_loss: if reduce_loss:
test_loss.backward()
ref_loss.backward() ref_loss.backward()
else:
test_loss.sum().backward()
ref_loss.sum().backward()
test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss
torch.testing.assert_close(test_loss, ref_loss, check_dtype=False)
if ignore_idx: if ignore_idx:
print(test_loss, ref_loss) print(test_loss, ref_loss)
if reduce_loss:
# Compare gradients when backward pass was called
torch.testing.assert_close( torch.testing.assert_close(
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad
) )
......
...@@ -326,6 +326,7 @@ def _test_permutation_index_map( ...@@ -326,6 +326,7 @@ def _test_permutation_index_map(
te_unpermute_output_ = te_unpermute_output.float() te_unpermute_output_ = te_unpermute_output.float()
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
if not BENCHMARK:
torch.testing.assert_close( torch.testing.assert_close(
pytorch_permute_output.float(), pytorch_permute_output.float(),
te_permute_output_, te_permute_output_,
...@@ -351,7 +352,10 @@ def _test_permutation_index_map( ...@@ -351,7 +352,10 @@ def _test_permutation_index_map(
) )
if with_probs: if with_probs:
torch.testing.assert_close( torch.testing.assert_close(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols probs.grad.float(),
te_probs.grad.float(),
msg=f"Mismatch in te_unpermute bwd",
**tols,
) )
if not pytorch_permute_fwd_input.numel(): if not pytorch_permute_fwd_input.numel():
...@@ -538,6 +542,7 @@ def _test_permutation_mask_map( ...@@ -538,6 +542,7 @@ def _test_permutation_mask_map(
te_unpermute_output_ = te_unpermute_output.float() te_unpermute_output_ = te_unpermute_output.float()
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
if not BENCHMARK:
torch.testing.assert_close( torch.testing.assert_close(
pytorch_permute_output.float(), pytorch_permute_output.float(),
te_permute_output_, te_permute_output_,
...@@ -564,7 +569,10 @@ def _test_permutation_mask_map( ...@@ -564,7 +569,10 @@ def _test_permutation_mask_map(
) )
if with_probs: if with_probs:
torch.testing.assert_close( torch.testing.assert_close(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols probs.grad.float(),
te_probs.grad.float(),
msg=f"Mismatch in te_unpermute bwd",
**tols,
) )
if not pytorch_permute_fwd_input.numel(): if not pytorch_permute_fwd_input.numel():
...@@ -827,6 +835,7 @@ def _test_moe_chunk_sort( ...@@ -827,6 +835,7 @@ def _test_moe_chunk_sort(
te_output_ = te_output.float() te_output_ = te_output.float()
te_fwd_input_grad = te_fwd_input.grad.float() te_fwd_input_grad = te_fwd_input.grad.float()
if not BENCHMARK:
torch.testing.assert_close( torch.testing.assert_close(
pytorch_output.float(), pytorch_output.float(),
te_output_, te_output_,
...@@ -887,6 +896,7 @@ def _test_permutation_mask_map_alongside_probs( ...@@ -887,6 +896,7 @@ def _test_permutation_mask_map_alongside_probs(
topK, topK,
num_out_tokens, num_out_tokens,
tp_size, tp_size,
BENCHMARK=False,
): ):
if topK > num_expert: if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.") pytest.skip("topK should be smaller than the number of experts.")
...@@ -1016,6 +1026,7 @@ def _test_permutation_mask_map_alongside_probs( ...@@ -1016,6 +1026,7 @@ def _test_permutation_mask_map_alongside_probs(
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float() te_unpermute_output_ = te_unpermute_output.float()
if not BENCHMARK:
torch.testing.assert_close( torch.testing.assert_close(
pytorch_unpermute_output.float(), pytorch_unpermute_output.float(),
te_unpermute_output_, te_unpermute_output_,
...@@ -1032,6 +1043,57 @@ def _test_permutation_mask_map_alongside_probs( ...@@ -1032,6 +1043,57 @@ def _test_permutation_mask_map_alongside_probs(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in prob grad", **tols probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in prob grad", **tols
) )
if BENCHMARK:
t1 = perf_test_cuda_kernel(
lambda: te_permute_with_probs(
te_permute_fwd_input, te_probs, routing_map, num_out_tokens=num_out_tokens
)
)
print(f"permute\t\tfwd: TE: {t1:.3f} ms")
te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs(
te_permute_fwd_input,
te_probs,
routing_map,
num_out_tokens=num_out_tokens,
)
te_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
te_permute_output,
te_permute_bwd_input,
forward_input=[te_permute_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"permute\t\tbwd: TE: {t2:.3f} ms")
chunk_sort_fwd_input = te_permute_output.detach()
chunk_sort_fwd_input.requires_grad_(True)
chunk_sort_fwd_probs = te_permuted_probs.detach()
chunk_sort_fwd_probs.requires_grad_(True)
t1 = perf_test_cuda_kernel(
lambda: te_sort_chunks_by_index_with_probs(
chunk_sort_fwd_input, chunk_sort_fwd_probs, split_sizes_cuda, sorted_idxs_cuda
)
)
print(f"chunk sort\t\tfwd: TE: {t1:.3f} ms")
chunk_sort_output, _ = te_sort_chunks_by_index_with_probs(
chunk_sort_fwd_input, chunk_sort_fwd_probs, split_sizes_cuda, sorted_idxs_cuda
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
chunk_sort_output,
te_permute_bwd_input,
forward_input=[chunk_sort_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"chunk sort\t\tbwd: TE: {t2:.3f} ms")
def perf_test_cuda_kernel(cuda_kernel_fn): def perf_test_cuda_kernel(cuda_kernel_fn):
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -1063,7 +1125,7 @@ if is_bf16_compatible(): ...@@ -1063,7 +1125,7 @@ if is_bf16_compatible():
@pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("num_out_tokens", [None, 2039])
...@@ -1092,7 +1154,7 @@ def test_permutation_index_map( ...@@ -1092,7 +1154,7 @@ def test_permutation_index_map(
@pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("num_out_tokens", [None, 2039])
...@@ -1138,7 +1200,7 @@ def test_permutation_mask_map_empty_input(te_dtype): ...@@ -1138,7 +1200,7 @@ def test_permutation_mask_map_empty_input(te_dtype):
@pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("num_out_tokens", [None, 2039])
...@@ -1193,7 +1255,7 @@ fp8_recipes = [ ...@@ -1193,7 +1255,7 @@ fp8_recipes = [
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) @pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("num_tokens", [2048]) @pytest.mark.parametrize("num_tokens", [2048])
@pytest.mark.parametrize("num_expert", [8, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5]) @pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("num_out_tokens", [None, 2039])
...@@ -1225,7 +1287,7 @@ def test_permutation_mask_map_fp8( ...@@ -1225,7 +1287,7 @@ def test_permutation_mask_map_fp8(
@pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
def test_permutation_index_map_topk1_no_probs( def test_permutation_index_map_topk1_no_probs(
te_dtype, te_dtype,
...@@ -1252,7 +1314,7 @@ def test_permutation_index_map_topk1_no_probs( ...@@ -1252,7 +1314,7 @@ def test_permutation_index_map_topk1_no_probs(
@pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
def test_permutation_mask_map_topk1_no_probs( def test_permutation_mask_map_topk1_no_probs(
te_dtype, te_dtype,
...@@ -1279,7 +1341,7 @@ def test_permutation_mask_map_topk1_no_probs( ...@@ -1279,7 +1341,7 @@ def test_permutation_mask_map_topk1_no_probs(
@pytest.mark.parametrize("te_dtype", _te_dtypes) @pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16]) @pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("tp_size", [1, 2, 8]) @pytest.mark.parametrize("tp_size", [1, 2, 8])
@pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("hidden_size", [4096])
def test_chunk_permutation( def test_chunk_permutation(
...@@ -1372,5 +1434,108 @@ def test_permutation_single_case(): ...@@ -1372,5 +1434,108 @@ def test_permutation_single_case():
) )
def benchmark_single_case(
te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
):
torch.cuda.nvtx.range_push(
f"{num_tokens}-{num_expert}-{hidden_size}-{topK}-{ep_size}-{tp_size}"
)
torch.cuda.nvtx.range_push("permutation_index_map_with_probs")
_test_permutation_index_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=True,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_mask_map_with_probs")
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=True,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_mask_map_without_probs")
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=False,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_mask_map_alongside_probs")
_test_permutation_mask_map_alongside_probs(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
tp_size=tp_size,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop()
def benchmark_multiple_cases():
print("GPU:", torch.cuda.get_device_name(0))
# te_dtype = tex.DType.kFloat32
# te_dtype = tex.DType.kFloat16
te_dtype = tex.DType.kBFloat16
ep_size = 64
tp_size = 2
num_tokens = 4096
num_expert = 256
hidden_size = 7168
topK = 8
num_out_tokens = num_tokens * topK
benchmark_single_case(
te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
)
ep_size = 8
tp_size = 1
num_tokens = 8192 * 2
num_expert = 128
hidden_size = 4096
topK = 6
num_out_tokens = num_tokens * topK
benchmark_single_case(
te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
)
ep_size = 64
tp_size = 2
num_tokens = 16384
num_expert = 4
hidden_size = 7168
topK = 1
num_out_tokens = num_tokens * topK
benchmark_single_case(
te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
)
if __name__ == "__main__": if __name__ == "__main__":
test_permutation_single_case() benchmark_multiple_cases()
...@@ -47,7 +47,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -47,7 +47,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import checkpoint
from test_numerics import reset_rng_states, dtype_tols from utils import dtype_tols
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
...@@ -56,6 +56,28 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( ...@@ -56,6 +56,28 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
) )
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))
if NVTE_TEST_NVINSPECT_ENABLED:
# The sanity tests should work the same,
# when debug=True. I fed them with dummy feature
# to prevent switching off debug, which can happen if
# no feature is active.
import nvdlfw_inspect.api as debug_api
debug_api.initialize(
os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
)
def create_meta(scale_factor: float, size: int = 1): def create_meta(scale_factor: float, size: int = 1):
meta = tex.FP8TensorMeta() meta = tex.FP8TensorMeta()
...@@ -90,6 +112,13 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor: ...@@ -90,6 +112,13 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
return torch.min(amax_history, dim=0).values return torch.min(amax_history, dim=0).values
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
global _cpu_rng_state, _cuda_rng_state
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@dataclass @dataclass
class ModelConfig: class ModelConfig:
"""Transformer model configuration""" """Transformer model configuration"""
...@@ -529,6 +558,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba ...@@ -529,6 +558,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias): def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
pytest.skip("Quantized model parameters are not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size ffn_hidden_size = 4 * config.hidden_size
num_tokens = bs * config.seq_len num_tokens = bs * config.seq_len
...@@ -570,6 +601,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -570,6 +601,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
def test_sanity_grouped_linear( def test_sanity_grouped_linear(
dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
): ):
if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
pytest.skip("FP8 model parameters are not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size ffn_hidden_size = 4 * config.hidden_size
# Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
...@@ -682,6 +715,8 @@ def test_sanity_gpt( ...@@ -682,6 +715,8 @@ def test_sanity_gpt(
parallel_attention_mlp, parallel_attention_mlp,
cpu_offload, cpu_offload,
): ):
if cpu_offload and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("CPU offload is not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
...@@ -1367,6 +1402,8 @@ def test_inference_mode( ...@@ -1367,6 +1402,8 @@ def test_inference_mode(
quantization: Optional[str], quantization: Optional[str],
) -> None: ) -> None:
"""Test heuristics for initializing quantized weights""" """Test heuristics for initializing quantized weights"""
if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
pytest.skip("Quantized model parameters are not supported in debug mode.")
# Tensor dimensions # Tensor dimensions
sequence_length = 32 sequence_length = 32
......
...@@ -93,6 +93,7 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]: ...@@ -93,6 +93,7 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
if name in ("fp8", "fp8_delayed_scaling"): if name in ("fp8", "fp8_delayed_scaling"):
return transformer_engine.common.recipe.DelayedScaling( return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3, fp8_format=transformer_engine.common.recipe.Format.E4M3,
amax_history_len=8,
) )
if name == "fp8_current_scaling": if name == "fp8_current_scaling":
return transformer_engine.common.recipe.Float8CurrentScaling( return transformer_engine.common.recipe.Float8CurrentScaling(
......
...@@ -158,6 +158,9 @@ if(USE_CUDA) ...@@ -158,6 +158,9 @@ if(USE_CUDA)
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu fused_rope/fused_rope.cu
fused_router/fused_moe_aux_loss.cu
fused_router/fused_score_for_moe_aux_loss.cu
fused_router/fused_topk_with_score_function.cu
recipe/current_scaling.cu recipe/current_scaling.cu
recipe/delayed_scaling.cu recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu recipe/fp8_block_scaling.cu
...@@ -211,6 +214,9 @@ else() ...@@ -211,6 +214,9 @@ else()
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu fused_rope/fused_rope.cu
fused_router/fused_moe_aux_loss.cu
fused_router/fused_score_for_moe_aux_loss.cu
fused_router/fused_topk_with_score_function.cu
recipe/current_scaling.cu recipe/current_scaling.cu
recipe/delayed_scaling.cu recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu recipe/fp8_block_scaling.cu
......
...@@ -100,6 +100,16 @@ bool has_mnnvl_fabric(int device_id) { ...@@ -100,6 +100,16 @@ bool has_mnnvl_fabric(int device_id) {
} }
return false; return false;
#else #else
// Check run-time CUDA version
if (transformer_engine::cuda::cudart_version() < 12040) {
if (getenv("NVTE_UBDEBUG")) {
printf(
"TransformerEngine does not support multi-node NVLINK "
"since it is not being run with CUDA version >= 12.4.\n");
}
return false;
}
bool mnnvl_fabric_support = false; bool mnnvl_fabric_support = false;
CUdevice dev; CUdevice dev;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &dev, device_id); NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &dev, device_id);
......
This diff is collapsed.
This diff is collapsed.
...@@ -52,7 +52,7 @@ class OptionalCUDAGuard { ...@@ -52,7 +52,7 @@ class OptionalCUDAGuard {
~OptionalCUDAGuard() { ~OptionalCUDAGuard() {
if (device_changed_) { if (device_changed_) {
NVTE_CHECK_CUDA(cudaSetDevice(prev_device_)); cudaSetDevice(prev_device_);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment