Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pathlib
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -363,6 +363,28 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
TEDebugState._reset()
def test_log_grouped_gemm(feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
log_all_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(all_stats))
with debug_session(log_all_stats_config, feature_dirs) as log_dir:
model = te.GroupedLinear(3, 128, 128, name="linear1", params_dtype=torch.bfloat16)
inp = torch.randn((1, 128, 128), dtype=torch.bfloat16).cuda()
m_splits = [64, 32, 32]
with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()):
output = model(inp, m_splits=m_splits)
loss = output.sum()
loss.backward()
debug_api.step()
output = read_log(log_dir)
assert "gemm_0" in output, "gemm0 not found in output"
assert "gemm_1" in output, "gemm1 not found in output"
assert "gemm_2" in output, "gemm2 not found in output"
def test_compute_max_blockwise_dynamic_range_direct():
"""Direct unit test for compute_max_blockwise_dynamic_range function.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -82,7 +82,6 @@ def _fp8_gemm_kernel(tensor1, scale1, dtype1, tensor2, scale2, dtype2, use_split
out, *_ = tepytorch.cpp_extensions.general_gemm(
fp8_tensor1,
fp8_tensor2,
tepytorch.module.base.get_workspace(),
torch.float32,
use_split_accumulator=use_split_accumulator,
)
......@@ -199,7 +198,6 @@ def _emulate_linear(
wgrad, *_ = tepytorch.cpp_extensions.general_gemm(
wgrad_input,
wgrad_gradient,
tepytorch.module.base.get_workspace(),
torch.float32,
layout="NT",
grad=True,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -25,10 +25,8 @@ from transformer_engine.pytorch import (
MXFP8Quantizer,
)
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.module.base import (
fill_userbuffers_buffer_for_all_gather,
get_cublas_workspace_size_bytes,
)
from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes
from transformer_engine.pytorch.module.base import fill_userbuffers_buffer_for_all_gather
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
......@@ -420,10 +418,6 @@ def _main(opts):
std=opts.std,
)
# Allocate cuBLAS workspace
workspace_size = 1 * get_cublas_workspace_size_bytes()
workspace = torch.empty(workspace_size, dtype=torch.uint8, device="cuda")
# Gather global tensors and calculate reference result (need these first for Fp8 scales)
if opts.bulk_overlap:
ker_g = torch.transpose(kernel_t, 0, 1)
......@@ -620,7 +614,6 @@ def _main(opts):
return tex.general_gemm(
kernel_t_fp8,
gemm_inp,
workspace,
out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16,
quantization_params=out_quantizer,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
......@@ -638,7 +631,6 @@ def _main(opts):
return tex.general_gemm(
kernel2_t_fp8,
gemm2_inp,
workspace,
out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16,
quantization_params=out2_quantizer,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
......@@ -651,7 +643,6 @@ def _main(opts):
return tex.general_gemm(
kernel_t,
gemm_inp,
workspace,
out_dtype=torch.bfloat16,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
ub=ub_obj,
......
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -39,8 +39,9 @@ WORLD_RANK, WORLD_SIZE = None, None
NCCL_WORLD = None
LOSS_FN = nn.MSELoss()
QUANTIZATION = None
NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED") or "0")
if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
if NVTE_TEST_NVINSPECT_ENABLED:
# The numerics of all the layers should work the same,
# when debug=True. I fed them with dummy feature
# to prevent switching off debug, which can happen if
......@@ -754,6 +755,8 @@ def test_linear():
for kwargs in kwargs_list:
if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
continue
if kwargs.get("delay_wgrad_compute", False) and NVTE_TEST_NVINSPECT_ENABLED:
continue
for parallel_mode in ["column", "row"]:
for sequence_parallel in [False, True]:
_test_linear(parallel_mode, sequence_parallel, **kwargs)
......@@ -941,6 +944,8 @@ def test_layernorm_linear():
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list:
if kwargs.get("delay_wgrad_compute", False) and NVTE_TEST_NVINSPECT_ENABLED:
continue
for parallel_mode in ["column"]:
for sequence_parallel in [False, True]:
_test_layernorm_linear(parallel_mode, sequence_parallel, **kwargs)
......@@ -1047,6 +1052,7 @@ def test_layernorm_mlp():
{"return_bias": True},
{"return_layernorm_output": True},
{"delay_wgrad_compute": True},
{"checkpoint": True},
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
......@@ -1058,6 +1064,8 @@ def test_layernorm_mlp():
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list:
if kwargs.get("delay_wgrad_compute", False) and NVTE_TEST_NVINSPECT_ENABLED:
continue
for set_parallel_mode in [True]:
for sequence_parallel in [False, True]:
_test_layernorm_mlp(set_parallel_mode, sequence_parallel, **kwargs)
......
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -20,6 +20,7 @@ from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Float8BlockScaling,
MXFP8BlockScaling,
Format,
Recipe,
)
......@@ -27,9 +28,11 @@ import transformer_engine.pytorch as te
from transformer_engine.pytorch import (
is_fp8_available,
is_fp8_block_scaling_available,
is_mxfp8_available,
QuantizedTensor,
Float8Tensor,
Float8BlockwiseQTensor,
MXFP8Tensor,
)
from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8
from transformer_engine.pytorch.tensor.utils import post_all_gather_processing, replace_raw_data
......@@ -44,17 +47,21 @@ def _get_quantization_recipe(quantization) -> Recipe:
return Float8CurrentScaling(fp8_format=fp8_format)
elif quantization == "fp8_block":
return Float8BlockScaling(fp8_format=fp8_format)
elif quantization == "mxfp8":
return MXFP8BlockScaling()
else:
raise ValueError(f"Unsupported quantization: {quantization}")
def _get_raw_data(quantized_tensor):
def _get_raw_data(quantized_tensor, colwise=False):
"""Get the underlying data of a quantized tensor, used in zero-1 optimizer"""
if isinstance(quantized_tensor, Float8Tensor):
assert not colwise, "Float8Tensor does not support get colwise data"
assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute"
assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8"
return quantized_tensor._data
elif isinstance(quantized_tensor, Float8BlockwiseQTensor):
assert not colwise, "Float8BlockwiseQTensor does not support get colwise data"
assert hasattr(
quantized_tensor, "_rowwise_data"
), "Float8BlockwiseQTensor does not have _rowwise_data attribute"
......@@ -62,6 +69,23 @@ def _get_raw_data(quantized_tensor):
quantized_tensor._rowwise_data.dtype == torch.uint8
), "Float8BlockwiseQTensor _rowwise_data must be uint8"
return quantized_tensor._rowwise_data
elif isinstance(quantized_tensor, MXFP8Tensor):
if colwise:
assert hasattr(
quantized_tensor, "_columnwise_data"
), "MXFP8Tensor does not have columnwise_data attribute"
assert (
quantized_tensor._columnwise_data.dtype == torch.uint8
), "MXFP8Tensor columnwise_data must be uint8"
return quantized_tensor._columnwise_data
else:
assert hasattr(
quantized_tensor, "_rowwise_data"
), "MXFP8Tensor does not have rowwise_data attribute"
assert (
quantized_tensor._rowwise_data.dtype == torch.uint8
), "MXFP8Tensor rowwise_data must be uint8"
return quantized_tensor._rowwise_data
else:
raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}")
......@@ -231,38 +255,43 @@ class MiniZero_1:
end = start_offset + master_weight.numel()
weight.data.view(-1)[start:end].copy_(master_weight)
# -----------------------------------------------------------------------------------------
# Step 5: Copy the updated weights (not all weights) to the weight buffer
# -----------------------------------------------------------------------------------------
for i in range(len(self.weights)):
master_weight = self.master_weights[i]
if master_weight is None:
continue
start_offset = self.start_offsets[i]
if isinstance(self.weights[i], QuantizedTensor):
weight = _get_raw_data(self.weights[i])
else:
weight = self.weights[i]
weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()]
overlapping_start, overlapping_end = self.overlapping_areas[i]
self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice)
colwise_list = [False]
if isinstance(self.weights[0], MXFP8Tensor):
colwise_list.append(True)
# -----------------------------------------------------------------------------------------
# Step 6: Weight all-gather (FP8 or BF16)
# -----------------------------------------------------------------------------------------
dist.all_gather_into_tensor(
self.weight_buffer, self.weight_buffer_slice, group=self.dp_group
)
for colwise in colwise_list:
# -------------------------------------------------------------------------------------
# Step 5: Copy the updated weights (not all weights) to the weight buffer
# -------------------------------------------------------------------------------------
for i in range(len(self.weights)):
master_weight = self.master_weights[i]
if master_weight is None:
continue
start_offset = self.start_offsets[i]
if isinstance(self.weights[i], QuantizedTensor):
weight = _get_raw_data(self.weights[i], colwise)
else:
weight = self.weights[i]
weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()]
overlapping_start, overlapping_end = self.overlapping_areas[i]
self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice)
# -------------------------------------------------------------------------------------
# Step 6: Weight all-gather (FP8 or BF16)
# -------------------------------------------------------------------------------------
dist.all_gather_into_tensor(
self.weight_buffer, self.weight_buffer_slice, group=self.dp_group
)
# -----------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# -----------------------------------------------------------------------------------------
for weight, offset in zip(self.weights, self.offsets[:-1]):
start = offset
end = offset + weight.numel()
if isinstance(weight, QuantizedTensor):
weight = _get_raw_data(weight)
weight.view(-1).data.copy_(self.weight_buffer[start:end])
# -------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# -------------------------------------------------------------------------------------
for weight, offset in zip(self.weights, self.offsets[:-1]):
start = offset
end = offset + weight.numel()
if isinstance(weight, QuantizedTensor):
weight = _get_raw_data(weight, colwise)
weight.view(-1).data.copy_(self.weight_buffer[start:end])
if self.manual_post_all_gather_processing:
quantized_weights = [
......@@ -287,9 +316,15 @@ class MiniFSDP:
else:
raw_data_list = [w.view(-1) for w in weights]
self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list)
if isinstance(weights[0], MXFP8Tensor):
self.flatten_columnwise = self.flatten_weight.clone()
else:
self.flatten_columnwise = None
# Split flattened weights into shards
self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank]
if self.flatten_columnwise is not None:
self.local_columnwise_shard = torch.chunk(self.flatten_columnwise, world_size)[rank]
self.local_main_grad_shard = torch.zeros_like(
self.local_weight_shard, dtype=torch.float32, device="cuda"
)
......@@ -321,14 +356,25 @@ class MiniFSDP:
self.shard_indices.append((None, None))
if isinstance(weights[idx], QuantizedTensor):
replace_raw_data(
weights[idx], self.flatten_weight[start:end].view(weights[idx].shape)
)
if self.flatten_columnwise is not None:
new_rowwise_data = self.flatten_weight[start:end].view(weights[idx].shape)
new_rowwise_data.copy_(weights[idx]._rowwise_data)
weights[idx]._rowwise_data = new_rowwise_data
new_columnwise_data = self.flatten_columnwise[start:end].view(
weights[idx].shape
)
new_columnwise_data.copy_(weights[idx]._columnwise_data)
weights[idx]._columnwise_data = new_columnwise_data
else:
replace_raw_data(
weights[idx], self.flatten_weight[start:end].view(weights[idx].shape)
)
else:
weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape)
# Initialize local model weights and high-precision master weights
self.local_weights = []
self.local_columnwise = []
self.master_weights = []
for i, weight in enumerate(self.weights):
weight_start, weight_end = self.weight_indices[i]
......@@ -336,6 +382,11 @@ class MiniFSDP:
if shard_start is not None and shard_end is not None:
local_weight_shard = self.local_weight_shard[shard_start:shard_end]
self.local_weights.append(local_weight_shard)
if self.flatten_columnwise is not None:
local_columnwise_shard = self.local_columnwise_shard[shard_start:shard_end]
else:
local_columnwise_shard = None
self.local_columnwise.append(local_columnwise_shard)
if isinstance(weight, QuantizedTensor):
high_precision_init_val = weight.get_high_precision_init_val().view(-1)
......@@ -347,6 +398,7 @@ class MiniFSDP:
self.master_weights.append(master_weight_shard)
else:
self.local_weights.append(None)
self.local_columnwise.append(None)
self.master_weights.append(None)
setattr(
weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda")
......@@ -417,12 +469,12 @@ class MiniFSDP:
# Step 3: Cast master weights to FP8 or BF16 precision
if isinstance(self.weights[0], QuantizedTensor):
local_weights = []
for local_weight in self.local_weights:
if local_weight is None:
local_weights.append(None)
continue
local_weights.append(local_weight)
for i, local_weight in enumerate(self.local_weights):
if self.flatten_columnwise is not None:
local_columnwise = self.local_columnwise[i]
local_weights.append((local_weight, local_columnwise))
else:
local_weights.append(local_weight)
cast_master_weights_to_fp8(
self.weights,
......@@ -444,6 +496,10 @@ class MiniFSDP:
dist.all_gather_into_tensor(
self.flatten_weight, self.local_weight_shard, group=self.dp_group
)
if self.flatten_columnwise is not None:
dist.all_gather_into_tensor(
self.flatten_columnwise, self.local_columnwise_shard, group=self.dp_group
)
if self.manual_post_all_gather_processing:
quantized_weights = [
......@@ -515,15 +571,15 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat
preserve_high_precision_init_val=True,
):
model_fp8 = nn.Sequential(
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(128, 256 + 32, **linear_kwargs),
te.Linear(256 + 32, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Create model with BF16 weights
model = nn.Sequential(
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(128, 256 + 32, **linear_kwargs),
te.Linear(256 + 32, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
......@@ -548,7 +604,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat
w.main_grad.zero_()
inputs = [
torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
torch.randn(32, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank]
......@@ -579,7 +635,9 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat
optimizer_fp8.step()
optimizer.step()
torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
assert torch.allclose(
loss_fp8, loss, atol=0, rtol=0
), f"Loss mismatch at rank {rank}, step {i} for {quantization}"
def _test_fsdp_cast_master_weights_to_fp8(
......@@ -611,15 +669,15 @@ def _test_fsdp_cast_master_weights_to_fp8(
preserve_high_precision_init_val=True,
):
model_fp8 = nn.Sequential(
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(128, 256 + 32, **linear_kwargs),
te.Linear(256 + 32, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Create model with BF16 weights
model = nn.Sequential(
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(128, 256 + 32, **linear_kwargs),
te.Linear(256 + 32, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
......@@ -633,12 +691,12 @@ def _test_fsdp_cast_master_weights_to_fp8(
)
optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group)
for _ in range(100):
for i in range(100):
optimizer_fp8.zero_grad()
optimizer.zero_grad()
inputs = [
torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
torch.randn(32, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x = inputs[rank]
......@@ -669,7 +727,9 @@ def _test_fsdp_cast_master_weights_to_fp8(
optimizer_fp8.step()
optimizer.step()
torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
assert torch.allclose(
loss_fp8, loss, atol=0, rtol=0
), f"Loss mismatch at rank {rank}, step {i} for {quantization} (FSDP)"
def run_parallel_tests() -> None:
......@@ -700,6 +760,8 @@ def run_parallel_tests() -> None:
quantizations.extend(["fp8", "fp8_cs"])
if is_fp8_block_scaling_available():
quantizations.append("fp8_block")
if is_mxfp8_available():
quantizations.append("mxfp8")
manual_post_all_gather_processings = [False, True]
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# mpirun -np 8 --allow-run-as-root --oversubscribe --quiet python3 /home/TransformerEngine/tests/pytorch/distributed/run_gemm_with_overlap.py --check-numerics --seed=42 --seq-length=1024 --batch-size=2 --num-heads=48 --head-dim=64 --comm-type=AG --p2p
......@@ -127,12 +127,18 @@ def _run_layer_with_overlap(
os.environ["PYTORCH_JIT"] = "0"
os.environ["NVTE_TORCH_COMPILE"] = "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
if te.get_device_compute_capability() <= (8, 0):
# We've experienced numerical discrepancies in Flash Attention
# backward when running with Userbuffers on A100s. This does
# not show up in more recent GPUs.
os.environ["NVTE_FLASH_ATTN"] = "0"
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
os.unsetenv("PYTORCH_JIT")
os.unsetenv("NVTE_TORCH_COMPILE")
os.unsetenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO")
os.unsetenv("NVTE_FLASH_ATTN")
if (
result.returncode != 0
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -13,7 +13,7 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
"""
Distributed numerics tests
These tests test the numerical corectness of the TransformerEngine layers.
These tests test the numerical correctness of the TransformerEngine layers.
Tests are parametrized by the layer and fp8 precision.
One test consists of running multiple configurations from file run_numerics.py
Such design is due to the fact the initialization of one test is long
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment