Commit a5892578 authored by yuguo's avatar yuguo
Browse files
parents f9faa7ca 793e0103
......@@ -34,6 +34,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tes
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
# channelwise int8 test
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s test_float8_current_scaling_exact.py
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 $TE_PATH/tests/pytorch/test_int8_blockwise_gemm_exact.py
......
......@@ -14,6 +14,7 @@ import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8CurrentScaling
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype
from transformer_engine.pytorch.fp8 import int8_simulation_fp8
# read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory
......@@ -715,7 +716,7 @@ class TestFP8CurrentScalingRecipeLinear(TestFP8RecipeLinearBase):
hidden_size,
out_size,
dtype,
use_bias=True,
use_bias=False if int8_simulation_fp8 else True,
):
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
......@@ -775,7 +776,7 @@ class TestFP8CurrentScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBas
hidden_size,
out_size,
dtype,
use_bias=True,
use_bias=False if int8_simulation_fp8 else True,
):
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
......
......@@ -557,4 +557,4 @@ te_dw = tex.generic_batchgemm(
)[0]
# print("te_dw.shape: ", te_dw.view(b, -1, te_dw.size(-1)).shape)
# print("te_dw: ", te_dw.view(b, -1, te_dw.size(-1)))
torch.testing.assert_close(te_dw, batched_int32_dw, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(te_dw.view(b, -1, te_dw.size(-1)), batched_int32_dw, atol=0, rtol=0)
......@@ -42,7 +42,7 @@ extern "C" {
void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = -1);
int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = 0);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
......@@ -77,7 +77,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
bool transb, bool grad, NVTETensor workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split,
int n_split, bool gemm_producer, const NVTETensor counter,
cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = -1);
cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = 0);
/*! \brief Compute multiple pairs of matrix multiplication, potentially fused with other operations,
* on multiple streams.
......
......@@ -638,6 +638,9 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
}
int nvte_is_non_tn_fp8_gemm_supported() {
#if USE_ROCM
return true;
#else
int deviceComputeCapability =
transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device());
......@@ -645,4 +648,5 @@ int nvte_is_non_tn_fp8_gemm_supported() {
// (remove the note once it's done.)
return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
#endif
}
......@@ -9,14 +9,24 @@ import os
import torch
import transformer_engine_torch as tex
import w8a8_matmul_extension
from ..constants import TE_DType
from ..constants import TE_DType, TE_DType_To_Torch
from ..utils import get_sm_count, _empty_tensor
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad, w8a8_block_int8_matmul_wgrad_batched, w8a8_block_int8_matmul_wgrad_batched_native
from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_quant_fp8_to_int8,
per_token_quant_fp8_to_int8_v2,
per_token_quant_fp8_to_int8_opt,
channelwise_dequantize,
channelwise_dequantize_transA,
channelwise_dequantize_transA_float,
channelwise_dequantize_transB,
channelwise_dequantize_transA_add,
channelwise_dequantize_transA_float_add)
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
__all__ = [
......@@ -165,6 +175,106 @@ def general_gemm(
):
raise RuntimeError("GEMM with Float8BlockwiseQTensor requires GEMM_READY format")
if int8_simulation_fp8 and (isinstance(A, Float8TensorBase) or isinstance(B, Float8TensorBase)):
assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation"
assert bias is None, "Bias not supported with int8 simulation"
assert ub is None, "User buffer not supported with int8 simulation"
assert ub_type is None, "User buffer type not supported with int8 simulation"
assert extra_output is None, "Extra output not supported with int8 simulation"
assert not bulk_overlap, "Bulk overlap not supported with int8 simulation"
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
if layout == "TN":
assert out_dtype is torch.bfloat16
x_int8, x_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
y_int32 = tex.generic_gemm(
w_int8,
transa,
x_int8,
transb,
None,
quantization_params,
TE_DType[torch.int32],
bias,
bias_dtype,
gelu,
gelu_in,
grad, # grad
workspace,
workspace.shape[0],
False,
use_split_accumulator,
)[0]
y = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
return y, None, None, None
elif layout == "NN":
assert out_dtype is torch.bfloat16
dy_int8, dy_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
dx_int32 = tex.generic_gemm(
w_int8,
transa,
dy_int8,
transb,
None,
quantization_params,
TE_DType[torch.int32],
bias,
bias_dtype,
gelu,
gelu_in,
grad, # grad
workspace,
workspace.shape[0],
False,
use_split_accumulator,
)[0]
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
return dx, None, None, None
elif layout == "NT":
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
dw_int32 = tex.generic_gemm(
x_int8,
transa,
dy_int8,
transb,
None,
quantization_params,
TE_DType[torch.int32],
bias,
bias_dtype,
gelu,
gelu_in,
grad, # grad
workspace,
workspace.shape[0],
False,
use_split_accumulator,
)[0]
if out_dtype is torch.bfloat16:
if accumulate:
out = channelwise_dequantize_transA_add(dy_scales, x_scales, dw_int32, out)
else:
out = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
else:
if accumulate:
out = channelwise_dequantize_transA_float_add(dy_scales, x_scales, dw_int32, out)
else:
out = channelwise_dequantize_transA_float(dy_scales, x_scales, dw_int32)
return out, None, None, None
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
args = (
A,
transa, # transa
......@@ -311,6 +421,165 @@ def general_grouped_gemm(
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)):
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm."
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
if layout == "TN":
assert out_dtype is torch.bfloat16
qx_data_list = []
w_data_list = []
scales_x_list = []
scales_w_list = []
for b in B:
x_int8, x_scales = per_token_quant_fp8_to_int8(b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]), b._scale_inv, False)
qx_data_list.append(x_int8)
scales_x_list.append(x_scales)
for a in A:
w_int8, w_scales = per_token_quant_fp8_to_int8(a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]), a._scale_inv, False)
w_data_list.append(w_int8)
scales_w_list.append(w_scales)
num_gemms = len(A)
seq_len = sum(m_splits) // num_gemms
qx_data = torch.stack(qx_data_list).contiguous()
w_data = torch.stack(w_data_list).contiguous()
y_int32 = torch.empty((num_gemms, seq_len, out[0].size(-1)), dtype=torch.int32, device="cuda")
y_int32 = tex.generic_batchgemm(
w_data.view(-1, w_data.size(-1)),
transa,
qx_data.view(-1, qx_data.size(-1)),
transb,
y_int32.view(-1, y_int32.size(-1)),
num_gemms,
None,
TE_DType[torch.int32],
bias[0],
bias_dtype,
gelu,
gelu_input[0],
grad, # grad
workspaces[0],
workspaces[0].shape[0],
False,
use_split_accumulator,
)[0]
out[0] = out[0].view(num_gemms, seq_len, out[0].size(-1))
for i in range(num_gemms):
out[0][i] = channelwise_dequantize_transB(scales_x_list[i], scales_w_list[i], y_int32[i])
return out.view(-1, out[0].size(-1)), bias, gelu_input
elif layout == "NN":
assert out_dtype is torch.bfloat16
qdout_data_list = []
w_data_list = []
scales_dout_list = []
scales_w_list = []
for b in B:
dy_int8, dy_scales = per_token_quant_fp8_to_int8(b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]), b._scale_inv, False)
qdout_data_list.append(dy_int8)
scales_dout_list.append(dy_scales)
for a in A:
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]), a._scale_inv, False)
w_data_list.append(w_int8)
scales_w_list.append(w_scales)
num_gemms = len(A)
seq_len = sum(m_splits) // num_gemms
qdout_data = torch.stack(qdout_data_list).contiguous()
w_data = torch.stack(w_data_list).contiguous()
dx_int32 = torch.empty((num_gemms, seq_len, out[0].size(-1)), dtype=torch.int32, device="cuda")
dx_int32 = tex.generic_batchgemm(
w_data.view(-1, w_data.size(-1)),
transa,
qdout_data.view(-1, qdout_data.size(-1)),
transb,
dx_int32.view(-1, dx_int32.size(-1)),
num_gemms,
None,
TE_DType[torch.int32],
bias[0],
bias_dtype,
gelu,
gelu_input[0],
grad, # grad
workspaces[0],
workspaces[0].shape[0],
False,
use_split_accumulator,
)[0]
out[0] = out[0].view(num_gemms, seq_len, out[0].size(-1))
for i in range(num_gemms):
out[0][i] = channelwise_dequantize(scales_dout_list[i], scales_w_list[i], dx_int32[i])
return out, bias, gelu_input
elif layout == "NT":
qdout_data_list = []
qx_data_list = []
scales_dout_list = []
scales_x_list = []
for b in B:
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]), b._scale_inv, False)
qdout_data_list.append(dy_int8)
scales_dout_list.append(dy_scales)
for a in A:
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]), a._scale_inv, False)
qx_data_list.append(x_int8)
scales_x_list.append(x_scales)
num_gemms = len(A)
qdout_data = torch.stack(qdout_data_list).contiguous()
qx_data = torch.stack(qx_data_list).contiguous()
dw_int32 = torch.empty((num_gemms, qdout_data.size(-1), qx_data.size(-1)), dtype=torch.int32, device="cuda")
dw_int32 = tex.generic_batchgemm(
qx_data.view(-1, qx_data.size(-1)),
transa,
qdout_data.view(-1, qdout_data.size(-1)),
transb,
dw_int32.view(-1, dw_int32.size(-1)),
num_gemms,
None,
TE_DType[torch.int32],
bias[0],
bias_dtype,
gelu,
gelu_input[0],
grad, # grad
workspaces[0],
workspaces[0].shape[0],
False,
use_split_accumulator,
)[0]
if out_dtype is torch.bfloat16:
if accumulate:
for i in num_gemms:
out[i] = channelwise_dequantize_transA_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i])
else:
for i in num_gemms:
out[i] = channelwise_dequantize_transA(scales_dout_list[i], scales_x_list[i], dw_int32[i])
else:
if accumulate:
for i in num_gemms:
out[i] = channelwise_dequantize_transA_float_add(scales_dout_list[i], scales_x_list[i], dw_int32[i], out[i])
else:
for i in num_gemms:
out[i] = channelwise_dequantize_transA_float(scales_dout_list[i], scales_x_list[i], dw_int32[i])
return out, bias, gelu_input
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
bias = tex.te_general_grouped_gemm(
A,
transa,
......
......@@ -328,6 +328,16 @@ def channelwise_dequantize_transA_float(A, B, C):
out_scales = A.T * B
return out_scales * C.to(dtype=torch.float32)
@torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize_transA_add(A, B, C, D):
out_scales = A.T * B
return (out_scales * C.to(dtype=torch.float32)).to(torch.bfloat16) + D
@torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize_transA_float_add(A, B, C, D):
out_scales = A.T * B
return out_scales * C.to(dtype=torch.float32) + D
@torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize_transB(A, B, C):
out_scales = A * B.T
......
......@@ -475,6 +475,8 @@ def is_non_tn_fp8_gemm_supported() -> bool:
"""Checks whether the device supports
non-TN layouts for FP8 GEMMs.
"""
if IS_HIP_EXTENSION:
return True
device_capability = torch.cuda.get_device_capability()
return (10, 0) <= device_capability < (12, 0) or device_capability >= (13, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment