Commit 1b303e91 authored by yuguo's avatar yuguo
Browse files
parents 52ba87a1 735227cd
......@@ -36,6 +36,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 $TE_PATH/tests/pytorch/test_int8_blockwise_gemm_exact.py
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_int8_blockwise_layers.xml $TE_PATH/tests/pytorch/test_int8_blockwise_layers.py || test_fail "test_int8_blockwise_layers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
......
......@@ -385,7 +385,7 @@ class TestFP8RecipeLinearBase:
)
# recipe1
using_fp8_recipe = recipe1 != GetRecipes.none
using_fp8_recipe = recipe1() is not None
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
......@@ -608,7 +608,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
)
# recipe1
using_fp8_recipe = recipe1 != GetRecipes.none
using_fp8_recipe = recipe1() is not None
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear(
......
This diff is collapsed.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Tuple
import math
import os
import pathlib
import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8BlockScaling
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
)
from references.blockwise_quantizer_reference import (
BlockwiseQuantizerReference,
QuantizeResult,
)
from test_float8_current_scaling_exact import (
TestFP8RecipeLinearBase,
TestFP8RecipeLayerNormLinearBase,
)
import logging
# read env variable NVTE_TEST_FLOAT8_BLOCK_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory
TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tensor_dumps"
tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR")
if tensor_dump_dir_env is not None:
TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env)
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available()
class GetRecipes:
@staticmethod
def none():
return None
@staticmethod
def fp8_blockwise():
# return default configs
return Float8BlockScaling()
# FP8 per tesnor current scaling
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase):
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_blockwise),
],
)
def test_fp8_current_scaling_with_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=False,
):
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
assert recipe1 == GetRecipes.none, "Only None recipe is supported for recipe1"
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
dgrad_error=1,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase):
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_blockwise),
],
)
def test_fp8_current_scaling_with_layernorm_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=False,
):
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR,
recipe2,
(batch_size, hidden_size, out_size),
dtype,
use_bias,
"LayerNorm",
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.9,
ln_out_error=0.5,
dgrad_error=1.5,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
......@@ -35,6 +35,7 @@ TE_DType_To_Torch = {
tex.DType.kByte: torch.uint8,
tex.DType.kFloat8E4M3: torch.float8_e4m3fn,
tex.DType.kFloat8E5M2: torch.float8_e5m2,
tex.DType.kInt8: torch.int8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
......
......@@ -10,11 +10,13 @@ import torch
import transformer_engine_torch as tex
from ..constants import TE_DType
from ..utils import get_sm_count, _empty_tensor
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad
from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
__all__ = [
"general_gemm",
"general_grouped_gemm",
......@@ -54,6 +56,67 @@ def general_gemm(
# + "a valid `ub` communicator object."
# )
if int8_simulation_fp8 and (isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase)):
assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation"
assert bias is None, "Bias not supported with int8 simulation"
assert not accumulate, "Accumulation not supported with int8 simulation"
assert ub is None, "User buffer not supported with int8 simulation"
assert ub_type is None, "User buffer type not supported with int8 simulation"
assert extra_output is None, "Extra output not supported with int8 simulation"
assert not bulk_overlap, "Bulk overlap not supported with int8 simulation"
if layout == "TN":
qx_data = (
B._rowwise_data.view(dtype=torch.int8)
)
qw_data = (
A._rowwise_data.view(dtype=torch.int8)
)
ref_scales_x = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128],
output_dtype=out_dtype
)
return y, None, None, None
elif layout == "NN":
qdout_data = (
B._rowwise_data.view(dtype=torch.int8)
)
qw_data = (
A._columnwise_data.view(dtype=torch.int8)
)
ref_scales_dout = B._rowwise_scale_inv
ref_scales_w = A._columnwise_scale_inv
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128],
output_dtype=out_dtype
)
return y, None, None, None
elif layout == "NT":
qdout_data = (
B._columnwise_data.view(dtype=torch.int8)
)
qx_data = (
A._columnwise_data.view(dtype=torch.int8)
)
ref_scales_dout = B._columnwise_scale_inv
ref_scales_x = A._columnwise_scale_inv
y, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, [128, 128],
output_dtype=out_dtype
)
return y, None, None, None
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
if ub is not None:
assert ub_type is not None, "Comm+GEMM overlap requires a valid `comm_type` argument."
if ub_type == tex.CommOverlapType.RS:
......
......@@ -27,16 +27,18 @@ from .constants import dist_group_type
from .utils import get_device_compute_capability
from .jit import jit_fuser
from torch.utils.cpp_extension import IS_HIP_EXTENSION
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
__all__ = ["fp8_autocast", "fp8_model_init"]
if IS_HIP_EXTENSION:
from transformer_engine.pytorch.utils import is_K100_AI, is_BW
def check_fp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if IS_HIP_EXTENSION:
if get_device_compute_capability() == (9, 4):
return True, ""
if (is_K100_AI() or is_BW()) and int8_simulation_fp8:
return True, "DCU turn on fp8 simulation with int8"
else:
return False, "DCU not support fp8 for now"
else:
......@@ -61,7 +63,10 @@ def check_mxfp8_support() -> Tuple[bool, str]:
def check_fp8_block_scaling_support() -> Tuple[bool, str]:
"""Return if fp8 block scaling support is available"""
if IS_HIP_EXTENSION:
return True, ""
if is_K100_AI() or is_BW():
return True, ""
else:
return False, "DCU not support block_scaling fp8 for now"
if (
get_device_compute_capability() >= (9, 0)
and get_device_compute_capability() < (10, 0)
......
......@@ -9,7 +9,7 @@ from typing import Optional, Tuple, Iterable
import math
import torch
import transformer_engine_torch as tex
import os
from transformer_engine_torch import DType as TE_DType
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
......@@ -17,6 +17,7 @@ from ..utils import devices_match, round_up_to_nearest_multiple
aten = torch.ops.aten
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
class Float8BlockQuantizer(Quantizer):
"""Builder class for tensors quantized with current scaling using
......@@ -44,7 +45,7 @@ class Float8BlockQuantizer(Quantizer):
block_scaling_dim: int = 2,
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = fp8_dtype
self.dtype = tex.DType.kInt8 if int8_simulation_fp8 else fp8_dtype
self.block_len = 128
self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon
......
This diff is collapsed.
import torch
import time
from typing import Optional, Type,Any, Dict, List, Tuple
import pandas as pd
import os
import json
import triton
import triton.language as tl
import pandas as pd
import logging
import math
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
@triton.jit
def _per_token_group_quant_int8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
# Stride of input
y_stride,
# Collums of input
N,
# Avoid to divide zero
eps,
# Information for int8
int8_min,
int8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform
per-token-group quantization on a tensor.
This function converts the tensor values into int8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * y_stride
y_q_ptr += g_id * y_stride
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / int8_max
y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def per_token_group_quant_int8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.int8`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
int8_max = iinfo.max
int8_min = iinfo.min
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
BLOCK = triton.next_power_of_2(N)#N是blocksize[1]
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_per_token_group_quant_int8[(M,)](
x,
x_q,
x_s,
group_size,
N,
eps,
int8_min=int8_min,
int8_max=int8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s,BLOCK,num_warps,num_stages,M
def _int8_gemm_helper(m: int,
n: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
best_config:Optional[list] = None):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
input = (torch.randn((m, k), device=device) * 5).to(dtype=out_dtype)
weight = to_int8(torch.randn((n ,k), device=device) * 5)
weight_scale = (torch.randn((math.ceil(n/block_size[0]), math.ceil(k/block_size[1])), device=device,
dtype=torch.float32))
print("input.dtype:",input.dtype)
#print("m:{} n:{} k:{},weight_scale.shape:{}".format(m,n,k,weight_scale.shape))
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale,_,_,_,_ = per_token_group_quant_int8(input_2d, block_size[1])
return q_input, x_scale,weight,weight_scale
def _int8_gemm_helper_b(m: int,
n: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
best_config:Optional[list] = None):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
input = (torch.randn((m, k), device=device) * 5).to(dtype=out_dtype)
weight = to_int8(torch.randn((n ,k), device=device) * 5)
weight_scale = (torch.randn((n, math.ceil(k/block_size[1])), device=device,
dtype=torch.float32))
print("input.dtype:",input.dtype)
#print("m:{} n:{} k:{},weight_scale.shape:{}".format(m,n,k,weight_scale.shape))
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale,_,_,_,_ = per_token_group_quant_int8(input_2d, block_size[1])
return q_input, x_scale,weight,weight_scale
def _int8_gemm_helper_test(m: int,
n: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
best_config:Optional[list] = None):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
input = (torch.randn((m, k), device=device) * 5).to(dtype=out_dtype)
weight = (torch.randn((n ,k), device=device) * 5).t().to(dtype=out_dtype)
print("input.dtype:",input.dtype)
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale,BLOCK,num_warps,num_stages,M= per_token_group_quant_int8(input_2d, block_size[1])
start_time_ = time.time() # 开始计时
for it in range(1000):
q_input, x_scale,_,_,_,_ =per_token_group_quant_int8(input_2d, block_size[1])
torch.cuda.synchronize()
end_time_ = time.time() # 结束计时
elapsed_time = round((end_time_ - start_time_) *1000 ,7)# 计算耗时
print("_time:{} us\n".format(elapsed_time))
return q_input, x_scale,elapsed_time,BLOCK,num_warps,num_stages,M
def main():
m_list=[1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768]
n_list=[576,2048,7168,256,7168,1536,1536]
k_list=[7168,512,1024,7168,128,7168,1536]
block_size=[128,128]
out_dtype=torch.bfloat16
_n=[]
_k=[]
_m=[]
config_blocks=[]
config_num_warps=[]
config_num_stages=[]
config_M=[]
cost_times=[]
for i in range(0,len(k_list),1):
for m in m_list:
print("m:{} n:{} k:{} ".format(m,n_list[i],k_list[i]))
q_input, x_scale,elapsed_time,BLOCK,num_warps,num_stages,M=_int8_gemm_helper_test(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=torch.bfloat16)
cost_times.append(elapsed_time)
_n.append(n_list[i])
_k.append(k_list[i])
_m.append(m)
config_blocks.append(BLOCK)
config_num_warps.append(num_warps)
config_num_stages.append(num_stages)
config_M.append(M)
# 创建一个包含这三个列表的 DataFrame
df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'量化算子耗时': cost_times,'BLOCK':config_blocks,'num_warps':config_num_warps,'config_num_stages':config_num_stages,'config_M':config_M})
# 将 DataFrame 写入 Excel 文件
df.to_excel('output.xlsx', index=False)
print("表格已保存到 output.xlsx 文件中。")
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment