Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
test_disable_fp8_layer:
enabled: True
layers:
layer_types: [qkv]
transformer_engine:
DisableFP8Layer:
enabled: True
\ No newline at end of file
deummy_feature_everywhere:
enabled: True
layers:
layer_name_regex_pattern: .*
transformer_engine:
TestDummyFeature:
enabled: True
tensors: [weight, activation, gradient, output, wgrad, dgrad]
gemms: [wgrad, dgrad, fprop]
\ No newline at end of file
test_fake_quant_fp8:
enabled: True
layers:
layer_numbers: [1]
layer_types: [fc1, fc2]
transformer_engine:
FakeQuant:
enabled: True
gemms: [fprop, dgrad]
tensors_struct:
- tensor: activation
quant_format: FP8E4M3
- tensor: gradient
quant_format: FP8E5M2
\ No newline at end of file
test_per_tensor_scaling:
enabled: True
layers:
layer_numbers: [1]
layer_types: [fc1, fc2]
transformer_engine:
DisableFP8GEMM:
enabled: True
gemms: [wgrad]
PerTensorScaling:
enabled: True
gemms_struct:
- gemm: fprop
tensors_struct:
- tensor: activation
- tensor: weight
- gemm: dgrad
tensors_struct:
- tensor: gradient
\ No newline at end of file
stat_collection_test_1:
enabled: True
layers:
layer_numbers: [1, 3]
LogTensorStats:
enabled: True
stats: [mean, std, l1_norm, l2_norm]
tensors: [activation]
freq: 1
start_step: 100
end_step: 500
transformer_engine:
LogTensorStats:
enabled: True
stats: [cur_amax, dynamic_range]
tensors: [activation]
freq: 2
start_step: 100
end_step: 500
LogFp8TensorStats:
enabled: True
stats: [underflows%]
tensors: [gradient]
freq: 5
start_step: 100
end_step: 500
stat_collection_test_2:
enabled: True
layers:
layer_numbers: [6, 7]
transformer_engine:
LogTensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [cur_amax, dynamic_range, mean, std, l1_norm]
freq: 2
start_step: 100
end_step: 500
- tensor: weight
stats: [mean, std, l1_norm, min, max]
freq: 5
start_step: 100
end_step: 500
stat_collection_test_4:
enabled: True
layers:
layer_numbers: [5]
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation]
stats: [cur_amax, dynamic_range, mean, std, l1_norm]
LogFp8TensorStats:
enabled: True
stats: [underflows%]
tensors: [activation]
\ No newline at end of file
# This config is used when FP8 training is ON
transformer_engine_fc1_manipulation:
enabled: True
layers:
layer_name_regex_pattern: .*(fc1) # Select layers if they end in fc1
transformer_engine: # namespace
DisableFP8GEMM: # Disable FP8 GEMM. FProp run in high precision
enabled: True
gemms: [fprop]
PerTensorScaling: # Scale DGrad gradients using per tensor current scaling and run FP8 GEMM
enabled: True
gemms: [dgrad]
tensors: [gradient]
FakeQuant: # Disable FP8 GEMM for Wgrad. Fake quantize activations to Wgrad and run high precision GEMM
enabled: True
gemms: [fprop]
tensors_struct:
- tensor: activation
quant_format: FP8E4M3
- tensor: weight
quant_format: FP8E4M3
transformer_engine_fc2_manipulation:
enabled: True
layers:
layer_name_regex_pattern: .*(fc2) # Select layers if they end in fc2
transformer_engine: # namespace
PerTensorScaling: # Scale WGrad and Fprop inputs using per tensor current scaling and run FP8 GEMM
enabled: True
gemms_struct:
- gemm: fprop
tensors_struct:
- tensor: activation
- tensor: weight
- gemm: wgrad
tensors_struct:
- tensor: activation
- tensor: gradient
FakeQuant: # Disable FP8 GEMM for DGrad. Fake quantize weights and gradients to DGrad and run high precision GEMM
enabled: True
gemms_struct:
- gemm: dgrad
tensors: [weight, gradient]
quant_format: FP8E5M2
\ No newline at end of file
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import subprocess
from pathlib import Path
import pytest
import torch
"""
Distributed numerics tests
These tests test the numerical corectness of the TransformerEngine layers.
Tests are parametrized by the layer and fp8 precision.
One test consists of running multiple configurations from file run_numerics.py
Such design is due to the fact the initialization of one test is long
- 2 processes need to start and load torch and TE. Multiple configurations
are run in one test - this reduces the initialization overhead.
"""
if torch.cuda.device_count() < 2:
pytest.skip("Distributed training needs at least 2 GPUs.")
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
def test_debug_distributed(feature_dirs):
test_path = TEST_ROOT / "run_distributed.py"
test_cmd = LAUNCH_CMD + [str(test_path), f"--feature_dirs={feature_dirs[0]}"]
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
if result.returncode != 0:
raise AssertionError(result.stderr.decode())
This diff is collapsed.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import functools
import itertools
import os
import random
import tempfile
from string import Template
import pytest
import torch
import nvdlfw_inspect.api as debug_api
import transformer_engine.debug
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import _default_sf_compute
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from test_numerics import create_config_file
B, S, H, D = 64, 64, 64, 64
model_keys = ["linear", "layernorm_linear", "layernorm_mlp", "mha_attention", "transformer_layer"]
configs = {
"": "",
"log": """log:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
LogFp8TensorStats:
enabled: True
tensors: [activation, gradient, weight]
stats: [underflows, overflows]
start_step : 0
end_step: 1
""",
"fake_quant": """
fake_quant_config:
enabled: True
layers:
layer_types: [linear]
transformer_engine:
FakeQuant:
enabled: True
gemms: [fprop, dgrad, wgrad]
quant_format: FP8E5M2
""",
}
def _get_model(model_key):
if model_key == "linear":
return te.Linear(D, D)
if model_key == "layernorm_linear":
return te.LayerNormLinear(D, D)
if model_key == "layernorm_mlp":
return te.LayerNormMLP(D, D, D)
if model_key == "mha_attention":
return te.MultiheadAttention(D, H)
if model_key == "transformer_layer":
return te.TransformerLayer(D, D, H)
def _run_forward_backward(model, fp8):
for _ in range(3):
inp = torch.randn((S, B, H)).cuda()
with te.fp8_autocast(enabled=fp8):
out = model(inp)
out.sum().backward()
debug_api.step()
@create_config_file
def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
try:
if config != "":
config_file.write(config)
config_file.flush()
config_file_name = config_file.name if config != "" else ""
debug_api.initialize(feature_dirs=feature_dirs, config_file=config_file_name)
model = _get_model(model_key)
_run_forward_backward(model, fp8)
except Exception as error:
raise error
finally:
debug_api.end_debug()
@pytest.mark.parametrize("model_key", model_keys)
@pytest.mark.parametrize("fp8", [False, True])
@pytest.mark.parametrize("config_key", configs.keys())
def test_sanity_debug(model_key, fp8, config_key, feature_dirs):
_run_test(model_key, fp8, configs[config_key], feature_dirs)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
LOG_FILE = os.path.join("nvdlfw_inspect_logs", "nvdlfw_inspect_globalrank-0.log")
def reset_debug_log():
if os.path.isfile(LOG_FILE):
# delete all content
with open(LOG_FILE, "w") as f:
pass
def check_debug_log(msg):
with open(LOG_FILE, "r") as f:
for line in f.readlines():
if msg in line:
return True
return False
...@@ -274,7 +274,9 @@ def _main(opts): ...@@ -274,7 +274,9 @@ def _main(opts):
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available() assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs) dist.init_process_group(**dist_init_kwargs)
tp_group = dist.new_group(backend="nccl") tp_group = dist.new_group(
backend="nccl", pg_options=dist.ProcessGroupNCCL.Options(is_high_priority_stream=True)
)
tp_rank = dist.get_rank(tp_group) tp_rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group) tp_size = dist.get_world_size(tp_group)
dist_print( dist_print(
......
...@@ -323,6 +323,7 @@ def _train(opts): ...@@ -323,6 +323,7 @@ def _train(opts):
new_group_kwargs = { new_group_kwargs = {
"backend": "nccl", "backend": "nccl",
"ranks": tp_rank_list, "ranks": tp_rank_list,
"pg_options": dist.ProcessGroupNCCL.Options(is_high_priority_stream=True),
} }
else: else:
opts.tp = WORLD_SIZE opts.tp = WORLD_SIZE
......
...@@ -35,6 +35,18 @@ NCCL_WORLD = None ...@@ -35,6 +35,18 @@ NCCL_WORLD = None
LOSS_FN = nn.MSELoss() LOSS_FN = nn.MSELoss()
QUANTIZATION = None QUANTIZATION = None
if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
# The numerics of all the layers should work the same,
# when debug=True. I fed them with dummy feature
# to prevent switching off debug, which can happen if
# no feature is active.
import nvdlfw_inspect.api as debug_api
debug_api.initialize(
os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
)
# Disable TF32 # Disable TF32
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
...@@ -89,11 +101,15 @@ def main(argv=None, namespace=None): ...@@ -89,11 +101,15 @@ def main(argv=None, namespace=None):
# Quantization scheme # Quantization scheme
QUANTIZATION = args.quantization QUANTIZATION = args.quantization
if QUANTIZATION in ("fp8", "mxfp8", "fp8_block_scaling"): global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE
global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE if QUANTIZATION in ("fp8", "mxfp8"):
SEQ_LEN = 32 SEQ_LEN = 32
BATCH_SIZE = 32 BATCH_SIZE = 32
HIDDEN_SIZE = 128 HIDDEN_SIZE = 128
elif QUANTIZATION == "fp8_block_scaling":
SEQ_LEN = 128
BATCH_SIZE = 128
HIDDEN_SIZE = 512
test_dict = [ test_dict = [
test_quantizer, test_quantizer,
...@@ -174,7 +190,7 @@ def _get_tolerances(dtype): ...@@ -174,7 +190,7 @@ def _get_tolerances(dtype):
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
return {"rtol": 1.6e-2, "atol": 1e-5} return {"rtol": 1.6e-2, "atol": 1e-5}
if dtype == torch.float32: if dtype == torch.float32:
return {"rtol": 1.3e-6, "atol": 4e-5} return {"rtol": 1.2e-4, "atol": 1e-4}
raise ValueError(f"Unsupported dtype ({dtype})") raise ValueError(f"Unsupported dtype ({dtype})")
...@@ -638,7 +654,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs ...@@ -638,7 +654,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
if "return_layernorm_output" in kwargs: if "return_layernorm_output" in kwargs:
output_single_node, norm_s = output_single_node output_single_node, norm_s = output_single_node
output_distributed, norm_d = output_distributed output_distributed, norm_d = output_distributed
if sequence_parallel: if sequence_parallel and not kwargs.get("return_layernorm_output_gathered", False):
norm_d = _gather(norm_d) norm_d = _gather(norm_d)
_check_outputs(norm_s, norm_d) _check_outputs(norm_s, norm_d)
...@@ -747,7 +763,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg ...@@ -747,7 +763,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
if "return_layernorm_output" in kwargs: if "return_layernorm_output" in kwargs:
output_single_node, norm_s = output_single_node output_single_node, norm_s = output_single_node
output_distributed, norm_d = output_distributed output_distributed, norm_d = output_distributed
if sequence_parallel: if sequence_parallel and not kwargs.get("return_layernorm_output_gathered", False):
norm_d = _gather(norm_d) norm_d = _gather(norm_d)
_check_outputs(norm_s, norm_d) _check_outputs(norm_s, norm_d)
......
...@@ -57,7 +57,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" ...@@ -57,7 +57,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
torch._dynamo.reset() torch._dynamo.reset()
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization): def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization):
test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = LAUNCH_CMD + [ test_cmd = LAUNCH_CMD + [
str(test_path), str(test_path),
...@@ -84,6 +84,8 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization): ...@@ -84,6 +84,8 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization):
if torch.cuda.get_device_properties(0).major != 9: if torch.cuda.get_device_properties(0).major != 9:
pytest.skip("Atomic GEMM is requires device compute capability 9.x (Hopper).") pytest.skip("Atomic GEMM is requires device compute capability 9.x (Hopper).")
test_cmd.append("--atomic") test_cmd.append("--atomic")
if aggregate:
test_cmd.append("--aggregate")
logging.info(f"test_cmd: {test_cmd}") logging.info(f"test_cmd: {test_cmd}")
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
...@@ -142,12 +144,13 @@ def _run_layer_with_overlap( ...@@ -142,12 +144,13 @@ def _run_layer_with_overlap(
@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
def test_split_all_gather_overlaps(quantization): @pytest.mark.parametrize("aggregate", (False, True))
def test_split_all_gather_overlaps(quantization, aggregate):
""" """
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm. te.cpp_extensions.fp8_gemm.
""" """
_run_gemm_with_overlap("AG", False, True, False, quantization) _run_gemm_with_overlap("AG", False, True, False, aggregate, quantization)
@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8")) @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
...@@ -157,7 +160,7 @@ def test_split_reduce_scatter_overlaps(quantization, p2p): ...@@ -157,7 +160,7 @@ def test_split_reduce_scatter_overlaps(quantization, p2p):
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm. te.cpp_extensions.fp8_gemm.
""" """
_run_gemm_with_overlap("RS", False, p2p, False, quantization) _run_gemm_with_overlap("RS", False, p2p, False, False, quantization)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -190,10 +193,10 @@ def test_bulk_overlaps(comm_type, quantization, connections): ...@@ -190,10 +193,10 @@ def test_bulk_overlaps(comm_type, quantization, connections):
" 9.0 (HOPPER ARCH)." " 9.0 (HOPPER ARCH)."
) )
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
_run_gemm_with_overlap(comm_type, True, False, False, quantization) _run_gemm_with_overlap(comm_type, True, False, False, False, quantization)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
else: else:
_run_gemm_with_overlap(comm_type, True, False, False, quantization) _run_gemm_with_overlap(comm_type, True, False, False, False, quantization)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -26,21 +26,25 @@ from transformer_engine.pytorch.ops.fused import ( ...@@ -26,21 +26,25 @@ from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear, UserbuffersBackwardLinear,
UserbuffersForwardLinear, UserbuffersForwardLinear,
) )
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions # Import utility functions
_current_file = pathlib.Path(__file__).resolve() _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent)) sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, str_to_dtype from utils import dtype_tols, make_recipe, str_to_dtype
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None] quantization_list: list[Optional[str]] = [None]
if fp8_available: if fp8_available:
quantization_list.append("fp8") quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available: if mxfp8_available:
quantization_list.append("mxfp8") quantization_list.append("mxfp8")
...@@ -118,11 +122,12 @@ def reset_rng(seed: int = 1234) -> None: ...@@ -118,11 +122,12 @@ def reset_rng(seed: int = 1234) -> None:
@torch.no_grad() @torch.no_grad()
def make_reference_and_test_tensors( def make_reference_and_test_tensors(
shape: int | Iterable[int], shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64, ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu", ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32, test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda", test_device: torch.device = "cuda",
test_is_fp8: bool = False, test_is_quantized: bool = False,
requires_grad: bool = True, requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values """Construct tensors with the same values
...@@ -131,47 +136,49 @@ def make_reference_and_test_tensors( ...@@ -131,47 +136,49 @@ def make_reference_and_test_tensors(
operations in high precision. The test tensor is intended for use operations in high precision. The test tensor is intended for use
in Transformer Engine operations. in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
""" """
# Random data # Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Make copy of tensor # Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype) test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8: if quantization is None:
if test_is_quantized:
raise ValueError("Quantization scheme not provided")
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer( quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device), scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device), amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
) )
test = quantizer(test) test = quantizer(test)
elif test.data_ptr() == ref.data_ptr(): elif quantization == "fp8_current_scaling":
test = test.clone() quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
test = test.dequantize()
# Make sure reference and test tensors represent exact same values # Make sure reference and test tensors match each other
ref.copy_(test) ref.copy_(test)
# Return reference and test tensors
ref.requires_grad_(requires_grad) ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad) test.requires_grad_(requires_grad)
return ref, test return ref, test
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_linear( def _test_linear(
*, *,
model_config: ModelConfig, model_config: ModelConfig,
...@@ -201,21 +208,16 @@ def _test_linear( ...@@ -201,21 +208,16 @@ def _test_linear(
reset_rng() reset_rng()
x_ref, x_test = make_reference_and_test_tensors( x_ref, x_test = make_reference_and_test_tensors(
in_shape, in_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
) )
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors( w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features), (out_features, in_features),
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
) )
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None b_ref, b_test = None, None
if bias: if bias:
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
...@@ -229,13 +231,11 @@ def _test_linear( ...@@ -229,13 +231,11 @@ def _test_linear(
) )
dy_ref, dy_test = make_reference_and_test_tensors( dy_ref, dy_test = make_reference_and_test_tensors(
out_shape, out_shape,
quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False, requires_grad=False,
) )
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation # Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref) y_ref = torch.nn.functional.linear(x_ref, w_ref)
......
This diff is collapsed.
...@@ -108,6 +108,18 @@ model_configs_fused_attn = { ...@@ -108,6 +108,18 @@ model_configs_fused_attn = {
"cp_2_4": ModelConfig( "cp_2_4": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA ), # GQA
"cp_3_0": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # MLA
"cp_3_1": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64
), # MLA
"cp_3_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64
), # MLA
} }
...@@ -160,6 +172,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ...@@ -160,6 +172,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
) )
if dtype != "fp8" and fp8_mha: if dtype != "fp8" and fp8_mha:
pytest.skip("Only fp8 works with fp8_mha=True!") pytest.skip("Only fp8 works with fp8_mha=True!")
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!")
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
......
...@@ -52,7 +52,7 @@ model_configs_infer = { ...@@ -52,7 +52,7 @@ model_configs_infer = {
4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16
), ),
"infer_1": ModelConfig( "infer_1": ModelConfig(
2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16 2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
), ),
} }
...@@ -370,12 +370,24 @@ def generate_args( ...@@ -370,12 +370,24 @@ def generate_args(
] ]
def get_tols(module, backend, dtype): def get_tols(config, module, backend, dtype):
if module == "TransformerLayer": if module == "TransformerLayer":
tols = { if config.head_dim_qk <= 128:
torch.half: (5e-3, 5e-3), tols = {
torch.bfloat16: (3.5e-2, 3.5e-2), torch.half: (5e-3, 5e-3),
} torch.bfloat16: (3.5e-2, 3.5e-2),
}
else:
if backend == "UnfusedAttention":
tols = {
torch.half: (1.6e-2, 1.6e-2),
torch.bfloat16: (1.2e-1, 1e-1),
}
else:
tols = {
torch.half: (1e-2, 1e-2),
torch.bfloat16: (8e-2, 7e-2),
}
if module == "DotProductAttention": if module == "DotProductAttention":
tols = { tols = {
torch.half: (1e-3, 1e-3), torch.half: (1e-3, 1e-3),
...@@ -662,7 +674,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g ...@@ -662,7 +674,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
incremental_output = incremental_output[0] incremental_output = incremental_output[0]
# compare results # compare results
atol, rtol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn) atol, rtol = get_tols(
config, module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn
)
for i, seq in enumerate(sim.t_seq_ids): for i, seq in enumerate(sim.t_seq_ids):
token_index = sim.step_lens[i] - 1 token_index = sim.step_lens[i] - 1
if qkv_format == "bshd": if qkv_format == "bshd":
......
...@@ -268,6 +268,7 @@ class BlockwiseQuantizerReference: ...@@ -268,6 +268,7 @@ class BlockwiseQuantizerReference:
eps: float = 0.0, eps: float = 0.0,
pow_2_scales: bool = False, pow_2_scales: bool = False,
quant_tile_shape: Tuple[int, int] = (128, 128), quant_tile_shape: Tuple[int, int] = (128, 128),
munge_scale_shapes: bool = True,
) -> QuantizeResult: ) -> QuantizeResult:
# sanity checks # sanity checks
assert x.dim() == 2 assert x.dim() == 2
...@@ -286,27 +287,33 @@ class BlockwiseQuantizerReference: ...@@ -286,27 +287,33 @@ class BlockwiseQuantizerReference:
assert quant_tile_shape in ((1, 128), (128, 128)) assert quant_tile_shape in ((1, 128), (128, 128))
if quant_tile_shape[0] == 1: if quant_tile_shape[0] == 1:
# Quantize row-wise # Quantize row-wise
return self.scale_munger.munge_scale_shapes_for_backend( result = self._quantize_vector_tiling(
self._quantize_vector_tiling( x,
x, quant_dtype,
quant_dtype, tile_len=quant_tile_shape[1],
tile_len=quant_tile_shape[1], return_transpose=return_transpose,
return_transpose=return_transpose, pow_2_scales=pow_2_scales,
pow_2_scales=pow_2_scales, eps=eps,
eps=eps,
),
quant_tile_shape,
) )
if munge_scale_shapes:
result = self.scale_munger.munge_scale_shapes_for_backend(
result,
quant_tile_shape,
)
return result
else: else:
# Quantize block-wise # Quantize block-wise
return self.scale_munger.munge_scale_shapes_for_backend( result = self._quantize_square_block_tiling(
self._quantize_square_block_tiling( x,
x, quant_dtype,
quant_dtype, tile_len=quant_tile_shape[0],
tile_len=quant_tile_shape[0], return_transpose=return_transpose,
return_transpose=return_transpose, pow_2_scales=pow_2_scales,
pow_2_scales=pow_2_scales, eps=eps,
eps=eps,
),
quant_tile_shape,
) )
if munge_scale_shapes:
result = self.scale_munger.munge_scale_shapes_for_backend(
result,
quant_tile_shape,
)
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment