Commit 3eb6ea62 authored by yuguo's avatar yuguo
Browse files

[DCU] add NVTE_INT8_SIM_FP8_TENSORWISE

parent 68d6c506
......@@ -36,6 +36,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
# channelwise int8 test
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s test_float8_current_scaling_exact.py
NVTE_INT8_SIM_FP8=1 NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 -m pytest -v -s test_float8_current_scaling_exact.py
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 $TE_PATH/tests/pytorch/test_int8_blockwise_gemm_exact.py
......
# NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 test_int8_channelwise_gemm_exact.py
from collections.abc import Iterable
import io
from typing import Any, Dict, List, Tuple, Union, Optional
......@@ -20,10 +21,16 @@ import transformer_engine_torch as tex
from references.ref_per_tensor_cs import ref_per_tensor_cs_cast
from transformer_engine.pytorch.triton.per_token_group_quant import per_token_quant_int8, per_token_quant_int8_v2, per_token_quant_fp8_to_int8, per_token_quant_fp8_to_int8_v2, channelwise_dequantize, channelwise_dequantize_transA, channelwise_dequantize_transB, per_token_quant_fp8_to_int8_opt
from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_quant_fp8_to_int8,
per_token_quant_fp8_to_int8_opt,
channelwise_dequantize,
channelwise_dequantize_transA,
channelwise_dequantize_transB,
tensorwise_dequantize)
import time
import os
int8_simulation_fp8_tensorwise = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE", "0")))
# TN
m = 4096
......@@ -227,6 +234,7 @@ transb = False
x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
output = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
bf16_out = torch.matmul(x_bf16, w_bf16.t())
......@@ -248,8 +256,12 @@ w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e4m3fn))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e4m3fn))
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False)
if int8_simulation_fp8_tensorwise:
x_int8, x_scales = x_fp8._data.view(dtype=torch.int8), x_fp8._scale_inv
w_int8, w_scales = w_fp8._data.view(dtype=torch.int8), w_fp8._scale_inv
else:
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
......@@ -278,40 +290,43 @@ y_int32 = tex.generic_gemm(
# y_int32 = torch._int_mm(x_int8, w_int8.t())
# print("y_int32: ", y_int32)
output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
if int8_simulation_fp8_tensorwise:
tensorwise_dequantize(x_scales, w_scales, y_int32, output)
else:
output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
# print("out_scales.shape: ", out_scales.shape)
# print("out_scales: ", out_scales)
print("bf16_out: ", bf16_out)
print("output: ", output)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False)
y_int32 = tex.generic_gemm(
w_int8,
transa,
x_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
torch.cuda.synchronize()
end = time.time()
# torch.cuda.synchronize()
# start = time.time()
# for i in range(20):
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# # w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False)
# y_int32 = tex.generic_gemm(
# w_int8,
# transa,
# x_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
# torch.cuda.synchronize()
# end = time.time()
# NN
......@@ -321,6 +336,7 @@ transb = False
dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
dx = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
bf16_dx = torch.matmul(dy_bf16, w_bf16)
......@@ -336,8 +352,12 @@ dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
if int8_simulation_fp8_tensorwise:
dy_int8, dy_scales = dy_fp8._data.view(dtype=torch.int8), dy_fp8._scale_inv
w_int8, w_scales = w_fp8._data.view(dtype=torch.int8), w_fp8._scale_inv
else:
dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
......@@ -372,7 +392,10 @@ dx_int32 = tex.generic_gemm(
# dx_int32 = torch._int_mm(dy_int8, w_int8)
# print("dx_int32: ", dx_int32)
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
if int8_simulation_fp8_tensorwise:
tensorwise_dequantize(dy_scales, w_scales, dx_int32, dx)
else:
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
# print("dx_scales.shape: ", dx_scales.shape)
......@@ -380,38 +403,38 @@ dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
print("bf16_dx: ", bf16_dx)
print("dx: ", dx)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
dx_int32 = tex.generic_gemm(
w_int8,
transa,
dy_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
torch.cuda.synchronize()
end = time.time()
# torch.cuda.synchronize()
# start = time.time()
# for i in range(20):
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# # w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# # w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# # w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# # w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# dx_int32 = tex.generic_gemm(
# w_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# # dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
# torch.cuda.synchronize()
# end = time.time()
# NT
......@@ -423,6 +446,7 @@ transb = True
dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
dw = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
......@@ -442,10 +466,14 @@ end = time.time()
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
if int8_simulation_fp8_tensorwise:
dy_int8, dy_scales = dy_fp8._data.view(dtype=torch.int8), dy_fp8._scale_inv
x_int8, x_scales = x_fp8._data.view(dtype=torch.int8), x_fp8._scale_inv
else:
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
......@@ -471,47 +499,50 @@ dw_int32 = tex.generic_gemm(
use_split_accumulator,
)[0]
dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
if int8_simulation_fp8_tensorwise:
tensorwise_dequantize(dy_scales, x_scales, dw_int32, dw)
else:
dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
print("bf16_dw: ", bf16_dw)
print("dw: ", dw)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dw_int32 = tex.generic_gemm(
x_int8,
transa,
dy_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
torch.cuda.synchronize()
end = time.time()
# torch.cuda.synchronize()
# start = time.time()
# for i in range(20):
# # dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# # x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# # dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# # x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# # dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# # x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# # dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# # x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dw_int32 = tex.generic_gemm(
# x_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# # dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
# torch.cuda.synchronize()
# end = time.time()
......
......@@ -716,6 +716,12 @@ struct is_fp8<fp8e4m3> : std::true_type {};
template <>
struct is_fp8<fp8e5m2> : std::true_type {};
template <typename T>
struct is_int8 : std::false_type {};
template <>
struct is_int8<int8> : std::true_type {};
template <typename T>
struct is_fp4 : std::false_type {};
......
......@@ -1479,8 +1479,8 @@ private:
};
// Define a static userArgs manager
// static userArgsManager UAManager;
// static d_userArgsManager d_UAManager;
static userArgsManager UAManager;
static d_userArgsManager d_UAManager;
void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB, std::vector<Tensor*>& outputD,
std::vector<int64_t>& m, std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b, hipblasOperation_t transa, hipblasOperation_t transb,
......@@ -1489,10 +1489,10 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
// int device_id;
// hipGetDevice(&device_id);
// hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
// hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size());
int device_id;
hipGetDevice(&device_id);
hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size());
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
......@@ -1566,20 +1566,20 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
}
// Get the default values from the grouepdgemm object
// groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
// Copy them to device memory
// hipblaslt_ext::UserArguments* d_userArgs;
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
// NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs,
// userArgs,
// m.size() * sizeof(hipblaslt_ext::UserArguments),
// hipMemcpyHostToDevice, stream));
NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs,
userArgs,
m.size() * sizeof(hipblaslt_ext::UserArguments),
hipMemcpyHostToDevice, stream));
// Make sure to initialize everytime the algo changes
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
// NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
// NVTE_CHECK_CUDA(hipFree(userArgs));
......
......@@ -291,8 +291,8 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
"Tensor must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode=",
to_string(output.scaling_mode));
NVTE_CHECK(is_fp8_dtype(output.data.dtype),
"Tensor must be FP8, but got dtype=", to_string(output.data.dtype));
NVTE_CHECK(is_fp8_dtype(output.data.dtype) || is_int8_dtype(output.data.dtype),
"Tensor must be FP8 or INT8, but got dtype=", to_string(output.data.dtype));
NVTE_CHECK(output.amax.numel() == 1,
"Tensor has invalid amax tensor (expected 1 entry, got shape=", output.amax.shape,
")");
......@@ -314,7 +314,7 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
// Maximum FP8 value
float max_fp8 = 0.f;
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output.data.dtype, DType,
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(output.data.dtype, DType,
max_fp8 = Quantized_Limits<DType>::max_norm;);
// Update scale
......
......@@ -166,7 +166,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
const DType type = t.dtype();
if (is_fp8_dtype(type)) {
if (is_fp8_dtype(type) || is_int8_dtype(type)) {
// FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.dptr != nullptr) {
NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ",
......
......@@ -1076,7 +1076,7 @@ void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop,
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT_WITH_INT8(
output->data.dtype, OType,
if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType);
......@@ -1105,7 +1105,7 @@ void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *inp
const size_t N = product(input->data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input->data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT_WITH_INT8(
output->data.dtype, OType,
if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType);
......@@ -1275,6 +1275,11 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
if (output_tensor->has_columnwise_data()) {
const char *NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");
if(NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1'){
NVTE_CHECK(false,
"NVTE_INT8_SIM_FP8_TENSORWISE need not be transposed!");
}
NVTE_CHECK(output_tensor->has_data(),
"Quantizing in only the columnwise direction not supported yet!");
if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) {
......
......@@ -231,12 +231,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#endif // __HIP_PLATFORM_AMD__
static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(is_fp8_dtype(input.data.dtype) || is_int8_dtype(input.data.dtype), "Input must have FP8 or INT8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype) && !is_int8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->data.dtype, OType,
......
......@@ -183,7 +183,7 @@ __launch_bounds__(unary_kernel_threads) __global__
VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0;
ComputeType s = 1;
if constexpr (is_fp8<OutputType>::value) {
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
if (scale != nullptr) s = *scale;
}
const int warp_id = threadIdx.x / THREADS_PER_WARP;
......@@ -196,18 +196,21 @@ __launch_bounds__(unary_kernel_threads) __global__
for (int i = 0; i < nvec; ++i) {
const ComputeType val = static_cast<ComputeType>(loader.separate()[i]);
ComputeType temp = OP(val, p);
if constexpr (is_fp8<OutputType>::value) {
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
__builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max);
temp = temp * s;
}
if constexpr (is_int8<OutputType>::value) {
storer.separate()[i] = static_cast<OutputType>(lroundf(fmaxf(-127.0f, fminf(127.0f, temp))));
} else {
storer.separate()[i] = static_cast<OutputType>(temp);
}
}
storer.store(tid, N);
}
if constexpr (is_fp8<OutputType>::value) {
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
// Reduce amax over block
if (amax != nullptr) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
......@@ -236,7 +239,7 @@ __launch_bounds__(unary_kernel_threads) __global__
VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0;
ComputeType s = 1;
if constexpr (is_fp8<OutputType>::value) {
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
if (scale != nullptr) s = *scale;
}
const int warp_id = threadIdx.x / THREADS_PER_WARP;
......@@ -251,18 +254,21 @@ __launch_bounds__(unary_kernel_threads) __global__
const ComputeType val = static_cast<ComputeType>(loader.separate()[i]);
const ComputeType g = static_cast<ComputeType>(grad_loader.separate()[i]);
ComputeType temp = OP(val, p) * g;
if constexpr (is_fp8<OutputType>::value) {
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
__builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max);
temp = temp * s;
}
if constexpr (is_int8<OutputType>::value) {
storer.separate()[i] = static_cast<OutputType>(lroundf(fmaxf(-127.0f, fminf(127.0f, temp))));
} else {
storer.separate()[i] = static_cast<OutputType>(temp);
}
}
storer.store(tid, N);
}
if constexpr (is_fp8<OutputType>::value) {
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
// Reduce amax over block
if (amax != nullptr) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
......
......@@ -30,10 +30,14 @@ from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_q
channelwise_dequantize_transA_float,
channelwise_dequantize_transB,
channelwise_dequantize_transA_add,
channelwise_dequantize_transA_float_add)
channelwise_dequantize_transA_float_add,
tensorwise_dequantize,
tensorwise_dequantize_add,
tensorwise_dequantize_float,
tensorwise_dequantize_float_add)
from transformer_engine.pytorch.utils import get_device_compute_capability
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
from transformer_engine.pytorch.fp8 import int8_simulation_fp8, int8_simulation_fp8_tensorwise
__all__ = [
"general_gemm",
"general_grouped_gemm",
......@@ -191,6 +195,10 @@ def general_gemm(
if layout == "TN":
assert out_dtype is torch.bfloat16
out_shape = B._data.shape[:-1] + (A._data.shape[0], )
if int8_simulation_fp8_tensorwise:
x_int8, x_scales = B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv
w_int8, w_scales = A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv
else:
x_int8, x_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
......@@ -212,12 +220,20 @@ def general_gemm(
False,
use_split_accumulator,
)[0]
if int8_simulation_fp8_tensorwise:
y = torch.empty_like(y_int32, device=y_int32.device, dtype=torch.bfloat16)
tensorwise_dequantize(x_scales, w_scales, y_int32, y)
else:
y = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
return y.view(out_shape), None, None, None
elif layout == "NN":
assert out_dtype is torch.bfloat16
dx_shape = B._data.shape[:-1] + (A._data.shape[-1], )
if int8_simulation_fp8_tensorwise:
dy_int8, dy_scales = B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv
w_int8, w_scales = A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv
else:
dy_int8, dy_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
......@@ -239,11 +255,19 @@ def general_gemm(
False,
use_split_accumulator,
)[0]
if int8_simulation_fp8_tensorwise:
dx = torch.empty_like(dx_int32, device=dx_int32.device, dtype=torch.bfloat16)
tensorwise_dequantize(dy_scales, w_scales, dx_int32, dx)
else:
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
return dx.view(dx_shape), None, None, None
elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32
if int8_simulation_fp8_tensorwise:
dy_int8, dy_scales = B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv
x_int8, x_scales = A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv
else:
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
......@@ -267,12 +291,27 @@ def general_gemm(
)[0]
if out_dtype is torch.bfloat16:
if accumulate:
if int8_simulation_fp8_tensorwise:
tensorwise_dequantize_add(dy_scales, x_scales, dw_int32, out)
else:
channelwise_dequantize_transA_add(dy_scales, x_scales, dw_int32, out)
else:
if int8_simulation_fp8_tensorwise:
out = torch.empty_like(dw_int32, device=dw_int32.device, dtype=torch.bfloat16)
tensorwise_dequantize(dy_scales, x_scales, dw_int32, out)
else:
out = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
else:
if accumulate:
if int8_simulation_fp8_tensorwise:
tensorwise_dequantize_float_add(dy_scales, x_scales, dw_int32, out)
else:
channelwise_dequantize_transA_float_add(dy_scales, x_scales, dw_int32, out)
else:
if int8_simulation_fp8_tensorwise:
out = torch.empty_like(dw_int32, device=dw_int32.device, dtype=torch.float32)
tensorwise_dequantize_float(dy_scales, x_scales, dw_int32, out)
else:
out = channelwise_dequantize_transA_float(dy_scales, x_scales, dw_int32)
return out, None, None, None
......
......@@ -28,6 +28,7 @@ from .utils import get_device_compute_capability
from .jit import jit_fuser
from torch.utils.cpp_extension import IS_HIP_EXTENSION
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
int8_simulation_fp8_tensorwise = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE", "0")))
blockwise_fp8_block_len = int(os.getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN", "128"))
__all__ = ["fp8_autocast", "fp8_model_init"]
......
......@@ -20,7 +20,7 @@ from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
aten = torch.ops.aten
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
from transformer_engine.pytorch.fp8 import int8_simulation_fp8
class Float8BlockQuantizer(Quantizer):
"""Builder class for tensors quantized with current scaling using
......
......@@ -16,6 +16,7 @@ from ..utils import canonicalize_process_group, devices_match
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..constants import dist_group_type
from transformer_engine.pytorch.fp8 import int8_simulation_fp8_tensorwise
aten = torch.ops.aten
......@@ -217,7 +218,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.scale = torch.ones(1, dtype=torch.float32, device=device)
self.amax = torch.empty(1, dtype=torch.float32, device=device)
self.dtype = fp8_dtype
self.dtype = tex.DType.kInt8 if int8_simulation_fp8_tensorwise else fp8_dtype
self.with_amax_reduction = with_amax_reduction
self.amax_reduction_group = amax_reduction_group
self.force_pow_2_scales = force_pow_2_scales
......
......@@ -284,6 +284,8 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
max_fp8 = 448.0
elif fp8_dtype == tex.DType.kFloat8E5M2:
max_fp8 = 57344.0
elif fp8_dtype == tex.DType.kInt8:
max_fp8 = 127.0
else:
raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}")
multi_tensor_applier(
......
......@@ -122,6 +122,68 @@ def per_token_quant_int8_v2(x):
@triton.jit
def _tensorwise_dequantize_impl(
x_ptr,
y_ptr,
scaleA_ptr,
scaleB_ptr,
stride_x,
stride_y,
N,
is_add: tl.constexpr,
is_float: tl.constexpr,
BLOCK: tl.constexpr,
):
row_id = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
a_scale = tl.load(scaleA_ptr)
b_scale = tl.load(scaleB_ptr)
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
result = x * a_scale * b_scale
if is_add:
y = tl.load(y_ptr + row_id * stride_y + cols, mask=mask, other=0.0).to(tl.float32)
result += y
if is_float:
tl.store(y_ptr + row_id * stride_y + cols, result, mask=mask)
else:
tl.store(y_ptr + row_id * stride_y + cols, result.to(tl.bfloat16), mask=mask)
def _tensorwise_dequantize(a_scale, b_scale, x, y, is_add=False, is_float=False):
assert x.is_contiguous()
x = x.view(-1, x.shape[-1])
M = x.numel() // x.shape[-1]
N = x.shape[-1]
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
assert x.is_contiguous()
_tensorwise_dequantize_impl[(M, )](
x,
y,
a_scale,
b_scale,
stride_x=x.stride(-2),
stride_y=y.stride(-2),
N=N,
is_add=is_add,
is_float=is_float,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
@triton.jit
def _per_token_quant_fp8_to_int8(
x_ptr,
......@@ -343,6 +405,18 @@ def channelwise_dequantize_transB(A, B, C):
out_scales = A * B.T
return (out_scales * C.to(dtype=torch.float32)).to(torch.bfloat16)
def tensorwise_dequantize(A, B, C, D):
_tensorwise_dequantize(A, B, C, D, is_add=False, is_float=False)
def tensorwise_dequantize_float(A, B, C, D):
_tensorwise_dequantize(A, B, C, D, is_add=False, is_float=True)
def tensorwise_dequantize_add(A, B, C, D):
_tensorwise_dequantize(A, B, C, D, is_add=True, is_float=False)
def tensorwise_dequantize_float_add(A, B, C, D):
_tensorwise_dequantize(A, B, C, D, is_add=True, is_float=True)
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment