Commit 1b303e91 authored by yuguo's avatar yuguo
Browse files
parents 52ba87a1 735227cd
...@@ -36,6 +36,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE ...@@ -36,6 +36,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 $TE_PATH/tests/pytorch/test_int8_blockwise_gemm_exact.py
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_int8_blockwise_layers.xml $TE_PATH/tests/pytorch/test_int8_blockwise_layers.py || test_fail "test_int8_blockwise_layers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
......
...@@ -385,7 +385,7 @@ class TestFP8RecipeLinearBase: ...@@ -385,7 +385,7 @@ class TestFP8RecipeLinearBase:
) )
# recipe1 # recipe1
using_fp8_recipe = recipe1 != GetRecipes.none using_fp8_recipe = recipe1() is not None
if using_fp8_recipe: if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()): with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
...@@ -608,7 +608,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase): ...@@ -608,7 +608,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
) )
# recipe1 # recipe1
using_fp8_recipe = recipe1 != GetRecipes.none using_fp8_recipe = recipe1() is not None
if using_fp8_recipe: if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()): with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear( y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear(
......
import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
)
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad
from references.blockwise_quantizer_reference import CuBLASScaleMunger
from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm
def fp8_blockwise_gemm_supported() -> bool:
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
return supported
def cublas_gemm_fp8_blockwise_case_fw(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
*,
x_columnwise: bool = False,
w_columnwise: bool = False,
use_bias: bool = False,
use_gelu: bool = False,
use_grad: bool = False,
atol: float = 5e-1,
rtol: float = 5e-1
):
if x_dtype == torch.float8_e5m2 and w_dtype == torch.float8_e5m2:
pytest.skip("FP8 GEMM doesn't support both a and b types being torch.float8_e5m2")
if not (is_x_1d_scaled or is_w_1d_scaled):
pytest.skip("FP8 GEMM doesn't support 2dimensional qtile by 2dimensional qtile")
if not fp8_blockwise_gemm_supported():
pytest.skip("CUDA version does not support blockwise FP8 gemm.")
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
x_shape = (K, M) if x_columnwise else (M, K)
w_shape = (K, N) if w_columnwise else (N, K)
# generate random input and weight
if noise_type == "uniform":
x = torch.rand(x_shape, dtype=torch.bfloat16, device=device) * x_magnitude * 2 - x_magnitude
w = torch.rand(w_shape, dtype=torch.bfloat16, device=device) * w_magnitude * 2 - w_magnitude
elif noise_type == "normal":
x = torch.randn(x_shape, dtype=torch.bfloat16, device=device) * x_magnitude
w = torch.randn(w_shape, dtype=torch.bfloat16, device=device) * w_magnitude
else:
assert False
bf16_out = torch.matmul(x, w.t())
# print(f"x.shape: {x.shape}, w.shape: {w.shape}")
# print("bf16 gemm output: ", bf16_out)
# print("bf16 gemm output shape: ", bf16_out.shape)
# Setup out tensor if accumulate is True
if accumulate:
out = torch.randn((M, N), dtype=out_dtype, device=device) * x_magnitude
else:
out = None
assert not (use_bias and use_grad), "Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128)
w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128)
x_block_scaling_dim = 1 if is_x_1d_scaled else 2
w_block_scaling_dim = 1 if is_w_1d_scaled else 2
x_te_dtype = TE_DType[x_dtype]
w_te_dtype = TE_DType[w_dtype]
x_quantizer = Float8BlockQuantizer(
fp8_dtype=x_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=x_block_scaling_dim,
)
w_quantizer = Float8BlockQuantizer(
fp8_dtype=w_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=w_block_scaling_dim,
)
# Quantize x and w
qx = x_quantizer.make_empty(x_shape, dtype=x_dtype, device=device, requires_grad=False)
qx = x_quantizer.update_quantized(x, qx)
qw = w_quantizer.make_empty(w_shape, dtype=w_dtype, device=device, requires_grad=False)
qw = w_quantizer.update_quantized(w, qw)
if not use_bias:
bias = None
else:
bias = torch.randn((1, N), dtype=torch.bfloat16, device=device)
# Reference GEMM
ref_gemm = CuBLASRefBlockwiseGemm()
scale_decoder = CuBLASScaleMunger()
qx_data = (
qx._columnwise_data.view(dtype=x_dtype)
if x_columnwise
else qx._rowwise_data.view(dtype=x_dtype)
)
qw_data = (
qw._columnwise_data.view(dtype=w_dtype)
if w_columnwise
else qw._rowwise_data.view(dtype=w_dtype)
)
ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
# print(f"qx_data.shape: {qx_data.shape}, qw_data.shape: {qw_data.shape}")
# print(f"ref_scales_x.shape: {ref_scales_x.shape}, ref_scales_w.shape: {ref_scales_w.shape}")
# print(f"ref_scales_x_t.shape: {ref_scales_x.t().shape}")
# print(f"ref_scales_x_columnwise.shape: {qx._columnwise_scale_inv.shape}")
y_ref = ref_gemm.qgemm(
qx=qx_data,
qw=qw_data,
out_dtype=out_dtype,
demunged_sx=CuBLASScaleMunger.demunge_scale_shape_from_backend(
qtensor_shape=(M, K), scales=ref_scales_x, tile_shape=x_quant_tile_shape
),
demunged_sw=CuBLASScaleMunger.demunge_scale_shape_from_backend(
qtensor_shape=(N, K), scales=ref_scales_w, tile_shape=w_quant_tile_shape
),
quant_tile_shape_x=x_quant_tile_shape,
quant_tile_shape_w=w_quant_tile_shape,
bias=bias,
out=out.clone() if accumulate else None,
accumulate=accumulate,
use_split_accumulator=use_split_accumulator,
)
# print("fp8 gemm output: ", y_ref)
# print("fp8 gemm output shape: ", y_ref.shape)
x_te_dtype = TE_DType[torch.int8]
w_te_dtype = TE_DType[torch.int8]
x_quantizer = Float8BlockQuantizer(
fp8_dtype=x_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=x_block_scaling_dim,
)
w_quantizer = Float8BlockQuantizer(
fp8_dtype=w_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=w_block_scaling_dim,
)
# Quantize x and w
qx = x_quantizer.make_empty(x_shape, dtype=torch.int8, device=device, requires_grad=False)
qx = x_quantizer.update_quantized(x, qx)
qw = w_quantizer.make_empty(w_shape, dtype=torch.int8, device=device, requires_grad=False)
qw = w_quantizer.update_quantized(w, qw)
qx_data = (
qx._columnwise_data.view(dtype=torch.int8)
if x_columnwise
else qx._rowwise_data.view(dtype=torch.int8)
)
qw_data = (
qw._columnwise_data.view(dtype=torch.int8)
if w_columnwise
else qw._rowwise_data.view(dtype=torch.int8)
)
ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128],
output_dtype=out_dtype
)
# print("int8 gemm output: ", y)
# print("int8 gemm output shape: ", y.shape)
torch.testing.assert_close(y, bf16_out, atol=atol, rtol=rtol)
torch.testing.assert_close(y_ref, bf16_out, atol=atol, rtol=rtol)
def cublas_gemm_fp8_blockwise_case_bw_xgrad(
dout_dtype,
w_dtype,
dx_dtype,
M,
K,
N,
noise_type,
dout_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_dout_1d_scaled,
is_w_1d_scaled,
*,
dout_columnwise: bool = False,
w_columnwise: bool = True,
use_bias: bool = False,
use_gelu: bool = False,
use_grad: bool = False,
atol: float = 5e-1,
rtol: float = 5e-1
):
if dout_dtype == torch.float8_e5m2 and w_dtype == torch.float8_e5m2:
pytest.skip("FP8 GEMM doesn't support both a and b types being torch.float8_e5m2")
if not (is_dout_1d_scaled or is_w_1d_scaled):
pytest.skip("FP8 GEMM doesn't support 2dimensional qtile by 2dimensional qtile")
if not fp8_blockwise_gemm_supported():
pytest.skip("CUDA version does not support blockwise FP8 gemm.")
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
dout_shape = (M, N)
w_shape = (N, K)
# generate random input and weight
if noise_type == "uniform":
dout = torch.rand(dout_shape, dtype=torch.bfloat16, device=device) * dout_magnitude * 2 - dout_magnitude
w = torch.rand(w_shape, dtype=torch.bfloat16, device=device) * w_magnitude * 2 - w_magnitude
elif noise_type == "normal":
dout = torch.randn(dout_shape, dtype=torch.bfloat16, device=device) * dout_magnitude
w = torch.randn(w_shape, dtype=torch.bfloat16, device=device) * w_magnitude
else:
assert False
bf16_dx = torch.matmul(dout, w)
# print("bf16 gemm dx: ", bf16_dx)
# Setup out tensor if accumulate is True
if accumulate:
dx = torch.randn((M, K), dtype=dx_dtype, device=device) * dout_magnitude
else:
dx = None
assert not (use_bias and use_grad), "Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
dout_quant_tile_shape = (1, 128) if is_dout_1d_scaled else (128, 128)
w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128)
dout_block_scaling_dim = 1 if is_dout_1d_scaled else 2
w_block_scaling_dim = 1 if is_w_1d_scaled else 2
dout_te_dtype = TE_DType[dout_dtype]
w_te_dtype = TE_DType[w_dtype]
dout_quantizer = Float8BlockQuantizer(
fp8_dtype=dout_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=dout_block_scaling_dim,
)
w_quantizer = Float8BlockQuantizer(
fp8_dtype=w_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=w_block_scaling_dim,
)
# Quantize dout and w
qdout = dout_quantizer.make_empty(dout_shape, dtype=dout_dtype, device=device, requires_grad=False)
qdout = dout_quantizer.update_quantized(dout, qdout)
qw = w_quantizer.make_empty(w_shape, dtype=w_dtype, device=device, requires_grad=False)
qw = w_quantizer.update_quantized(w, qw)
if not use_bias:
bias = None
else:
bias = torch.randn((1, N), dtype=torch.bfloat16, device=device)
# Reference GEMM
ref_gemm = CuBLASRefBlockwiseGemm()
scale_decoder = CuBLASScaleMunger()
qdout_data = (
qdout._columnwise_data.view(dtype=dout_dtype)
if dout_columnwise
else qdout._rowwise_data.view(dtype=dout_dtype)
)
qw_data = (
qw._columnwise_data.view(dtype=w_dtype)
if w_columnwise
else qw._rowwise_data.view(dtype=w_dtype)
)
ref_scales_dout = qdout._columnwise_scale_inv if dout_columnwise else qdout._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
y_ref = ref_gemm.qgemm(
qx=qdout_data,
qw=qw_data,
out_dtype=dx_dtype,
demunged_sx=CuBLASScaleMunger.demunge_scale_shape_from_backend(
qtensor_shape=(M, N), scales=ref_scales_dout, tile_shape=dout_quant_tile_shape
),
demunged_sw=CuBLASScaleMunger.demunge_scale_shape_from_backend(
qtensor_shape=(K, N), scales=ref_scales_w, tile_shape=w_quant_tile_shape
),
quant_tile_shape_x=dout_quant_tile_shape,
quant_tile_shape_w=w_quant_tile_shape,
bias=bias,
out=dx.clone() if accumulate else None,
accumulate=accumulate,
use_split_accumulator=use_split_accumulator,
)
# print("fp8 gemm dx: ", y_ref)
dout_te_dtype = TE_DType[torch.int8]
w_te_dtype = TE_DType[torch.int8]
dout_quantizer = Float8BlockQuantizer(
fp8_dtype=dout_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=dout_block_scaling_dim,
)
w_quantizer = Float8BlockQuantizer(
fp8_dtype=w_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=w_block_scaling_dim,
)
# Quantize x and w
qdout = dout_quantizer.make_empty(dout_shape, dtype=torch.int8, device=device, requires_grad=False)
qdout = dout_quantizer.update_quantized(dout, qdout)
qw = w_quantizer.make_empty(w_shape, dtype=torch.int8, device=device, requires_grad=False)
qw = w_quantizer.update_quantized(w, qw)
qdout_data = (
qdout._columnwise_data.view(dtype=torch.int8)
if dout_columnwise
else qdout._rowwise_data.view(dtype=torch.int8)
)
qw_data = (
qw._columnwise_data.view(dtype=torch.int8)
if w_columnwise
else qw._rowwise_data.view(dtype=torch.int8)
)
ref_scales_dout = qdout._columnwise_scale_inv if dout_columnwise else qdout._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128],
output_dtype=dx_dtype
)
# print("int8 gemm dx: ", y)
torch.testing.assert_close(y_ref, bf16_dx, atol=atol, rtol=rtol)
torch.testing.assert_close(y, bf16_dx, atol=atol, rtol=rtol)
def cublas_gemm_fp8_blockwise_case_bw_wgrad(
dout_dtype,
x_dtype,
dw_dtype,
M,
K,
N,
noise_type,
dout_magnitude,
x_magnitude,
accumulate,
use_split_accumulator,
is_dout_1d_scaled,
is_x_1d_scaled,
*,
dout_columnwise: bool = True,
x_columnwise: bool = True,
use_bias: bool = False,
use_gelu: bool = False,
use_grad: bool = False,
atol: float = 5e-1,
rtol: float = 5e-1
):
if dout_dtype == torch.float8_e5m2 and x_dtype == torch.float8_e5m2:
pytest.skip("FP8 GEMM doesn't support both a and b types being torch.float8_e5m2")
if not (is_dout_1d_scaled or is_x_1d_scaled):
pytest.skip("FP8 GEMM doesn't support 2dimensional qtile by 2dimensional qtile")
if not fp8_blockwise_gemm_supported():
pytest.skip("CUDA version does not support blockwise FP8 gemm.")
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
dout_shape = (M, N)
x_shape = (M, K)
# generate random input and weight
if noise_type == "uniform":
dout = torch.rand(dout_shape, dtype=torch.bfloat16, device=device) * dout_magnitude * 2 - dout_magnitude
x = torch.rand(x_shape, dtype=torch.bfloat16, device=device) * x_magnitude * 2 - x_magnitude
elif noise_type == "normal":
dout = torch.randn(dout_shape, dtype=torch.bfloat16, device=device) * dout_magnitude
x = torch.randn(x_shape, dtype=torch.bfloat16, device=device) * x_magnitude
else:
assert False
bf16_dw = torch.matmul(dout.t(), x)
# print("bf16 gemm dw: ", bf16_dw)
# Setup out tensor if accumulate is True
if accumulate:
dw = torch.randn((N, K), dtype=dw_dtype, device=device) * dout_magnitude
else:
dw = None
assert not (use_bias and use_grad), "Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
dout_quant_tile_shape = (1, 128) if is_dout_1d_scaled else (128, 128)
x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128)
dout_block_scaling_dim = 1 if is_dout_1d_scaled else 2
x_block_scaling_dim = 1 if is_x_1d_scaled else 2
dout_te_dtype = TE_DType[dout_dtype]
x_te_dtype = TE_DType[x_dtype]
dout_quantizer = Float8BlockQuantizer(
fp8_dtype=dout_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=dout_block_scaling_dim,
)
x_quantizer = Float8BlockQuantizer(
fp8_dtype=x_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=x_block_scaling_dim,
)
# Quantize dout and w
qdout = dout_quantizer.make_empty(dout_shape, dtype=dout_dtype, device=device, requires_grad=False)
qdout = dout_quantizer.update_quantized(dout, qdout)
qx = x_quantizer.make_empty(x_shape, dtype=x_dtype, device=device, requires_grad=False)
qx = x_quantizer.update_quantized(x, qx)
if not use_bias:
bias = None
else:
bias = torch.randn((1, N), dtype=torch.bfloat16, device=device)
# Reference GEMM
ref_gemm = CuBLASRefBlockwiseGemm()
scale_decoder = CuBLASScaleMunger()
qdout_data = (
qdout._columnwise_data.view(dtype=dout_dtype)
if dout_columnwise
else qdout._rowwise_data.view(dtype=dout_dtype)
)
qx_data = (
qx._columnwise_data.view(dtype=x_dtype)
if x_columnwise
else qx._rowwise_data.view(dtype=x_dtype)
)
ref_scales_dout = qdout._columnwise_scale_inv if dout_columnwise else qdout._rowwise_scale_inv
ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv
y_ref = ref_gemm.qgemm(
qx=qdout_data,
qw=qx_data,
out_dtype=dw_dtype,
demunged_sx=CuBLASScaleMunger.demunge_scale_shape_from_backend(
qtensor_shape=(N, M), scales=ref_scales_dout, tile_shape=dout_quant_tile_shape
),
demunged_sw=CuBLASScaleMunger.demunge_scale_shape_from_backend(
qtensor_shape=(K, M), scales=ref_scales_x, tile_shape=x_quant_tile_shape
),
quant_tile_shape_x=dout_quant_tile_shape,
quant_tile_shape_w=x_quant_tile_shape,
bias=bias,
out=dw.clone() if accumulate else None,
accumulate=accumulate,
use_split_accumulator=use_split_accumulator,
)
# print("fp8 gemm dw: ",y_ref)
dout_te_dtype = TE_DType[torch.int8]
x_te_dtype = TE_DType[torch.int8]
dout_quantizer = Float8BlockQuantizer(
fp8_dtype=dout_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=dout_block_scaling_dim,
)
x_quantizer = Float8BlockQuantizer(
fp8_dtype=x_te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=0.0,
force_pow_2_scales=True,
block_scaling_dim=x_block_scaling_dim,
)
# Quantize x and w
qdout = dout_quantizer.make_empty(dout_shape, dtype=torch.int8, device=device, requires_grad=False)
qdout = dout_quantizer.update_quantized(dout, qdout)
qx = x_quantizer.make_empty(x_shape, dtype=torch.int8, device=device, requires_grad=False)
qx = x_quantizer.update_quantized(x, qx)
qdout_data = (
qdout._columnwise_data.view(dtype=torch.int8)
if dout_columnwise
else qdout._rowwise_data.view(dtype=torch.int8)
)
qx_data = (
qx._columnwise_data.view(dtype=torch.int8)
if x_columnwise
else qx._rowwise_data.view(dtype=torch.int8)
)
ref_scales_dout = qdout._columnwise_scale_inv if dout_columnwise else qdout._rowwise_scale_inv
ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
y, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, [128, 128],
output_dtype=dw_dtype
)
# print("int8 gemm dw: ",y)
torch.testing.assert_close(y_ref, bf16_dw, atol=atol, rtol=rtol)
torch.testing.assert_close(y, bf16_dw, atol=atol, rtol=rtol)
def test_cublas_gemm_fp8_blockwise_fw(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
):
cublas_gemm_fp8_blockwise_case_fw(
x_dtype,
w_dtype,
out_dtype,
M,
K,
N,
noise_type,
x_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_x_1d_scaled,
is_w_1d_scaled,
)
def test_cublas_gemm_fp8_blockwise_bw_xgrad(
dout_dtype,
w_dtype,
dx_dtype,
M,
K,
N,
noise_type,
dout_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_dout_1d_scaled,
is_w_1d_scaled,
):
cublas_gemm_fp8_blockwise_case_bw_xgrad(
dout_dtype,
w_dtype,
dx_dtype,
M,
K,
N,
noise_type,
dout_magnitude,
w_magnitude,
accumulate,
use_split_accumulator,
is_dout_1d_scaled,
is_w_1d_scaled,
)
def test_cublas_gemm_fp8_blockwise_bw_wgrad(
dout_dtype,
x_dtype,
dw_dtype,
M,
K,
N,
noise_type,
dout_magnitude,
x_magnitude,
accumulate,
use_split_accumulator,
is_dout_1d_scaled,
is_x_1d_scaled,
):
cublas_gemm_fp8_blockwise_case_bw_wgrad(
dout_dtype,
x_dtype,
dw_dtype,
M,
K,
N,
noise_type,
dout_magnitude,
x_magnitude,
accumulate,
use_split_accumulator,
is_dout_1d_scaled,
is_x_1d_scaled,
)
if __name__ == "__main__":
test_cublas_gemm_fp8_blockwise_fw(
x_dtype=torch.float8_e4m3fn, # torch.float8_e4m3fnuz if te.e4m3 use funz
w_dtype=torch.float8_e4m3fn, # torch.float8_e4m3fnuz if te.e4m3 use funz
out_dtype=torch.bfloat16,
M=128, # batch_size * seq_len
K=512, # in_feature
N=256, # out_feature
noise_type="normal",
x_magnitude=1e-1,
w_magnitude=1,
accumulate=False,
use_split_accumulator=True,
is_x_1d_scaled=True,
is_w_1d_scaled=False
)
test_cublas_gemm_fp8_blockwise_bw_xgrad(
dout_dtype=torch.float8_e4m3fn, # torch.float8_e4m3fnuz if te.e4m3 use funz
w_dtype=torch.float8_e4m3fn, # torch.float8_e4m3fnuz if te.e4m3 use funz
dx_dtype=torch.bfloat16,
M=128, # batch_size * seq_len
K=512, # in_feature
N=256, # out_feature
noise_type="normal",
dout_magnitude=1e-1,
w_magnitude=1,
accumulate=False,
use_split_accumulator=True,
is_dout_1d_scaled=True,
is_w_1d_scaled=False,
)
test_cublas_gemm_fp8_blockwise_bw_wgrad(
dout_dtype=torch.float8_e4m3fn, # torch.float8_e4m3fnuz if te.e4m3 use funz
x_dtype=torch.float8_e4m3fn, # torch.float8_e4m3fnuz if te.e4m3 use funz
dw_dtype=torch.bfloat16,
M=128, # batch_size * seq_len
K=512, # in_feature
N=256, # out_feature
noise_type="normal",
dout_magnitude=1e-1,
x_magnitude=1,
accumulate=False,
use_split_accumulator=True,
is_dout_1d_scaled=True,
is_x_1d_scaled=True,
)
\ No newline at end of file
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Tuple
import math
import os
import pathlib
import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8BlockScaling
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
)
from references.blockwise_quantizer_reference import (
BlockwiseQuantizerReference,
QuantizeResult,
)
from test_float8_current_scaling_exact import (
TestFP8RecipeLinearBase,
TestFP8RecipeLayerNormLinearBase,
)
import logging
# read env variable NVTE_TEST_FLOAT8_BLOCK_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory
TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tensor_dumps"
tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR")
if tensor_dump_dir_env is not None:
TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env)
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available()
class GetRecipes:
@staticmethod
def none():
return None
@staticmethod
def fp8_blockwise():
# return default configs
return Float8BlockScaling()
# FP8 per tesnor current scaling
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase):
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_blockwise),
],
)
def test_fp8_current_scaling_with_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=False,
):
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
assert recipe1 == GetRecipes.none, "Only None recipe is supported for recipe1"
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
dgrad_error=1,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase):
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_blockwise),
],
)
def test_fp8_current_scaling_with_layernorm_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=False,
):
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR,
recipe2,
(batch_size, hidden_size, out_size),
dtype,
use_bias,
"LayerNorm",
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.9,
ln_out_error=0.5,
dgrad_error=1.5,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
...@@ -35,6 +35,7 @@ TE_DType_To_Torch = { ...@@ -35,6 +35,7 @@ TE_DType_To_Torch = {
tex.DType.kByte: torch.uint8, tex.DType.kByte: torch.uint8,
tex.DType.kFloat8E4M3: torch.float8_e4m3fn, tex.DType.kFloat8E4M3: torch.float8_e4m3fn,
tex.DType.kFloat8E5M2: torch.float8_e5m2, tex.DType.kFloat8E5M2: torch.float8_e5m2,
tex.DType.kInt8: torch.int8,
tex.DType.kInt32: torch.int32, tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32, tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half, tex.DType.kFloat16: torch.half,
......
...@@ -10,11 +10,13 @@ import torch ...@@ -10,11 +10,13 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import get_sm_count, _empty_tensor from ..utils import get_sm_count, _empty_tensor
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
__all__ = [ __all__ = [
"general_gemm", "general_gemm",
"general_grouped_gemm", "general_grouped_gemm",
...@@ -54,6 +56,67 @@ def general_gemm( ...@@ -54,6 +56,67 @@ def general_gemm(
# + "a valid `ub` communicator object." # + "a valid `ub` communicator object."
# ) # )
if int8_simulation_fp8 and (isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase)):
assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation"
assert bias is None, "Bias not supported with int8 simulation"
assert not accumulate, "Accumulation not supported with int8 simulation"
assert ub is None, "User buffer not supported with int8 simulation"
assert ub_type is None, "User buffer type not supported with int8 simulation"
assert extra_output is None, "Extra output not supported with int8 simulation"
assert not bulk_overlap, "Bulk overlap not supported with int8 simulation"
if layout == "TN":
qx_data = (
B._rowwise_data.view(dtype=torch.int8)
)
qw_data = (
A._rowwise_data.view(dtype=torch.int8)
)
ref_scales_x = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128],
output_dtype=out_dtype
)
return y, None, None, None
elif layout == "NN":
qdout_data = (
B._rowwise_data.view(dtype=torch.int8)
)
qw_data = (
A._columnwise_data.view(dtype=torch.int8)
)
ref_scales_dout = B._rowwise_scale_inv
ref_scales_w = A._columnwise_scale_inv
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128],
output_dtype=out_dtype
)
return y, None, None, None
elif layout == "NT":
qdout_data = (
B._columnwise_data.view(dtype=torch.int8)
)
qx_data = (
A._columnwise_data.view(dtype=torch.int8)
)
ref_scales_dout = B._columnwise_scale_inv
ref_scales_x = A._columnwise_scale_inv
y, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, [128, 128],
output_dtype=out_dtype
)
return y, None, None, None
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
if ub is not None: if ub is not None:
assert ub_type is not None, "Comm+GEMM overlap requires a valid `comm_type` argument." assert ub_type is not None, "Comm+GEMM overlap requires a valid `comm_type` argument."
if ub_type == tex.CommOverlapType.RS: if ub_type == tex.CommOverlapType.RS:
......
...@@ -27,16 +27,18 @@ from .constants import dist_group_type ...@@ -27,16 +27,18 @@ from .constants import dist_group_type
from .utils import get_device_compute_capability from .utils import get_device_compute_capability
from .jit import jit_fuser from .jit import jit_fuser
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
__all__ = ["fp8_autocast", "fp8_model_init"] __all__ = ["fp8_autocast", "fp8_model_init"]
if IS_HIP_EXTENSION:
from transformer_engine.pytorch.utils import is_K100_AI, is_BW
def check_fp8_support() -> Tuple[bool, str]: def check_fp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available""" """Return if fp8 support is available"""
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
if get_device_compute_capability() == (9, 4): if (is_K100_AI() or is_BW()) and int8_simulation_fp8:
return True, "" return True, "DCU turn on fp8 simulation with int8"
else: else:
return False, "DCU not support fp8 for now" return False, "DCU not support fp8 for now"
else: else:
...@@ -61,7 +63,10 @@ def check_mxfp8_support() -> Tuple[bool, str]: ...@@ -61,7 +63,10 @@ def check_mxfp8_support() -> Tuple[bool, str]:
def check_fp8_block_scaling_support() -> Tuple[bool, str]: def check_fp8_block_scaling_support() -> Tuple[bool, str]:
"""Return if fp8 block scaling support is available""" """Return if fp8 block scaling support is available"""
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
return True, "" if is_K100_AI() or is_BW():
return True, ""
else:
return False, "DCU not support block_scaling fp8 for now"
if ( if (
get_device_compute_capability() >= (9, 0) get_device_compute_capability() >= (9, 0)
and get_device_compute_capability() < (10, 0) and get_device_compute_capability() < (10, 0)
......
...@@ -9,7 +9,7 @@ from typing import Optional, Tuple, Iterable ...@@ -9,7 +9,7 @@ from typing import Optional, Tuple, Iterable
import math import math
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
import os
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
...@@ -17,6 +17,7 @@ from ..utils import devices_match, round_up_to_nearest_multiple ...@@ -17,6 +17,7 @@ from ..utils import devices_match, round_up_to_nearest_multiple
aten = torch.ops.aten aten = torch.ops.aten
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
class Float8BlockQuantizer(Quantizer): class Float8BlockQuantizer(Quantizer):
"""Builder class for tensors quantized with current scaling using """Builder class for tensors quantized with current scaling using
...@@ -44,7 +45,7 @@ class Float8BlockQuantizer(Quantizer): ...@@ -44,7 +45,7 @@ class Float8BlockQuantizer(Quantizer):
block_scaling_dim: int = 2, block_scaling_dim: int = 2,
) -> None: ) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise) super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = fp8_dtype self.dtype = tex.DType.kInt8 if int8_simulation_fp8 else fp8_dtype
self.block_len = 128 self.block_len = 128
self.force_pow_2_scales = force_pow_2_scales self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon self.amax_epsilon = amax_epsilon
......
import torch
import time
from typing import Optional, Type,Any, Dict, List, Tuple
import pandas as pd
import os
import json
import triton
import triton.language as tl
import pandas as pd
from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_helper
import functools
import logging
logger = logging.getLogger(__name__)
device_name=torch.cuda.get_device_properties('cuda').name.replace(" ","_")
tuning_full_space = False
# tuning_full_space = True
def get_full_tuning_space():
configs = []
if not tuning_full_space:
return configs
block_m_range = [16, 32, 64]
block_n_range = [16, 32, 64, 128]
block_k_range = [32, 64, 128]
num_warps_range = [4, 8]
group_m_range = [2, 4, 8]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range = [0, 1, 2]
for block_m in block_m_range:
for block_n in block_n_range:
for block_k in block_k_range:
for num_warps in num_warps_range:
for group_m in group_m_range:
for num_stages in num_stage_range:
configs.append(triton.Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m}, num_stages=num_stages, num_warps=num_warps, enable_mmacfuse=2))
return configs
@triton.autotune(
configs= get_full_tuning_space() if tuning_full_space else [
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 2, 'kpack':2}, num_stages=2, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 2,}, num_stages=1, num_warps=4, enable_mmacfuse=2),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8,}, num_stages=1, num_warps=4, enable_mmacfuse=2),
],
key=['M', 'N', 'K'],
# reset_to_zero=['c_ptr']
)
@triton.jit
def _w8a8_block_int8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization,
and store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
# offs_bsn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_bsn = pid_n * BLOCK_SIZE_N // group_n
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# a_ptrs = A + (offs_am[:, None] * stride_am)
# b_ptrs = B + (offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
# offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@functools.lru_cache
def get_w8a8_block_int8_configs(
N: int, K: int, block_n: int, block_k: int
) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name=torch.cuda.get_device_properties('cuda').name.replace(" ","_")
json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json" # noqa: E501
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
"Using configuration from %s for W8A8 Block INT8 kernel.",
config_file_path,
)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
(
"Using default W8A8 Block INT8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s"
),
config_file_path,
)
return None
def w8a8_block_int8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
best_config:Optional[dict]=None
) -> torch.Tensor:
"""matrix multiplication with block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
# print(f"A.shape[:-1] : {A.shape[:-1]}, As.shape[:-1]: {As.shape[:-1]}")
# assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[0]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
# assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
# configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
# if configs:
# # If an optimal configuration map has been found, look up the
# # optimal config
# config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
# else:
if best_config:
config=best_config
else:
#print("best config has not found!")
# config = {
# "BLOCK_SIZE_M": 32, #64
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
#print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
if M<=64:
config = {
"BLOCK_SIZE_M": 16, #64
"BLOCK_SIZE_N":block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
elif M<128:
config = {
"BLOCK_SIZE_M": 32, #64
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
elif M<=256:
config = {
"BLOCK_SIZE_M": 64, #64
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
else :
config = {
"BLOCK_SIZE_M": 64, #64
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 0,
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
#print("config:",config)
# print(f"zhenggf, A.shape:{A.shape}, B.shape:{B.shape}")
# print(f"zhenggf, A.stride(-2):{A.stride(-2)}, A.stride(-1):{A.stride(-1)}, B.stride(1):{B.stride(1)}, B.stride(0):{B.stride(0)}")
# print(f"zhenggf, As.stride(-2):{As.stride(-2)}, As.stride(-1):{As.stride(-1)}, Bs.stride(1):{Bs.stride(1)}, Bs.stride(0):{Bs.stride(0)}")
# print(f"zhenggf, As.stride(-2):{As.stride(-2)}, As.stride(-1):{As.stride(-1)}, Bs.stride(1):{Bs.stride(1)}, Bs.stride(0):{Bs.stride(0)}")
# As = As.permute(1, 0).contiguous()
_w8a8_block_int8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
# As.stride(-2),
# As.stride(-1),
As.stride(1),
As.stride(0),
Bs.stride(1),
Bs.stride(0),
# **config,
)
config = _w8a8_block_int8_matmul.best_config
return C,config
def apply_w8a8_block_int8_linear_helper(m: int,
n: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
bias: Optional[torch.Tensor] = None,
best_config:Optional[dict]=None):
q_input, x_scale,weight,weight_scale=_int8_gemm_helper(m=m,n=n,k=k,out_dtype=out_dtype,device=device,block_size=block_size)
print(f"zhenggf, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}")
torch_output = native_w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size)
x_scale = x_scale.permute(1, 0).contiguous()
output,config = w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size,
output_dtype=out_dtype,
best_config=best_config
)
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
print("triton 精度检查不合格!!!")
else:
print("triton 精度检查合格")
# unit test end
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for it in range(1000):
output,_ = w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size,
output_dtype=out_dtype,
best_config=best_config
)
torch.cuda.synchronize()
start_time_ = time.time() # 开始计时
g.replay()
torch.cuda.synchronize()
end_time_ = time.time() # 结束计时
elapsed_time = round((end_time_ - start_time_) *1000 ,7)# 计算耗时
print("_time:{} us\n".format(elapsed_time))
quantiles = [0.5, 0.2, 0.8]
gpu_costtime = triton.testing.do_bench(lambda:w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size,output_dtype=out_dtype,best_config=best_config),quantiles=None, return_mode="mean")*1000
if bias is not None:
output = output + bias
return output.to(dtype=out_dtype),elapsed_time,gpu_costtime,config
def get_triton_cache(file_path,n,k,block_n,block_k):
#会将所报错的json文件以字典的形式return出来
#先读取指定的文件,该文件地址不存在则会读默认路径
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
cachedata = {}
# 写入空数据到新的JSON文件
with open(file_path, 'w') as file:
json.dump(cachedata, file)
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_value={
'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]),
'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]),
'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]),
'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]),
'num_stages':int(sub_value['num_stages']),
'num_warps':int(sub_value['num_warps']),
# 'kpack':int(sub_value['kpack']),
'enable_mmacfuse':int(2),
}
configs_dict[configs_key]=configs_value
return configs_dict
def getspec_config(configs_dict,m,n,k,block_n,block_k):
if f"{m}_{n}_{k}_block[{block_n},{block_k}]" in configs_dict:
return configs_dict[f"{m}_{n}_{k}_block[{block_n},{block_k}]"]
else:
return None
# For test
def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.bfloat16):
"""This function performs matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N,)
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
B_tiles = [
[
B[
j * block_n : min((j + 1) * block_n, N),
i * block_k : min((i + 1) * block_k, K),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
return C
def main():
m1=[item if item < 17 else 1 << (item - 27) for item in range(1, 17)]
m2=[item<<2 if item <17 else (item - 8)<<3 for item in range(5, 29)]
m3=[2<<(item) for item in range(7, 13)]
m_list = m1+m2+m3
n_list=[576,2048,7168,256,7168,1536,1536,2304,7168]
k_list=[7168,512,1024,7168,128,7168,1536,7168,1152]
m_list = [8192]
n_list=[7168]
k_list=[1152]
block_size=[128, 128]
out_dtype=torch.bfloat16
_n=[]
_k=[]
_m=[]
_configs_block_m=[]
_configs_block_n=[]
_configs_block_k=[]
_configs_block_group_m=[]
_configs_block_num_warps=[]
_configs_block_num_stages=[]
_configs_kpack=[]
cost_times=[]
gpu_costtimes=[]
device_name=torch.cuda.get_device_properties('cuda').name.replace(" ", "_")
for i in range(0,len(k_list),1):
for m in m_list:
print("m:{} n:{} k:{} ".format(m,n_list[i],k_list[i]))
best_config = []
output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
cost_times.append(elapsed_time)
gpu_costtimes.append(gpu_costtime)
_n.append(n_list[i])
_k.append(k_list[i])
_m.append(m)
print(f"zhenggf, {config}")
print(f"zhenggf, {config.kwargs}")
_configs_block_m.append(config.kwargs['BLOCK_SIZE_M'])
_configs_block_n.append(config.kwargs['BLOCK_SIZE_N'])
_configs_block_k.append(config.kwargs['BLOCK_SIZE_K'])
_configs_block_group_m.append(config.kwargs['GROUP_SIZE_M'])
_configs_block_num_warps.append(config.num_warps)
_configs_block_num_stages.append(config.num_stages)
# _configs_kpack.append(config['kpack'])
# 创建一个包含这三个列表的 DataFrame
df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes,
'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k,
'GROUP_SIZE_M':_configs_block_group_m,'num_warps':_configs_block_num_warps,'num_stages':_configs_block_num_stages,#'kpack':_configs_kpack
})
# 将 DataFrame 写入 Excel 文件
df.to_excel('gemmoutput.xlsx', index=False)
print("表格已保存到 gemmoutput.xlsx 文件中。")
if __name__ == "__main__":
main()
import torch
import time
from typing import Optional, Type,Any, Dict, List, Tuple
import pandas as pd
import os
import json
import triton
import triton.language as tl
import pandas as pd
from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_helper_b
import functools
import logging
logger = logging.getLogger(__name__)
device_name=torch.cuda.get_device_properties('cuda').name.replace(" ","_")
tuning_full_space = False
# tuning_full_space = True
def get_full_tuning_space():
configs = []
if not tuning_full_space:
return configs
block_m_range = [16, 32, 64]
block_n_range = [16, 32, 64, 128]
block_k_range = [32, 64, 128]
num_warps_range = [4, 8]
group_m_range = [2, 4, 8]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range = [0, 1, 2]
for block_m in block_m_range:
for block_n in block_n_range:
for block_k in block_k_range:
for num_warps in num_warps_range:
for group_m in group_m_range:
for num_stages in num_stage_range:
configs.append(triton.Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m}, num_stages=num_stages, num_warps=num_warps, enable_mmacfuse=2))
return configs
@triton.autotune(
configs= get_full_tuning_space() if tuning_full_space else [
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 2, 'kpack':2}, num_stages=2, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 2,}, num_stages=1, num_warps=4, enable_mmacfuse=2),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8,}, num_stages=1, num_warps=4, enable_mmacfuse=2),
],
key=['M', 'N', 'K'],
# reset_to_zero=['c_ptr']
)
@triton.jit
def _w8a8_block_int8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization,
and store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_bsn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
# offs_bsn = pid_n * BLOCK_SIZE_N // group_n
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
# offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@functools.lru_cache
def get_w8a8_block_int8_configs(
N: int, K: int, block_n: int, block_k: int
) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name=torch.cuda.get_device_properties('cuda').name.replace(" ","_")
json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json" # noqa: E501
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
"Using configuration from %s for W8A8 Block INT8 kernel.",
config_file_path,
)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
(
"Using default W8A8 Block INT8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s"
),
config_file_path,
)
return None
def w8a8_block_int8_matmul_wgrad(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
best_config:Optional[dict]=None
) -> torch.Tensor:
"""matrix multiplication with block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
# print(f"A.shape[:-1] : {A.shape[:-1]}, As.shape[:-1]: {As.shape[:-1]}")
# assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[0]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
# assert triton.cdiv(N, block_n) == Bs.shape[0]
# assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
# configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
# if configs:
# # If an optimal configuration map has been found, look up the
# # optimal config
# config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
# else:
if best_config:
config=best_config
else:
#print("best config has not found!")
# config = {
# "BLOCK_SIZE_M": 32, #64
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
#print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
if M<=64:
config = {
"BLOCK_SIZE_M": 16, #64
"BLOCK_SIZE_N":block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
elif M<128:
config = {
"BLOCK_SIZE_M": 32, #64
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
elif M<=256:
config = {
"BLOCK_SIZE_M": 64, #64
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
else :
config = {
"BLOCK_SIZE_M": 64, #64
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 0,
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
#print("config:",config)
# print(f"zhenggf, A.shape:{A.shape}, B.shape:{B.shape}")
# print(f"zhenggf, A.stride(-2):{A.stride(-2)}, A.stride(-1):{A.stride(-1)}, B.stride(1):{B.stride(1)}, B.stride(0):{B.stride(0)}")
# print(f"zhenggf, As.stride(-2):{As.stride(-2)}, As.stride(-1):{As.stride(-1)}, Bs.stride(1):{Bs.stride(1)}, Bs.stride(0):{Bs.stride(0)}")
# print(f"zhenggf, As.stride(-2):{As.stride(-2)}, As.stride(-1):{As.stride(-1)}, Bs.stride(1):{Bs.stride(1)}, Bs.stride(0):{Bs.stride(0)}")
# As = As.permute(1, 0).contiguous()
_w8a8_block_int8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
# As.stride(-2),
# As.stride(-1),
As.stride(1),
As.stride(0),
Bs.stride(-2),
Bs.stride(-1),
# Bs.stride(1),
# Bs.stride(0),
# **config,
)
config = _w8a8_block_int8_matmul.best_config
return C,config
def apply_w8a8_block_int8_linear_helper(m: int,
n: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
bias: Optional[torch.Tensor] = None,
best_config:Optional[dict]=None):
q_input, x_scale,weight,weight_scale=_int8_gemm_helper_b(m=m,n=n,k=k,out_dtype=out_dtype,device=device,block_size=block_size)
print(f"zhenggf, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}")
torch_output = native_w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size)
x_scale = x_scale.permute(1, 0).contiguous()
weight_scale = weight_scale.permute(1, 0).contiguous()
print(f"zhenggf 转置后传递给triton kernel, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}")
output,config = w8a8_block_int8_matmul_wgrad(
q_input, weight, x_scale, weight_scale, block_size,
output_dtype=out_dtype,
best_config=best_config
)
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
print("triton 精度检查不合格!!!")
else:
print("triton 精度检查合格")
# unit test end
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for it in range(1000):
output,_ = w8a8_block_int8_matmul_wgrad(
q_input, weight, x_scale, weight_scale, block_size,
output_dtype=out_dtype,
best_config=best_config
)
torch.cuda.synchronize()
start_time_ = time.time() # 开始计时
g.replay()
torch.cuda.synchronize()
end_time_ = time.time() # 结束计时
elapsed_time = round((end_time_ - start_time_) *1000 ,7)# 计算耗时
print("_time:{} us\n".format(elapsed_time))
quantiles = [0.5, 0.2, 0.8]
gpu_costtime = triton.testing.do_bench(lambda:w8a8_block_int8_matmul_wgrad(q_input, weight, x_scale, weight_scale, block_size,output_dtype=out_dtype,best_config=best_config),quantiles=None, return_mode="mean")*1000
if bias is not None:
output = output + bias
return output.to(dtype=out_dtype),elapsed_time,gpu_costtime,config
def get_triton_cache(file_path,n,k,block_n,block_k):
#会将所报错的json文件以字典的形式return出来
#先读取指定的文件,该文件地址不存在则会读默认路径
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
cachedata = {}
# 写入空数据到新的JSON文件
with open(file_path, 'w') as file:
json.dump(cachedata, file)
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_value={
'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]),
'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]),
'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]),
'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]),
'num_stages':int(sub_value['num_stages']),
'num_warps':int(sub_value['num_warps']),
# 'kpack':int(sub_value['kpack']),
'enable_mmacfuse':int(2),
}
configs_dict[configs_key]=configs_value
return configs_dict
def getspec_config(configs_dict,m,n,k,block_n,block_k):
if f"{m}_{n}_{k}_block[{block_n},{block_k}]" in configs_dict:
return configs_dict[f"{m}_{n}_{k}_block[{block_n},{block_k}]"]
else:
return None
# For test
def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.bfloat16):
"""This function performs matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N,)
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
# assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
B_tiles = [
[
B[
j * block_n : min((j + 1) * block_n, N),
i * block_k : min((i + 1) * block_k, K),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
Bs_tiles = [Bs[:, i : i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs_tiles[i].t()[:,j*block_n:min((j + 1) * block_n, N)]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
return C
def main():
m1=[item if item < 17 else 1 << (item - 27) for item in range(1, 17)]
m2=[item<<2 if item <17 else (item - 8)<<3 for item in range(5, 29)]
m3=[2<<(item) for item in range(7, 13)]
m_list = m1+m2+m3
n_list=[576,2048,7168,256,7168,1536,1536,2304,7168]
k_list=[7168,512,1024,7168,128,7168,1536,7168,1152]
m_list = [8192]
n_list=[7168]
k_list=[1152]
block_size=[128, 128]
out_dtype=torch.bfloat16
_n=[]
_k=[]
_m=[]
_configs_block_m=[]
_configs_block_n=[]
_configs_block_k=[]
_configs_block_group_m=[]
_configs_block_num_warps=[]
_configs_block_num_stages=[]
_configs_kpack=[]
cost_times=[]
gpu_costtimes=[]
device_name=torch.cuda.get_device_properties('cuda').name.replace(" ", "_")
for i in range(0,len(k_list),1):
for m in m_list:
print("m:{} n:{} k:{} ".format(m,n_list[i],k_list[i]))
best_config = []
output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
cost_times.append(elapsed_time)
gpu_costtimes.append(gpu_costtime)
_n.append(n_list[i])
_k.append(k_list[i])
_m.append(m)
print(f"zhenggf, {config}")
print(f"zhenggf, {config.kwargs}")
_configs_block_m.append(config.kwargs['BLOCK_SIZE_M'])
_configs_block_n.append(config.kwargs['BLOCK_SIZE_N'])
_configs_block_k.append(config.kwargs['BLOCK_SIZE_K'])
_configs_block_group_m.append(config.kwargs['GROUP_SIZE_M'])
_configs_block_num_warps.append(config.num_warps)
_configs_block_num_stages.append(config.num_stages)
# _configs_kpack.append(config['kpack'])
# 创建一个包含这三个列表的 DataFrame
df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes,
'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k,
'GROUP_SIZE_M':_configs_block_group_m,'num_warps':_configs_block_num_warps,'num_stages':_configs_block_num_stages,#'kpack':_configs_kpack
})
# 将 DataFrame 写入 Excel 文件
df.to_excel('gemmoutput.xlsx', index=False)
print("表格已保存到 gemmoutput.xlsx 文件中。")
if __name__ == "__main__":
main()
import torch
import time
from typing import Optional, Type,Any, Dict, List, Tuple
import pandas as pd
import os
import json
import triton
import triton.language as tl
import pandas as pd
import logging
import math
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
@triton.jit
def _per_token_group_quant_int8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
# Stride of input
y_stride,
# Collums of input
N,
# Avoid to divide zero
eps,
# Information for int8
int8_min,
int8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform
per-token-group quantization on a tensor.
This function converts the tensor values into int8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * y_stride
y_q_ptr += g_id * y_stride
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / int8_max
y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def per_token_group_quant_int8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.int8`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
int8_max = iinfo.max
int8_min = iinfo.min
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
BLOCK = triton.next_power_of_2(N)#N是blocksize[1]
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_per_token_group_quant_int8[(M,)](
x,
x_q,
x_s,
group_size,
N,
eps,
int8_min=int8_min,
int8_max=int8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s,BLOCK,num_warps,num_stages,M
def _int8_gemm_helper(m: int,
n: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
best_config:Optional[list] = None):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
input = (torch.randn((m, k), device=device) * 5).to(dtype=out_dtype)
weight = to_int8(torch.randn((n ,k), device=device) * 5)
weight_scale = (torch.randn((math.ceil(n/block_size[0]), math.ceil(k/block_size[1])), device=device,
dtype=torch.float32))
print("input.dtype:",input.dtype)
#print("m:{} n:{} k:{},weight_scale.shape:{}".format(m,n,k,weight_scale.shape))
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale,_,_,_,_ = per_token_group_quant_int8(input_2d, block_size[1])
return q_input, x_scale,weight,weight_scale
def _int8_gemm_helper_b(m: int,
n: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
best_config:Optional[list] = None):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
input = (torch.randn((m, k), device=device) * 5).to(dtype=out_dtype)
weight = to_int8(torch.randn((n ,k), device=device) * 5)
weight_scale = (torch.randn((n, math.ceil(k/block_size[1])), device=device,
dtype=torch.float32))
print("input.dtype:",input.dtype)
#print("m:{} n:{} k:{},weight_scale.shape:{}".format(m,n,k,weight_scale.shape))
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale,_,_,_,_ = per_token_group_quant_int8(input_2d, block_size[1])
return q_input, x_scale,weight,weight_scale
def _int8_gemm_helper_test(m: int,
n: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
best_config:Optional[list] = None):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
input = (torch.randn((m, k), device=device) * 5).to(dtype=out_dtype)
weight = (torch.randn((n ,k), device=device) * 5).t().to(dtype=out_dtype)
print("input.dtype:",input.dtype)
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale,BLOCK,num_warps,num_stages,M= per_token_group_quant_int8(input_2d, block_size[1])
start_time_ = time.time() # 开始计时
for it in range(1000):
q_input, x_scale,_,_,_,_ =per_token_group_quant_int8(input_2d, block_size[1])
torch.cuda.synchronize()
end_time_ = time.time() # 结束计时
elapsed_time = round((end_time_ - start_time_) *1000 ,7)# 计算耗时
print("_time:{} us\n".format(elapsed_time))
return q_input, x_scale,elapsed_time,BLOCK,num_warps,num_stages,M
def main():
m_list=[1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768]
n_list=[576,2048,7168,256,7168,1536,1536]
k_list=[7168,512,1024,7168,128,7168,1536]
block_size=[128,128]
out_dtype=torch.bfloat16
_n=[]
_k=[]
_m=[]
config_blocks=[]
config_num_warps=[]
config_num_stages=[]
config_M=[]
cost_times=[]
for i in range(0,len(k_list),1):
for m in m_list:
print("m:{} n:{} k:{} ".format(m,n_list[i],k_list[i]))
q_input, x_scale,elapsed_time,BLOCK,num_warps,num_stages,M=_int8_gemm_helper_test(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=torch.bfloat16)
cost_times.append(elapsed_time)
_n.append(n_list[i])
_k.append(k_list[i])
_m.append(m)
config_blocks.append(BLOCK)
config_num_warps.append(num_warps)
config_num_stages.append(num_stages)
config_M.append(M)
# 创建一个包含这三个列表的 DataFrame
df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'量化算子耗时': cost_times,'BLOCK':config_blocks,'num_warps':config_num_warps,'config_num_stages':config_num_stages,'config_M':config_M})
# 将 DataFrame 写入 Excel 文件
df.to_excel('output.xlsx', index=False)
print("表格已保存到 output.xlsx 文件中。")
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment