Commit 44740c6c authored by yuguo's avatar yuguo
Browse files

Merge commit '7a9a0825' of...

Merge commit '7a9a0825' of https://github.com/NVIDIA/TransformerEngine
parents 8113d9e0 7a9a0825
...@@ -106,7 +106,7 @@ all_normalizations = ["LayerNorm", "RMSNorm"] ...@@ -106,7 +106,7 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
mask_types = ["causal", "no_mask"] mask_types = ["causal", "no_mask"]
NVTE_TEST_NVINSPECT_ENABLED = os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False) NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))
if NVTE_TEST_NVINSPECT_ENABLED: if NVTE_TEST_NVINSPECT_ENABLED:
# The numerics of all the layers should work the same, # The numerics of all the layers should work the same,
...@@ -1059,8 +1059,11 @@ def test_mha_accuracy(dtype, bs, model, mask_type): ...@@ -1059,8 +1059,11 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False): def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False, recipe=None):
reset_rng_states() reset_rng_states()
fp8 = recipe is not None
if fp8:
FP8GlobalStateManager.reset()
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
(config.seq_len, bs, config.hidden_size), (config.seq_len, bs, config.hidden_size),
...@@ -1070,9 +1073,10 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False) ...@@ -1070,9 +1073,10 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False)
) )
inp_hidden_states.retain_grad() inp_hidden_states.retain_grad()
out = block(inp_hidden_states) with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
if isinstance(out, (List, Tuple)): out = block(inp_hidden_states)
out = out[0] if isinstance(out, (List, Tuple)):
out = out[0]
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
if delay_wgrad_compute: if delay_wgrad_compute:
...@@ -1268,6 +1272,64 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ ...@@ -1268,6 +1272,64 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_
torch.testing.assert_close(o, o_ref, rtol=0, atol=0) torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
def test_linear_accuracy_save_original_input(dtype, model, recipe):
bs = 1
fuse_wgrad_accumulation = True
fp8_model_params = False
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
te_linear_ref = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
save_original_input=False,
).eval()
te_linear = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
save_original_input=True,
).eval()
# Share params
with torch.no_grad():
te_linear_ref.weight = Parameter(te_linear.weight.clone())
if fuse_wgrad_accumulation:
weight = getattr(te_linear, f"weight")
weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
te_linear_ref.weight.main_grad = weight.main_grad.clone()
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, recipe=recipe)
te_outputs_ref = _test_granular_accuracy(te_linear_ref, bs, dtype, config, recipe=recipe)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
...@@ -1768,6 +1830,111 @@ def test_grouped_linear_accuracy( ...@@ -1768,6 +1830,111 @@ def test_grouped_linear_accuracy(
device="cuda", device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation, fuse_wgrad_accumulation=fuse_wgrad_accumulation,
delay_wgrad_compute=delay_wgrad_compute, delay_wgrad_compute=delay_wgrad_compute,
save_original_input=False,
).eval()
sequential_linear = torch.nn.ModuleList(
[
Linear(
config.hidden_size,
4 * config.hidden_size,
bias=bias,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
for _ in range(num_gemms)
]
)
# Share params
with torch.no_grad():
for i in range(num_gemms):
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
if bias:
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
if fuse_wgrad_accumulation:
weight_i = getattr(grouped_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
outputs_ref = _test_grouped_linear_accuracy(
sequential_linear,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
)
outputs = _test_grouped_linear_accuracy(
grouped_linear,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("num_gemms", [3])
@pytest.mark.parametrize("bs", [1])
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
@pytest.mark.parametrize("fp8_model_params", [False])
@pytest.mark.parametrize("fuse_wgrad_accumulation", [True])
@pytest.mark.parametrize("bias", [False])
@pytest.mark.parametrize("delay_wgrad_compute", [True])
def test_grouped_linear_accuracy_save_original_input(
dtype,
num_gemms,
bs,
model,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
bias,
delay_wgrad_compute,
parallel_mode=None,
):
fp8 = recipe is not None
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=bias,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
delay_wgrad_compute=delay_wgrad_compute,
save_original_input=True,
).eval() ).eval()
sequential_linear = torch.nn.ModuleList( sequential_linear = torch.nn.ModuleList(
[ [
...@@ -1948,7 +2115,89 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r ...@@ -1948,7 +2115,89 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r
@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_padding_grouped_linear_accuracy( def test_padding_grouped_linear_accuracy(
dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None dtype,
num_gemms,
bs,
model,
fp8,
recipe,
fp8_model_params,
parallel_mode=None,
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
fp8=fp8,
).eval()
with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
ref_grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
save_original_input=False,
).eval()
# Share params
with torch.no_grad():
inner_grouped_linear = grouped_linear.linear_fn
for i in range(num_gemms):
setattr(
ref_grouped_linear,
f"weight{i}",
Parameter(getattr(inner_grouped_linear, f"weight{i}").clone()),
)
outputs = _test_padding_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
)
outputs_ref = _test_padding_grouped_linear_accuracy(
ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("num_gemms", [3])
@pytest.mark.parametrize("bs", [1])
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8", [True])
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", [False])
def test_padding_grouped_linear_accuracy_save_original_input(
dtype,
num_gemms,
bs,
model,
fp8,
recipe,
fp8_model_params,
parallel_mode=None,
): ):
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -1958,6 +2207,8 @@ def test_padding_grouped_linear_accuracy( ...@@ -1958,6 +2207,8 @@ def test_padding_grouped_linear_accuracy(
pytest.skip("FP8 parameters are not supported in debug mode.") pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available: if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling) pytest.skip(reason_for_no_fp8_block_scaling)
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
config = model_configs[model] config = model_configs[model]
if config.seq_len % 16 != 0 and fp8: if config.seq_len % 16 != 0 and fp8:
...@@ -1983,6 +2234,7 @@ def test_padding_grouped_linear_accuracy( ...@@ -1983,6 +2234,7 @@ def test_padding_grouped_linear_accuracy(
params_dtype=dtype, params_dtype=dtype,
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
device="cuda", device="cuda",
save_original_input=True,
).eval() ).eval()
# Share params # Share params
...@@ -2334,9 +2586,9 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -2334,9 +2586,9 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
if ( if (
backend == "FusedAttention" backend == "FusedAttention"
and get_device_compute_capability() == (8, 9) and get_device_compute_capability() == (8, 9)
and get_cudnn_version() < (9, 11, 0) and get_cudnn_version() < (9, 12, 0)
): ):
pytest.skip("Skip KV cache for sm89 and cuDNN < 9.11") pytest.skip("Skip KV cache for sm89 and cuDNN < 9.12")
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
......
This diff is collapsed.
...@@ -61,22 +61,26 @@ class TestParallelCrossEntropy: ...@@ -61,22 +61,26 @@ class TestParallelCrossEntropy:
test_loss = self.test_loss_func( test_loss = self.test_loss_func(
self.input_test, self.tar_test, label_smoothing, reduce_loss, None self.input_test, self.tar_test, label_smoothing, reduce_loss, None
) )
if reduce_loss:
test_loss.backward()
ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref)
# Handle backward pass based on the test scenario
if reduce_loss: if reduce_loss:
test_loss.backward()
ref_loss.backward() ref_loss.backward()
else:
test_loss.sum().backward()
ref_loss.sum().backward()
test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss
torch.testing.assert_close(test_loss, ref_loss, check_dtype=False)
if ignore_idx: if ignore_idx:
print(test_loss, ref_loss) print(test_loss, ref_loss)
if reduce_loss:
torch.testing.assert_close( # Compare gradients when backward pass was called
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad torch.testing.assert_close(
) torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad
)
self.input_test = None self.input_test = None
self.input_ref = None self.input_ref = None
......
This diff is collapsed.
...@@ -47,7 +47,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -47,7 +47,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import checkpoint
from test_numerics import reset_rng_states, dtype_tols from utils import dtype_tols
# Only run FP8 tests on supported devices. # Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
...@@ -56,6 +56,28 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( ...@@ -56,6 +56,28 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
) )
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Record initial RNG state from script run.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0"))
if NVTE_TEST_NVINSPECT_ENABLED:
# The sanity tests should work the same,
# when debug=True. I fed them with dummy feature
# to prevent switching off debug, which can happen if
# no feature is active.
import nvdlfw_inspect.api as debug_api
debug_api.initialize(
os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
)
def create_meta(scale_factor: float, size: int = 1): def create_meta(scale_factor: float, size: int = 1):
meta = tex.FP8TensorMeta() meta = tex.FP8TensorMeta()
...@@ -90,6 +112,13 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor: ...@@ -90,6 +112,13 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor:
return torch.min(amax_history, dim=0).values return torch.min(amax_history, dim=0).values
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
global _cpu_rng_state, _cuda_rng_state
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@dataclass @dataclass
class ModelConfig: class ModelConfig:
"""Transformer model configuration""" """Transformer model configuration"""
...@@ -529,6 +558,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba ...@@ -529,6 +558,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean)
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias): def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
pytest.skip("Quantized model parameters are not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size ffn_hidden_size = 4 * config.hidden_size
num_tokens = bs * config.seq_len num_tokens = bs * config.seq_len
...@@ -570,6 +601,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -570,6 +601,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
def test_sanity_grouped_linear( def test_sanity_grouped_linear(
dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
): ):
if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
pytest.skip("FP8 model parameters are not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size ffn_hidden_size = 4 * config.hidden_size
# Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
...@@ -682,6 +715,8 @@ def test_sanity_gpt( ...@@ -682,6 +715,8 @@ def test_sanity_gpt(
parallel_attention_mlp, parallel_attention_mlp,
cpu_offload, cpu_offload,
): ):
if cpu_offload and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("CPU offload is not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
if fp8_recipe is not None: if fp8_recipe is not None:
...@@ -1367,6 +1402,8 @@ def test_inference_mode( ...@@ -1367,6 +1402,8 @@ def test_inference_mode(
quantization: Optional[str], quantization: Optional[str],
) -> None: ) -> None:
"""Test heuristics for initializing quantized weights""" """Test heuristics for initializing quantized weights"""
if NVTE_TEST_NVINSPECT_ENABLED and quantization is not None:
pytest.skip("Quantized model parameters are not supported in debug mode.")
# Tensor dimensions # Tensor dimensions
sequence_length = 32 sequence_length = 32
......
...@@ -93,6 +93,7 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]: ...@@ -93,6 +93,7 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
if name in ("fp8", "fp8_delayed_scaling"): if name in ("fp8", "fp8_delayed_scaling"):
return transformer_engine.common.recipe.DelayedScaling( return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3, fp8_format=transformer_engine.common.recipe.Format.E4M3,
amax_history_len=8,
) )
if name == "fp8_current_scaling": if name == "fp8_current_scaling":
return transformer_engine.common.recipe.Float8CurrentScaling( return transformer_engine.common.recipe.Float8CurrentScaling(
......
...@@ -158,6 +158,9 @@ if(USE_CUDA) ...@@ -158,6 +158,9 @@ if(USE_CUDA)
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu fused_rope/fused_rope.cu
fused_router/fused_moe_aux_loss.cu
fused_router/fused_score_for_moe_aux_loss.cu
fused_router/fused_topk_with_score_function.cu
recipe/current_scaling.cu recipe/current_scaling.cu
recipe/delayed_scaling.cu recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu recipe/fp8_block_scaling.cu
...@@ -211,6 +214,9 @@ else() ...@@ -211,6 +214,9 @@ else()
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu fused_rope/fused_rope.cu
fused_router/fused_moe_aux_loss.cu
fused_router/fused_score_for_moe_aux_loss.cu
fused_router/fused_topk_with_score_function.cu
recipe/current_scaling.cu recipe/current_scaling.cu
recipe/delayed_scaling.cu recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu recipe/fp8_block_scaling.cu
......
...@@ -100,6 +100,16 @@ bool has_mnnvl_fabric(int device_id) { ...@@ -100,6 +100,16 @@ bool has_mnnvl_fabric(int device_id) {
} }
return false; return false;
#else #else
// Check run-time CUDA version
if (transformer_engine::cuda::cudart_version() < 12040) {
if (getenv("NVTE_UBDEBUG")) {
printf(
"TransformerEngine does not support multi-node NVLINK "
"since it is not being run with CUDA version >= 12.4.\n");
}
return false;
}
bool mnnvl_fabric_support = false; bool mnnvl_fabric_support = false;
CUdevice dev; CUdevice dev;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &dev, device_id); NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &dev, device_id);
......
...@@ -248,7 +248,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -248,7 +248,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) ||
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
cudnn_runtime_version >= 91100))) && cudnn_runtime_version >= 91100)) &&
// 9.11 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(!(cudnn_runtime_version == 91100 && is_training && sm_arch_ == 90 && head_dim_qk >= 128 &&
head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) &&
head_dim_qk != head_dim_v))) &&
// bias type // bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
(cudnn_runtime_version >= 8906 && (cudnn_runtime_version >= 8906 &&
......
This diff is collapsed.
This diff is collapsed.
...@@ -52,7 +52,7 @@ class OptionalCUDAGuard { ...@@ -52,7 +52,7 @@ class OptionalCUDAGuard {
~OptionalCUDAGuard() { ~OptionalCUDAGuard() {
if (device_changed_) { if (device_changed_) {
NVTE_CHECK_CUDA(cudaSetDevice(prev_device_)); cudaSetDevice(prev_device_);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment