Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
test_disable_fp8_layer:
enabled: True
layers:
layer_types: [qkv]
transformer_engine:
DisableFP8Layer:
enabled: True
\ No newline at end of file
deummy_feature_everywhere:
enabled: True
layers:
layer_name_regex_pattern: .*
transformer_engine:
TestDummyFeature:
enabled: True
tensors: [weight, activation, gradient, output, wgrad, dgrad]
gemms: [wgrad, dgrad, fprop]
\ No newline at end of file
test_fake_quant_fp8:
enabled: True
layers:
layer_numbers: [1]
layer_types: [fc1, fc2]
transformer_engine:
FakeQuant:
enabled: True
gemms: [fprop, dgrad]
tensors_struct:
- tensor: activation
quant_format: FP8E4M3
- tensor: gradient
quant_format: FP8E5M2
\ No newline at end of file
test_per_tensor_scaling:
enabled: True
layers:
layer_numbers: [1]
layer_types: [fc1, fc2]
transformer_engine:
DisableFP8GEMM:
enabled: True
gemms: [wgrad]
PerTensorScaling:
enabled: True
gemms_struct:
- gemm: fprop
tensors_struct:
- tensor: activation
- tensor: weight
- gemm: dgrad
tensors_struct:
- tensor: gradient
\ No newline at end of file
stat_collection_test_1:
enabled: True
layers:
layer_numbers: [1, 3]
LogTensorStats:
enabled: True
stats: [mean, std, l1_norm, l2_norm]
tensors: [activation]
freq: 1
start_step: 100
end_step: 500
transformer_engine:
LogTensorStats:
enabled: True
stats: [cur_amax, dynamic_range]
tensors: [activation]
freq: 2
start_step: 100
end_step: 500
LogFp8TensorStats:
enabled: True
stats: [underflows%]
tensors: [gradient]
freq: 5
start_step: 100
end_step: 500
stat_collection_test_2:
enabled: True
layers:
layer_numbers: [6, 7]
transformer_engine:
LogTensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [cur_amax, dynamic_range, mean, std, l1_norm]
freq: 2
start_step: 100
end_step: 500
- tensor: weight
stats: [mean, std, l1_norm, min, max]
freq: 5
start_step: 100
end_step: 500
stat_collection_test_4:
enabled: True
layers:
layer_numbers: [5]
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation]
stats: [cur_amax, dynamic_range, mean, std, l1_norm]
LogFp8TensorStats:
enabled: True
stats: [underflows%]
tensors: [activation]
\ No newline at end of file
# This config is used when FP8 training is ON
transformer_engine_fc1_manipulation:
enabled: True
layers:
layer_name_regex_pattern: .*(fc1) # Select layers if they end in fc1
transformer_engine: # namespace
DisableFP8GEMM: # Disable FP8 GEMM. FProp run in high precision
enabled: True
gemms: [fprop]
PerTensorScaling: # Scale DGrad gradients using per tensor current scaling and run FP8 GEMM
enabled: True
gemms: [dgrad]
tensors: [gradient]
FakeQuant: # Disable FP8 GEMM for Wgrad. Fake quantize activations to Wgrad and run high precision GEMM
enabled: True
gemms: [fprop]
tensors_struct:
- tensor: activation
quant_format: FP8E4M3
- tensor: weight
quant_format: FP8E4M3
transformer_engine_fc2_manipulation:
enabled: True
layers:
layer_name_regex_pattern: .*(fc2) # Select layers if they end in fc2
transformer_engine: # namespace
PerTensorScaling: # Scale WGrad and Fprop inputs using per tensor current scaling and run FP8 GEMM
enabled: True
gemms_struct:
- gemm: fprop
tensors_struct:
- tensor: activation
- tensor: weight
- gemm: wgrad
tensors_struct:
- tensor: activation
- tensor: gradient
FakeQuant: # Disable FP8 GEMM for DGrad. Fake quantize weights and gradients to DGrad and run high precision GEMM
enabled: True
gemms_struct:
- gemm: dgrad
tensors: [weight, gradient]
quant_format: FP8E5M2
\ No newline at end of file
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import subprocess
from pathlib import Path
import pytest
import torch
"""
Distributed numerics tests
These tests test the numerical corectness of the TransformerEngine layers.
Tests are parametrized by the layer and fp8 precision.
One test consists of running multiple configurations from file run_numerics.py
Such design is due to the fact the initialization of one test is long
- 2 processes need to start and load torch and TE. Multiple configurations
are run in one test - this reduces the initialization overhead.
"""
if torch.cuda.device_count() < 2:
pytest.skip("Distributed training needs at least 2 GPUs.")
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
def test_debug_distributed(feature_dirs):
test_path = TEST_ROOT / "run_distributed.py"
test_cmd = LAUNCH_CMD + [str(test_path), f"--feature_dirs={feature_dirs[0]}"]
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
if result.returncode != 0:
raise AssertionError(result.stderr.decode())
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import functools
import itertools
import os
import random
import tempfile
from string import Template
import pytest
import torch
import nvdlfw_inspect.api as debug_api
import transformer_engine.debug
import transformer_engine.pytorch as tepytorch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.pytorch.fp8 import _default_sf_compute
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.module.base import (
_2X_ACC_DGRAD,
_2X_ACC_FPROP,
_2X_ACC_WGRAD,
)
all_boolean = [True, False]
FP8_FORMAT = Format.HYBRID
AMAX_HISTORY_LEN = 16
FP8_RECIPE = DelayedScaling(
fp8_format=FP8_FORMAT, amax_history_len=AMAX_HISTORY_LEN, amax_compute_algo="max"
)
SEED = 1234
IN_SIZE = 128
OUT_SIZE = 64
BATCH_SIZE = 16
SEQ_LEN = 128
LOSS_FN = torch.nn.functional.cross_entropy
def _cast_to_fp8(tensor, scale, dtype):
tensor = tensor.contiguous()
if type(scale) == torch.Tensor:
amax = scale.abs().max().float()
quantizer = Float8Quantizer(scale, amax, dtype)
else:
quantizer = Float8CurrentScalingQuantizer(scale, device=tensor.device)
return quantizer(tensor)
def _get_current_scale(tensor, fp8_dtype):
if fp8_dtype == tex.DType.kFloat8E4M3:
fp8_max = Format.E4M3.value.max_fwd
else:
fp8_max = Format.E5M2.value.max_fwd
amax = tensor.abs().max().float()
one = torch.ones(1, device=tensor.device)
return _default_sf_compute(amax, one, fp8_max, 0).detach()
def _fake_cast(tensor, fp8_dtype, scale):
scale = scale or _get_current_scale(tensor, fp8_dtype)
fp8_tensor = _cast_to_fp8(tensor, scale, fp8_dtype)
return fp8_tensor.dequantize()
def _fp8_gemm_kernel(tensor1, scale1, dtype1, tensor2, scale2, dtype2, use_split_accumulator):
fp8_tensor1 = _cast_to_fp8(tensor1, scale1, dtype1)
fp8_tensor2 = _cast_to_fp8(tensor2, scale2, dtype2)
out, *_ = tepytorch.cpp_extensions.general_gemm(
fp8_tensor1,
fp8_tensor2,
tepytorch.module.base.get_workspace(),
torch.float32,
use_split_accumulator=use_split_accumulator,
)
out.requires_grad = True
return out.T
def _emulate_linear(
input: torch.Tensor,
weight: torch.Tensor,
fprop_fp8: bool = False,
fprop_input_fake_quant: tex.DType = None,
fprop_input_scale: torch.Tensor = None,
fprop_weight_fake_quant: tex.DType = None,
fprop_weight_scale: torch.Tensor = None,
dgrad_fp8: bool = False,
dgrad_gradient_fake_quant: tex.DType = None,
dgrad_gradient_scale: torch.Tensor = None,
dgrad_weight_fake_quant: tex.DType = None,
dgrad_weight_scale: torch.Tensor = None,
wgrad_fp8: bool = False,
wgrad_gradient_fake_quant: tex.DType = None,
wgrad_gradient_scale: torch.Tensor = None,
wgrad_input_fake_quant: tex.DType = None,
wgrad_input_scale: torch.Tensor = None,
loss_multiplier: float = 1.0,
activation_sync=None,
gradient_sync=None,
):
_scalar = lambda x: torch.Tensor([x]).cuda() if type(x) in [float, torch.Tensor] else x
if fprop_fp8:
activation = _fp8_gemm_kernel(
input,
_scalar(fprop_input_scale or 1.0),
tex.DType.kFloat8E4M3,
weight,
_scalar(fprop_weight_scale or 1.0),
tex.DType.kFloat8E4M3,
_2X_ACC_FPROP,
)
activation = activation.clone().detach().contiguous().requires_grad_(True)
else:
fprop_input = (
_fake_cast(input, fprop_input_fake_quant, _scalar(fprop_input_scale))
if fprop_input_fake_quant is not None
else input
)
fprop_weight = (
_fake_cast(weight, fprop_weight_fake_quant, _scalar(fprop_weight_scale))
if fprop_weight_fake_quant is not None
else weight
)
activation = (fprop_input @ fprop_weight.T).contiguous()
if activation_sync:
activation = activation_sync(activation)
activation.retain_grad()
(loss_multiplier * activation.sum()).backward(retain_graph=True)
gradient = activation.grad.clone()
if gradient_sync:
gradient = gradient_sync(gradient)
if dgrad_fp8:
dgrad = _fp8_gemm_kernel(
weight.T,
_scalar(dgrad_weight_scale or 1.0),
tex.DType.kFloat8E4M3,
gradient,
_scalar(dgrad_gradient_scale or 1.0),
tex.DType.kFloat8E5M2,
_2X_ACC_DGRAD,
).T
else:
dgrad_gradient = (
_fake_cast(gradient, dgrad_gradient_fake_quant, _scalar(dgrad_gradient_scale))
if dgrad_gradient_fake_quant is not None
else gradient
)
dgrad_weight = (
_fake_cast(weight, dgrad_weight_fake_quant, _scalar(dgrad_weight_scale))
if dgrad_weight_fake_quant is not None
else weight
)
dgrad = dgrad_gradient @ dgrad_weight
if wgrad_fp8:
wgrad = _fp8_gemm_kernel(
input.T,
_scalar(wgrad_input_scale or 1.0),
tex.DType.kFloat8E4M3,
gradient.T,
_scalar(wgrad_gradient_scale or 1.0),
tex.DType.kFloat8E5M2,
_2X_ACC_WGRAD,
).T
else:
wgrad_gradient = (
_fake_cast(gradient, wgrad_gradient_fake_quant, _scalar(wgrad_gradient_scale))
if wgrad_gradient_fake_quant is not None
else gradient
)
wgrad_input = (
_fake_cast(input, wgrad_input_fake_quant, _scalar(wgrad_input_scale))
if wgrad_input_fake_quant is not None
else input
)
wgrad_input = wgrad_input.contiguous()
wgrad_gradient = wgrad_gradient.contiguous()
wgrad, *_ = tepytorch.cpp_extensions.general_gemm(
wgrad_input,
wgrad_gradient,
tepytorch.module.base.get_workspace(),
torch.float32,
layout="NT",
grad=True,
use_split_accumulator=_2X_ACC_WGRAD,
)
return {"activation": activation, "wgrad": wgrad, "dgrad": dgrad}
def _init_debug(config_name, log_dir, feature_dirs):
debug_api.initialize(
config_file=config_name,
feature_dirs=feature_dirs,
log_dir=log_dir,
default_logging_enabled=True,
)
def create_config_file(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
with tempfile.TemporaryDirectory() as temp_dir:
try:
kwargs["config_file"] = temp_file
kwargs["log_dir"] = temp_dir
result = func(*args, **kwargs)
finally:
temp_file_name = temp_file.name
debug_api.end_debug()
os.unlink(temp_file_name)
return result
return wrapper
def _cmp(ground_truth, output):
torch.testing.assert_close(ground_truth["activation"], output["activation"])
torch.testing.assert_close(ground_truth["wgrad"], output["wgrad"])
torch.testing.assert_close(ground_truth["dgrad"], output["dgrad"])
def _init_model(weight):
model = transformer_engine.pytorch.Linear(IN_SIZE, OUT_SIZE, name="linear")
with torch.no_grad():
model.weight.copy_(weight.contiguous())
return model
def _run_forward_backward(x, model, loss_scale=1.0, is_first_microbatch=None):
with tepytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
y = model(x, is_first_microbatch=is_first_microbatch)
(y.sum() * loss_scale).backward()
debug_api.step()
return y
def _get_tensors():
torch.manual_seed(SEED)
x = torch.randn((SEQ_LEN * BATCH_SIZE, IN_SIZE), requires_grad=True).cuda()
x.retain_grad()
weight = torch.randn((OUT_SIZE, IN_SIZE)).cuda()
return x, weight
DISABLE_FP8_CONFIG = Template(
"""disable_fp8_config:
enabled: True
layers:
layer_types: [linear]
transformer_engine:
DisableFP8GEMM:
enabled: True
gemms: [$gemms]
"""
)
@pytest.mark.parametrize("fprop_fp8", all_boolean)
@pytest.mark.parametrize("dgrad_fp8", all_boolean)
@pytest.mark.parametrize("wgrad_fp8", all_boolean)
def test_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8):
run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8)
def disable_fp8_gemms_create_config(fprop_fp8, dgrad_fp8, wgrad_fp8, config_file):
gemms = ""
if not fprop_fp8:
gemms += "fprop,"
if not dgrad_fp8:
gemms += "dgrad,"
if not wgrad_fp8:
gemms += "wgrad,"
if len(gemms) > 0:
gemms = gemms[:-1] # remove last ','
config_file.write(DISABLE_FP8_CONFIG.safe_substitute(gemms=gemms))
config_file.flush()
@create_config_file
def run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8, **kwargs):
disable_fp8_gemms_create_config(fprop_fp8, dgrad_fp8, wgrad_fp8, kwargs["config_file"])
fp8_kwargs = {
"fprop_fp8": fprop_fp8,
"dgrad_fp8": dgrad_fp8,
"wgrad_fp8": wgrad_fp8,
}
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs)
x, weight = _get_tensors()
model = _init_model(weight)
y = _run_forward_backward(x, model)
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
x.grad.zero_()
ground_truth = _emulate_linear(x, weight, **fp8_kwargs)
_cmp(ground_truth, output)
def test_disable_fp8_layer(feature_dirs):
run_disable_fp8_layer(feature_dirs)
DISABLE_FP8_LAYER_CONFIG = """disable_fp8_config:
enabled: True
layers:
layer_types: [linear]
transformer_engine:
DisableFP8Layer:
enabled: True
"""
@create_config_file
def run_disable_fp8_layer(feature_dirs, **kwargs):
kwargs["config_file"].write(DISABLE_FP8_LAYER_CONFIG)
kwargs["config_file"].flush()
x, weight = _get_tensors()
ground_truth = _emulate_linear(x, weight)
x.grad.zero_()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs)
model = _init_model(weight)
y = _run_forward_backward(x, model)
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
_cmp(ground_truth, output)
random.seed(1234)
all_combinations = list(itertools.product(all_boolean, repeat=6))
subset_combinations = random.sample(all_combinations, 20)
@pytest.mark.parametrize(
"fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad",
subset_combinations,
)
def test_per_tensor_scaling(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
):
if not any([fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad]):
pytest.skip("Skipping test because all parameters are False")
run_per_tensor_scaling(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
)
PER_TENSOR_SCALING_CONFIG = Template(
"""per_tensor_scaling_config:
enabled: True
layers:
layer_types: [linear]
transformer_engine:
PerTensorScaling:
enabled: True
gemms_struct:
$gemms
"""
)
def _prepare_per_tensor_scaling_config(
fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad, config_file
):
gemms = ""
title = lambda x: f" - gemm: {x}\n tensors: ["
def add_tensor(if_add, gemm_name):
nonlocal gemms
if if_add:
gemms += gemm_name + ","
if fprop_inp or fprop_weight:
gemms += title("fprop")
add_tensor(fprop_inp, "activation")
add_tensor(fprop_weight, "weight")
gemms = gemms[:-1] + "]\n"
if dgrad_weight or dgrad_grad:
gemms += title("dgrad")
add_tensor(dgrad_weight, "weight")
add_tensor(dgrad_grad, "gradient")
gemms = gemms[:-1] + "]\n"
if wgrad_input or wgrad_grad:
gemms += title("wgrad")
add_tensor(wgrad_input, "activation")
add_tensor(wgrad_grad, "gradient")
gemms = gemms[:-1] + "]\n"
config_file.write(PER_TENSOR_SCALING_CONFIG.safe_substitute(gemms=gemms))
config_file.flush()
def set_scaling_factors(model, input_kwargs, fp8_kwargs):
# Copy fp8 scaling factors into fp8_kwargs dict if respective flag in input_kwargs is set.
if not input_kwargs["fprop_inp"]:
fp8_kwargs["fprop_input_scale"] = model.fp8_meta["scaling_fwd"].scale[0].clone()
if not input_kwargs["fprop_weight"]:
fp8_kwargs["fprop_weight_scale"] = model.fp8_meta["scaling_fwd"].scale[1].clone()
if not input_kwargs["dgrad_grad"]:
fp8_kwargs["dgrad_gradient_scale"] = model.fp8_meta["scaling_bwd"].scale[0].clone()
if not input_kwargs["dgrad_weight"]:
fp8_kwargs["dgrad_weight_scale"] = model.fp8_meta["scaling_fwd"].scale[1].clone()
if not input_kwargs["wgrad_grad"]:
fp8_kwargs["wgrad_gradient_scale"] = model.fp8_meta["scaling_bwd"].scale[0].clone()
if not input_kwargs["wgrad_input"]:
fp8_kwargs["wgrad_input_scale"] = model.fp8_meta["scaling_fwd"].scale[0].clone()
def set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs):
# Compute per tensor scaling factor if respective flag in input_kwargs is set.
if input_kwargs["fprop_inp"]:
fp8_kwargs["fprop_input_scale"] = tex.DType.kFloat8E4M3
if input_kwargs["fprop_weight"]:
fp8_kwargs["fprop_weight_scale"] = tex.DType.kFloat8E4M3
if input_kwargs["dgrad_grad"]:
fp8_kwargs["dgrad_gradient_scale"] = tex.DType.kFloat8E5M2
if input_kwargs["dgrad_weight"]:
fp8_kwargs["dgrad_weight_scale"] = tex.DType.kFloat8E4M3
if input_kwargs["wgrad_grad"]:
fp8_kwargs["wgrad_gradient_scale"] = tex.DType.kFloat8E5M2
if input_kwargs["wgrad_input"]:
fp8_kwargs["wgrad_input_scale"] = tex.DType.kFloat8E4M3
@create_config_file
def run_per_tensor_scaling(
feature_dirs,
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
**kwargs,
):
input_kwargs = {
"fprop_inp": fprop_inp,
"fprop_weight": fprop_weight,
"dgrad_weight": dgrad_weight,
"dgrad_grad": dgrad_grad,
"wgrad_input": wgrad_input,
"wgrad_grad": wgrad_grad,
}
fp8_kwargs = {
"fprop_fp8": True,
"dgrad_fp8": True,
"wgrad_fp8": True,
}
"""
Runs a test to validate per-tensor (current) scaling in FP8 computations.
The function performs warm-up iterations to populate the amax buffer of the model and compute scaling factors based on delayed scaling.
Subsequently, weights and inputs are switched to ensure their current scaling factors differ from those based on delayed scaling;
similarly, the loss is multiplied by a large factor to alter the gradient's magnitude,
creating a discrepancy between the original (delayed) and per-tensor (current) scaling factors.
Finally, a linear pass is emulated, and the results are compared.”
"""
_prepare_per_tensor_scaling_config(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
kwargs["config_file"],
)
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs)
warmup_input, warmup_weight = _get_tensors()
model = _init_model(warmup_weight)
# Warmup run to setup amax and scaling factors.
for _ in range(AMAX_HISTORY_LEN):
_run_forward_backward(warmup_input, model)
x = torch.randn_like(warmup_input, requires_grad=True).cuda()
weight = torch.randn_like(warmup_weight, requires_grad=True).cuda()
model.weight.data = weight.data
x.retain_grad()
# delayed scaling factor
# need to be collected before forward pass with test data,
# because this forward pass changes scaling factors
set_scaling_factors(model, input_kwargs, fp8_kwargs)
LOSS_MULTIPLIER = 100
with tepytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
y = model(x, is_first_microbatch=True)
model.zero_grad()
y.retain_grad()
(
LOSS_MULTIPLIER * y.sum()
).backward() # Loss multiplication to change gradient's order of magintude
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
# per tensor - current - scaling factors
# need to be collected after forward pass with test data,
# because gradient(y.grad) cannot be accessed before forward,
# but it needs to be collected.
set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs)
ground_truth = _emulate_linear(x, weight, loss_multiplier=LOSS_MULTIPLIER, **fp8_kwargs)
_cmp(ground_truth, output)
@pytest.mark.parametrize(
"fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad",
subset_combinations,
)
def test_microbatching_per_tensor_scaling(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
):
if not any([fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad]):
pytest.skip("Skipping test because all parameters are False")
@create_config_file
def run_microbatching_test(
feature_dirs,
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
**kwargs,
):
# Prepare the configuration file
_prepare_per_tensor_scaling_config(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
kwargs["config_file"],
)
# Initialize debug
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs)
# Get data
x_full, weight = _get_tensors()
microbatch_size = x_full.size(0) // 2
x_mb1 = x_full[:microbatch_size, ...].clone().detach().requires_grad_(True)
x_mb2 = x_full[microbatch_size:, ...].clone().detach().requires_grad_(True)
def init_and_warmup():
model = _init_model(weight)
_run_forward_backward(x_mb1, model, loss_scale=0.5)
_run_forward_backward(x_mb2, model, loss_scale=0.5)
return model
# Run without is_first_microbatch
model = init_and_warmup() # running next 2 iters does not change amaxes and scaling factors
y_mb1 = _run_forward_backward(x_mb1, model, loss_scale=0.5)
y_mb2 = _run_forward_backward(x_mb2, model, loss_scale=0.5)
# Collect outputs
output1 = {
"activation": torch.cat([y_mb1.clone(), y_mb2.clone()], dim=0),
"wgrad": model.weight.grad.clone(),
"dgrad": torch.cat([x_mb1.grad.clone(), x_mb2.grad.clone()], dim=0),
}
# Run with is_first_microbatch
model = init_and_warmup() # running next 2 iters does not change amaxes and scaling factors
y_mb1 = _run_forward_backward(x_mb1, model, loss_scale=0.5, is_first_microbatch=True)
y_mb2 = _run_forward_backward(x_mb2, model, loss_scale=0.5, is_first_microbatch=False)
# Collect outputs
output2 = {
"activation": torch.cat([y_mb1.clone(), y_mb2.clone()], dim=0),
"wgrad": model.weight.grad.clone(),
"dgrad": torch.cat([x_mb1.grad.clone(), x_mb2.grad.clone()], dim=0),
}
# Compare outputs
torch.testing.assert_close(output1["activation"], output2["activation"], atol=1.0, rtol=0.5)
torch.testing.assert_close(output1["dgrad"], output2["dgrad"], atol=1.0, rtol=0.5)
torch.testing.assert_close(output1["wgrad"], output2["wgrad"], atol=1.0, rtol=0.5)
# Run the test
run_microbatching_test(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
)
all_combinations = list(
itertools.product([tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None], repeat=6)
)
subset_combinations = random.sample(all_combinations, 10)
@pytest.mark.parametrize(
"fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad",
subset_combinations,
)
def test_fake_quant_fp8(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
):
run_fake_quant_fp8(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
)
FAKE_QUANT_CONFIG = Template(
"""fake_quant_config:
enabled: True
layers:
layer_types: [linear]
transformer_engine:
FakeQuant:
enabled: True
gemms_struct:
$gemms
"""
)
def fake_quant_fp8_create_config(
fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad, config_file
):
format_to_str = {tex.DType.kFloat8E4M3: "FP8E4M3", tex.DType.kFloat8E5M2: "FP8E5M2"}
gemms = ""
def _add_tensor(quant_format, tensor):
nonlocal gemms
if quant_format:
gemms += " " * 8 + "- tensor: " + tensor + "\n"
gemms += " " * 8 + " quant_format: " + format_to_str[quant_format] + "\n"
title = lambda x: f" - gemm: {x}\n tensors_struct:\n"
if fprop_inp or fprop_weight:
gemms += title("fprop")
_add_tensor(fprop_inp, "activation")
_add_tensor(fprop_weight, "weight")
gemms = gemms[:-1] + "\n"
if dgrad_weight or dgrad_grad:
gemms += title("dgrad")
_add_tensor(dgrad_weight, "weight")
_add_tensor(dgrad_grad, "gradient")
gemms = gemms[:-1] + "\n"
if wgrad_input or wgrad_grad:
gemms += title("wgrad")
_add_tensor(wgrad_input, "activation")
_add_tensor(wgrad_grad, "gradient")
gemms = gemms[:-1] + "\n"
config = FAKE_QUANT_CONFIG.safe_substitute(gemms=gemms)
config_file.write(config)
config_file.flush()
@create_config_file
def run_fake_quant_fp8(
feature_dirs,
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
**kwargs,
):
fp8_kwargs = {
"fprop_input_fake_quant": fprop_inp,
"fprop_weight_fake_quant": fprop_weight,
"dgrad_gradient_fake_quant": dgrad_grad,
"dgrad_weight_fake_quant": dgrad_weight,
"wgrad_gradient_fake_quant": wgrad_grad,
"wgrad_input_fake_quant": wgrad_input,
"fprop_fp8": not (fprop_inp or fprop_weight),
"dgrad_fp8": not (dgrad_weight or dgrad_grad),
"wgrad_fp8": not (wgrad_grad or wgrad_input),
}
fake_quant_fp8_create_config(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
kwargs["config_file"],
)
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs)
x, weight = _get_tensors()
model = _init_model(weight)
y = _run_forward_backward(x, model)
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
ground_truth = _emulate_linear(x, weight, **fp8_kwargs)
_cmp(ground_truth, output)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import functools
import itertools
import os
import random
import tempfile
from string import Template
import pytest
import torch
import nvdlfw_inspect.api as debug_api
import transformer_engine.debug
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import _default_sf_compute
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from test_numerics import create_config_file
B, S, H, D = 64, 64, 64, 64
model_keys = ["linear", "layernorm_linear", "layernorm_mlp", "mha_attention", "transformer_layer"]
configs = {
"": "",
"log": """log:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
LogFp8TensorStats:
enabled: True
tensors: [activation, gradient, weight]
stats: [underflows, overflows]
start_step : 0
end_step: 1
""",
"fake_quant": """
fake_quant_config:
enabled: True
layers:
layer_types: [linear]
transformer_engine:
FakeQuant:
enabled: True
gemms: [fprop, dgrad, wgrad]
quant_format: FP8E5M2
""",
}
def _get_model(model_key):
if model_key == "linear":
return te.Linear(D, D)
if model_key == "layernorm_linear":
return te.LayerNormLinear(D, D)
if model_key == "layernorm_mlp":
return te.LayerNormMLP(D, D, D)
if model_key == "mha_attention":
return te.MultiheadAttention(D, H)
if model_key == "transformer_layer":
return te.TransformerLayer(D, D, H)
def _run_forward_backward(model, fp8):
for _ in range(3):
inp = torch.randn((S, B, H)).cuda()
with te.fp8_autocast(enabled=fp8):
out = model(inp)
out.sum().backward()
debug_api.step()
@create_config_file
def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
try:
if config != "":
config_file.write(config)
config_file.flush()
config_file_name = config_file.name if config != "" else ""
debug_api.initialize(feature_dirs=feature_dirs, config_file=config_file_name)
model = _get_model(model_key)
_run_forward_backward(model, fp8)
except Exception as error:
raise error
finally:
debug_api.end_debug()
@pytest.mark.parametrize("model_key", model_keys)
@pytest.mark.parametrize("fp8", [False, True])
@pytest.mark.parametrize("config_key", configs.keys())
def test_sanity_debug(model_key, fp8, config_key, feature_dirs):
_run_test(model_key, fp8, configs[config_key], feature_dirs)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
LOG_FILE = os.path.join("nvdlfw_inspect_logs", "nvdlfw_inspect_globalrank-0.log")
def reset_debug_log():
if os.path.isfile(LOG_FILE):
# delete all content
with open(LOG_FILE, "w") as f:
pass
def check_debug_log(msg):
with open(LOG_FILE, "r") as f:
for line in f.readlines():
if msg in line:
return True
return False
......@@ -274,7 +274,9 @@ def _main(opts):
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
tp_group = dist.new_group(backend="nccl")
tp_group = dist.new_group(
backend="nccl", pg_options=dist.ProcessGroupNCCL.Options(is_high_priority_stream=True)
)
tp_rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group)
dist_print(
......
......@@ -323,6 +323,7 @@ def _train(opts):
new_group_kwargs = {
"backend": "nccl",
"ranks": tp_rank_list,
"pg_options": dist.ProcessGroupNCCL.Options(is_high_priority_stream=True),
}
else:
opts.tp = WORLD_SIZE
......
......@@ -35,6 +35,18 @@ NCCL_WORLD = None
LOSS_FN = nn.MSELoss()
QUANTIZATION = None
if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
# The numerics of all the layers should work the same,
# when debug=True. I fed them with dummy feature
# to prevent switching off debug, which can happen if
# no feature is active.
import nvdlfw_inspect.api as debug_api
debug_api.initialize(
os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
)
# Disable TF32
torch.backends.cuda.matmul.allow_tf32 = False
......@@ -89,11 +101,15 @@ def main(argv=None, namespace=None):
# Quantization scheme
QUANTIZATION = args.quantization
if QUANTIZATION in ("fp8", "mxfp8", "fp8_block_scaling"):
global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE
if QUANTIZATION in ("fp8", "mxfp8"):
SEQ_LEN = 32
BATCH_SIZE = 32
HIDDEN_SIZE = 128
elif QUANTIZATION == "fp8_block_scaling":
SEQ_LEN = 128
BATCH_SIZE = 128
HIDDEN_SIZE = 512
test_dict = [
test_quantizer,
......@@ -174,7 +190,7 @@ def _get_tolerances(dtype):
if dtype == torch.bfloat16:
return {"rtol": 1.6e-2, "atol": 1e-5}
if dtype == torch.float32:
return {"rtol": 1.3e-6, "atol": 4e-5}
return {"rtol": 1.2e-4, "atol": 1e-4}
raise ValueError(f"Unsupported dtype ({dtype})")
......@@ -638,7 +654,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
if "return_layernorm_output" in kwargs:
output_single_node, norm_s = output_single_node
output_distributed, norm_d = output_distributed
if sequence_parallel:
if sequence_parallel and not kwargs.get("return_layernorm_output_gathered", False):
norm_d = _gather(norm_d)
_check_outputs(norm_s, norm_d)
......@@ -747,7 +763,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
if "return_layernorm_output" in kwargs:
output_single_node, norm_s = output_single_node
output_distributed, norm_d = output_distributed
if sequence_parallel:
if sequence_parallel and not kwargs.get("return_layernorm_output_gathered", False):
norm_d = _gather(norm_d)
_check_outputs(norm_s, norm_d)
......
......@@ -57,7 +57,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
torch._dynamo.reset()
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization):
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization):
test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
......@@ -84,6 +84,8 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, quantization):
if torch.cuda.get_device_properties(0).major != 9:
pytest.skip("Atomic GEMM is requires device compute capability 9.x (Hopper).")
test_cmd.append("--atomic")
if aggregate:
test_cmd.append("--aggregate")
logging.info(f"test_cmd: {test_cmd}")
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
......@@ -142,12 +144,13 @@ def _run_layer_with_overlap(
@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
def test_split_all_gather_overlaps(quantization):
@pytest.mark.parametrize("aggregate", (False, True))
def test_split_all_gather_overlaps(quantization, aggregate):
"""
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("AG", False, True, False, quantization)
_run_gemm_with_overlap("AG", False, True, False, aggregate, quantization)
@pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8"))
......@@ -157,7 +160,7 @@ def test_split_reduce_scatter_overlaps(quantization, p2p):
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("RS", False, p2p, False, quantization)
_run_gemm_with_overlap("RS", False, p2p, False, False, quantization)
@pytest.mark.parametrize(
......@@ -190,10 +193,10 @@ def test_bulk_overlaps(comm_type, quantization, connections):
" 9.0 (HOPPER ARCH)."
)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
_run_gemm_with_overlap(comm_type, True, False, False, quantization)
_run_gemm_with_overlap(comm_type, True, False, False, False, quantization)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
else:
_run_gemm_with_overlap(comm_type, True, False, False, quantization)
_run_gemm_with_overlap(comm_type, True, False, False, False, quantization)
@pytest.mark.parametrize(
......
......@@ -22,19 +22,28 @@ import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
from torch.utils.cpp_extension import IS_HIP_EXTENSION
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, make_recipe
# Check what quantization schemes are supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.append("fp8")
quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
quantization_list.append("mxfp8")
......@@ -63,11 +72,12 @@ def reset_rng(seed: int = 1234) -> None:
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_fp8: bool = False,
test_is_quantized: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
......@@ -76,78 +86,55 @@ def make_reference_and_test_tensors(
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
"""
# Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
if quantization is None:
if test_is_quantized:
raise ValueError("Quantization scheme not provided")
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device),
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
test = test.dequantize()
# Make sure reference and test tensors match each other
ref.copy_(test)
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
if dtype == tex.DType.kFloat8E4M3:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == tex.DType.kFloat8E5M2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
}[dtype]
# PyTorch dtypes
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
raise ValueError(f"Unsupported dtype ({dtype})")
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_all_reduce(
*,
local_size: int = 17,
local_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
quantization: Optional[str] = None,
) -> None:
# Distributed process group
......@@ -156,22 +143,25 @@ def _test_all_reduce(
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
in_shape = [world_size, local_size]
out_shape = [local_size]
in_shape = [world_size, local_size, local_size]
out_shape = [local_size, local_size]
# Random data
reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
# Plain PyTorch implementation
......@@ -199,10 +189,10 @@ def _test_all_reduce(
def _test_all_gather(
*,
local_size: int = 13,
local_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
quantization: Optional[str] = None,
) -> None:
# Distributed process group
......@@ -211,26 +201,29 @@ def _test_all_gather(
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
in_shape = [world_size, local_size]
out_shape = [world_size, world_size * local_size]
in_shape = [world_size, local_size, local_size]
out_shape = [world_size, world_size * local_size, local_size]
# Random data
reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
# Plain PyTorch implementation
y_ref = x_ref.tile((world_size, 1)).reshape(out_shape)
y_ref = x_ref.tile((world_size, 1, 1)).reshape(out_shape)
y_ref.backward(dy_ref)
# Convert to distributed tensors
......@@ -257,10 +250,10 @@ def _test_all_gather(
def _test_reduce_scatter(
*,
local_size: int = 11,
local_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
quantization: Optional[str] = None,
) -> None:
# Distributed process group
......@@ -269,22 +262,25 @@ def _test_reduce_scatter(
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
in_shape = [world_size, world_size * local_size]
out_shape = [world_size, local_size]
in_shape = [world_size, world_size * local_size, local_size]
out_shape = [world_size, local_size, local_size]
# Random data
reset_rng()
with_quantization = quantization is not None
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
test_is_quantized=with_quantization,
)
# Plain PyTorch implementation
......@@ -324,7 +320,11 @@ def _test_basic_linear(
tensor_parallel_mode: str = "column",
sequence_parallel: bool = False,
) -> None:
# Skip invalid configurations
quantized_compute = quantization is not None
if not quantized_compute and quantized_weight:
return
# Distributed process group
process_group = world_group()
......@@ -348,30 +348,23 @@ def _test_basic_linear(
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
......@@ -468,7 +461,11 @@ def _test_linear(
tensor_parallel_mode: str = "column",
sequence_parallel: bool = False,
) -> None:
# Skip invalid configurations
quantized_compute = quantization is not None
if not quantized_compute and quantized_weight:
return
# Distributed process group
process_group = world_group()
......@@ -492,21 +489,16 @@ def _test_linear(
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None
if bias:
if tensor_parallel_mode == "row":
......@@ -520,13 +512,11 @@ def _test_linear(
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
......@@ -773,9 +763,10 @@ def run_parallel_tests() -> None:
if rank == 0:
print(f"Running _test_all_reduce")
_test_all_reduce()
for quantization in quantization_list:
if rank == 0:
print(f"Running _test_all_gather")
_test_all_gather()
print(f"Running _test_all_gather with quantization={quantization}")
_test_all_gather(quantization=quantization)
if rank == 0:
print(f"Running _test_reduce_scatter")
_test_reduce_scatter()
......
......@@ -26,21 +26,25 @@ from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear,
UserbuffersForwardLinear,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, str_to_dtype
from utils import dtype_tols, make_recipe, str_to_dtype
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.append("fp8")
quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
quantization_list.append("mxfp8")
......@@ -118,11 +122,12 @@ def reset_rng(seed: int = 1234) -> None:
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
quantization: Optional[str] = None,
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_fp8: bool = False,
test_is_quantized: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
......@@ -131,47 +136,49 @@ def make_reference_and_test_tensors(
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
If a quantization scheme is provided, the tensor values are
quantized so that they are representable.
"""
# Random data
# Random reference tensor
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Make copy of tensor
# Construct test tensor from reference tensor
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
if quantization is None:
if test_is_quantized:
raise ValueError("Quantization scheme not provided")
if test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization in ("fp8", "fp8_delayed_scaling"):
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device),
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
elif quantization == "fp8_current_scaling":
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=test_device,
)
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
test = test.dequantize()
# Make sure reference and test tensors represent exact same values
# Make sure reference and test tensors match each other
ref.copy_(test)
# Return reference and test tensors
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_linear(
*,
model_config: ModelConfig,
......@@ -201,21 +208,16 @@ def _test_linear(
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None
if bias:
if tensor_parallel_mode == "row":
......@@ -229,13 +231,11 @@ def _test_linear(
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
......
......@@ -223,13 +223,19 @@ def _get_attention_backends(
model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
# test: b, h, hg, d, sq, skv, p, mask, bias
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_4_0": ModelConfig(8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_4_1": ModelConfig(8, 16, 16, 192, 128, 2048, 0.0, "no_mask", "no_bias"),
"base_5_0": ModelConfig(8, 16, 16, 512, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_5_1": ModelConfig(8, 16, 16, 512, 128, 2048, 0.0, "no_mask", "no_bias"),
"base_6_0": ModelConfig(8, 16, 16, 1024, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_6_1": ModelConfig(8, 16, 16, 1024, 128, 2048, 0.0, "no_mask", "no_bias"),
}
......@@ -271,12 +277,26 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
is_training = True
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
is_training = False
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
......@@ -297,7 +317,6 @@ def test_dot_product_attention(
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128
# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
......@@ -361,6 +380,7 @@ def test_dot_product_attention(
is_training,
)
logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
if unfused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
......@@ -400,12 +420,18 @@ if IS_HIP_EXTENSION:
"mla_1_1": ModelConfig(
4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_1_2": ModelConfig(
4, 16, 16, 192, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_2_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64
), # self , 1
"mla_2_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # cross, 1
"mla_2_2": ModelConfig(
1, 24, 24, 192, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=128
), # cross, 1
}
else:
model_configs_mla = {
......@@ -416,18 +442,27 @@ else:
"mla_1_1": ModelConfig(
4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_1_2": ModelConfig(
4, 16, 16, 192, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_2_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64
), # self , 1
"mla_2_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # cross, 1
"mla_2_2": ModelConfig(
1, 24, 24, 192, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=128
), # cross, 1
"mla_3_0": ModelConfig(
8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64
), # inference
"mla_3_1": ModelConfig(
8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
"mla_3_2": ModelConfig(
8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
}
......@@ -1041,6 +1076,8 @@ def _run_dot_product_attention(
layer_number=1,
attention_type=config.attn_type,
).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()
# Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
......@@ -1124,9 +1161,11 @@ model_configs_te_layer = {
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"),
"te_1_3": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"te_2_3": ModelConfig(1, 16, 16, 64, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"),
"te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"),
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
}
......@@ -1137,7 +1176,7 @@ model_configs_te_layer = {
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("qkv_format", ["sbhd"])
@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd", "thd"])
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
def test_transformer_layer(
......@@ -1156,16 +1195,30 @@ def test_transformer_layer(
qkv_layout = "sbhd_sbhd_sbhd"
# Test backend availability
is_training = True
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
is_training = False
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
# Skip if qkv_format = thd and "padding" not in attn_mask_type
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
pytest.skip("THD requires padding mask.")
# UnfusedDotProductAttention backend
if unfused_attn_supported:
......@@ -1178,6 +1231,7 @@ def test_transformer_layer(
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)
# FusedAttention backend
......@@ -1191,6 +1245,7 @@ def test_transformer_layer(
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)
# FlashAttention backend
......@@ -1204,8 +1259,10 @@ def test_transformer_layer(
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)
logging.info(f"[test_transformer_layer]: is_training = {is_training}")
if unfused_attn_supported and fused_attn_supported:
logging.info("[test_transformer_layer]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
......@@ -1272,6 +1329,7 @@ def _run_transformer_layer(
workspace_opt: bool,
fused_qkv_params: bool,
RoPE: bool,
is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run TransformerLayer module with one forward pass and one backward pass"""
......@@ -1286,6 +1344,7 @@ def _run_transformer_layer(
_attention_backends["backend_selection_requires_update"] = True
# Create input tensor
if qkv_format == "sbhd":
inp = torch.randn(
config.max_seqlen_q,
config.batch_size,
......@@ -1294,40 +1353,75 @@ def _run_transformer_layer(
device="cuda",
requires_grad=True,
)
# In case the format to be tested is batch-first, need to transpose the
# input tensor.
inp_enc = torch.randn(
config.max_seqlen_kv,
config.batch_size,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
if qkv_format == "bshd":
inp = inp.transpose(0, 1)
inp = torch.randn(
config.batch_size,
config.max_seqlen_q,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_enc = torch.randn(
config.batch_size,
config.max_seqlen_kv,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
# Create seqlens
if "padding" in config.attn_mask_type:
if "padding" in config.attn_mask_type or qkv_format == "thd":
if config.attn_type == "self":
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
seqlens_kv = seqlens_q
if config.attn_type == "cross":
if config.max_seqlen_q > 1:
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
else:
seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
seqlens_kv = torch.randint(
1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
)
else:
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
)
# Create attention mask if padding
attention_mask = None
if "padding" in config.attn_mask_type:
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat(
[
attention_mask_q,
torch.Tensor(
[False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
seqlens_kv = torch.full(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
.to(torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
if qkv_format == "thd":
inp = torch.randn(
cu_seqlens_q[-1],
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_enc = torch.randn(
cu_seqlens_kv[-1],
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
attention_mask = attention_mask_q.to(device="cuda")
sigma = 0.02
init_method = init_method_normal(sigma)
......@@ -1379,7 +1473,7 @@ def _run_transformer_layer(
sequence_parallel=False,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="encoder",
layer_type="encoder" if config.attn_type == "self" else "decoder",
drop_path_rate=drop_path_rates[layer_number - 1],
set_parallel_mode=True,
fuse_qkv_params=fused_qkv_params,
......@@ -1389,6 +1483,8 @@ def _run_transformer_layer(
bias=True,
attn_input_format=qkv_format,
).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()
# Create ALiBi slopes
alibi_slopes = None
......@@ -1398,14 +1494,20 @@ def _run_transformer_layer(
# Run a forward and backward pass
out = block(
inp,
attention_mask=attention_mask,
self_attn_mask_type=config.attn_mask_type,
encoder_output=inp_enc if config.attn_type == "cross" else None,
enc_dec_attn_mask_type=config.attn_mask_type if config.attn_type == "cross" else None,
checkpoint_core_attention=False,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
)
if is_training:
loss = out.sum()
loss.backward()
......
......@@ -108,6 +108,18 @@ model_configs_fused_attn = {
"cp_2_4": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
"cp_3_0": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # MLA
"cp_3_1": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64
), # MLA
"cp_3_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64
), # MLA
}
......@@ -160,6 +172,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
)
if dtype != "fp8" and fp8_mha:
pytest.skip("Only fp8 works with fp8_mha=True!")
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!")
subprocess.run(
get_bash_arguments(
......
......@@ -52,7 +52,7 @@ model_configs_infer = {
4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16
),
"infer_1": ModelConfig(
2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
),
}
......@@ -370,12 +370,24 @@ def generate_args(
]
def get_tols(module, backend, dtype):
def get_tols(config, module, backend, dtype):
if module == "TransformerLayer":
if config.head_dim_qk <= 128:
tols = {
torch.half: (5e-3, 5e-3),
torch.bfloat16: (3.5e-2, 3.5e-2),
}
else:
if backend == "UnfusedAttention":
tols = {
torch.half: (1.6e-2, 1.6e-2),
torch.bfloat16: (1.2e-1, 1e-1),
}
else:
tols = {
torch.half: (1e-2, 1e-2),
torch.bfloat16: (8e-2, 7e-2),
}
if module == "DotProductAttention":
tols = {
torch.half: (1e-3, 1e-3),
......@@ -662,7 +674,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
incremental_output = incremental_output[0]
# compare results
atol, rtol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn)
atol, rtol = get_tols(
config, module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn
)
for i, seq in enumerate(sim.t_seq_ids):
token_index = sim.step_lens[i] - 1
if qkv_format == "bshd":
......
......@@ -268,6 +268,7 @@ class BlockwiseQuantizerReference:
eps: float = 0.0,
pow_2_scales: bool = False,
quant_tile_shape: Tuple[int, int] = (128, 128),
munge_scale_shapes: bool = True,
) -> QuantizeResult:
# sanity checks
assert x.dim() == 2
......@@ -286,27 +287,33 @@ class BlockwiseQuantizerReference:
assert quant_tile_shape in ((1, 128), (128, 128))
if quant_tile_shape[0] == 1:
# Quantize row-wise
return self.scale_munger.munge_scale_shapes_for_backend(
self._quantize_vector_tiling(
result = self._quantize_vector_tiling(
x,
quant_dtype,
tile_len=quant_tile_shape[1],
return_transpose=return_transpose,
pow_2_scales=pow_2_scales,
eps=eps,
),
)
if munge_scale_shapes:
result = self.scale_munger.munge_scale_shapes_for_backend(
result,
quant_tile_shape,
)
return result
else:
# Quantize block-wise
return self.scale_munger.munge_scale_shapes_for_backend(
self._quantize_square_block_tiling(
result = self._quantize_square_block_tiling(
x,
quant_dtype,
tile_len=quant_tile_shape[0],
return_transpose=return_transpose,
pow_2_scales=pow_2_scales,
eps=eps,
),
)
if munge_scale_shapes:
result = self.scale_munger.munge_scale_shapes_for_backend(
result,
quant_tile_shape,
)
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment