Commit a61caafd authored by yuguo's avatar yuguo
Browse files
parents af196942 3eb6ea62
...@@ -38,6 +38,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE ...@@ -38,6 +38,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
# channelwise int8 test # channelwise int8 test
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s test_float8_current_scaling_exact.py NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s test_float8_current_scaling_exact.py
NVTE_INT8_SIM_FP8=1 NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 -m pytest -v -s test_float8_current_scaling_exact.py
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 $TE_PATH/tests/pytorch/test_int8_blockwise_gemm_exact.py python3 $TE_PATH/tests/pytorch/test_int8_blockwise_gemm_exact.py
......
# NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 test_int8_channelwise_gemm_exact.py
from collections.abc import Iterable from collections.abc import Iterable
import io import io
from typing import Any, Dict, List, Tuple, Union, Optional from typing import Any, Dict, List, Tuple, Union, Optional
...@@ -20,10 +21,16 @@ import transformer_engine_torch as tex ...@@ -20,10 +21,16 @@ import transformer_engine_torch as tex
from references.ref_per_tensor_cs import ref_per_tensor_cs_cast from references.ref_per_tensor_cs import ref_per_tensor_cs_cast
from transformer_engine.pytorch.triton.per_token_group_quant import per_token_quant_int8, per_token_quant_int8_v2, per_token_quant_fp8_to_int8, per_token_quant_fp8_to_int8_v2, channelwise_dequantize, channelwise_dequantize_transA, channelwise_dequantize_transB, per_token_quant_fp8_to_int8_opt from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_quant_fp8_to_int8,
per_token_quant_fp8_to_int8_opt,
channelwise_dequantize,
channelwise_dequantize_transA,
channelwise_dequantize_transB,
tensorwise_dequantize)
import time import time
import os
int8_simulation_fp8_tensorwise = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE", "0")))
# TN # TN
m = 4096 m = 4096
...@@ -227,6 +234,7 @@ transb = False ...@@ -227,6 +234,7 @@ transb = False
x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16) x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16) w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
output = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
bf16_out = torch.matmul(x_bf16, w_bf16.t()) bf16_out = torch.matmul(x_bf16, w_bf16.t())
...@@ -248,8 +256,12 @@ w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3) ...@@ -248,8 +256,12 @@ w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e4m3fn)) # print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e4m3fn))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e4m3fn)) # print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e4m3fn))
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False) if int8_simulation_fp8_tensorwise:
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False) x_int8, x_scales = x_fp8._data.view(dtype=torch.int8), x_fp8._scale_inv
w_int8, w_scales = w_fp8._data.view(dtype=torch.int8), w_fp8._scale_inv
else:
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False)
# print("x_int8: ", x_int8) # print("x_int8: ", x_int8)
# print("w_int8: ", w_int8) # print("w_int8: ", w_int8)
...@@ -278,40 +290,43 @@ y_int32 = tex.generic_gemm( ...@@ -278,40 +290,43 @@ y_int32 = tex.generic_gemm(
# y_int32 = torch._int_mm(x_int8, w_int8.t()) # y_int32 = torch._int_mm(x_int8, w_int8.t())
# print("y_int32: ", y_int32) # print("y_int32: ", y_int32)
output = channelwise_dequantize_transB(x_scales, w_scales, y_int32) if int8_simulation_fp8_tensorwise:
tensorwise_dequantize(x_scales, w_scales, y_int32, output)
else:
output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
# print("out_scales.shape: ", out_scales.shape) # print("out_scales.shape: ", out_scales.shape)
# print("out_scales: ", out_scales) # print("out_scales: ", out_scales)
print("bf16_out: ", bf16_out) print("bf16_out: ", bf16_out)
print("output: ", output) print("output: ", output)
torch.cuda.synchronize() # torch.cuda.synchronize()
start = time.time() # start = time.time()
for i in range(20): # for i in range(20):
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E4M3) # x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3) # # w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False) # x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False) # w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False)
y_int32 = tex.generic_gemm( # y_int32 = tex.generic_gemm(
w_int8, # w_int8,
transa, # transa,
x_int8, # x_int8,
transb, # transb,
out, # out,
out_quantizer, # out_quantizer,
TE_DType[out_dtype], # TE_DType[out_dtype],
bias, # bias,
bias_dtype, # bias_dtype,
use_gelu, # use_gelu,
aux_tensor, # aux_tensor,
use_grad, # use_grad,
workspace, # workspace,
workspace.shape[0], # workspace.shape[0],
accumulate, # accumulate,
use_split_accumulator, # use_split_accumulator,
)[0] # )[0]
output = channelwise_dequantize_transB(x_scales, w_scales, y_int32) # output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
torch.cuda.synchronize() # torch.cuda.synchronize()
end = time.time() # end = time.time()
# NN # NN
...@@ -321,6 +336,7 @@ transb = False ...@@ -321,6 +336,7 @@ transb = False
dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16) dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16) w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
dx = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
bf16_dx = torch.matmul(dy_bf16, w_bf16) bf16_dx = torch.matmul(dy_bf16, w_bf16)
...@@ -336,8 +352,12 @@ dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2) ...@@ -336,8 +352,12 @@ dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2) w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True) # w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False) if int8_simulation_fp8_tensorwise:
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False) dy_int8, dy_scales = dy_fp8._data.view(dtype=torch.int8), dy_fp8._scale_inv
w_int8, w_scales = w_fp8._data.view(dtype=torch.int8), w_fp8._scale_inv
else:
dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False) # w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False) # w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
...@@ -372,7 +392,10 @@ dx_int32 = tex.generic_gemm( ...@@ -372,7 +392,10 @@ dx_int32 = tex.generic_gemm(
# dx_int32 = torch._int_mm(dy_int8, w_int8) # dx_int32 = torch._int_mm(dy_int8, w_int8)
# print("dx_int32: ", dx_int32) # print("dx_int32: ", dx_int32)
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32) if int8_simulation_fp8_tensorwise:
tensorwise_dequantize(dy_scales, w_scales, dx_int32, dx)
else:
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32) # dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
# print("dx_scales.shape: ", dx_scales.shape) # print("dx_scales.shape: ", dx_scales.shape)
...@@ -380,38 +403,38 @@ dx = channelwise_dequantize(dy_scales, w_scales, dx_int32) ...@@ -380,38 +403,38 @@ dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
print("bf16_dx: ", bf16_dx) print("bf16_dx: ", bf16_dx)
print("dx: ", dx) print("dx: ", dx)
torch.cuda.synchronize() # torch.cuda.synchronize()
start = time.time() # start = time.time()
for i in range(20): # for i in range(20):
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2) # dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2) # # w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True) # # w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False) # dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False) # w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False) # # w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False) # # w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
dx_int32 = tex.generic_gemm( # dx_int32 = tex.generic_gemm(
w_int8, # w_int8,
transa, # transa,
dy_int8, # dy_int8,
transb, # transb,
out, # out,
out_quantizer, # out_quantizer,
TE_DType[out_dtype], # TE_DType[out_dtype],
bias, # bias,
bias_dtype, # bias_dtype,
use_gelu, # use_gelu,
aux_tensor, # aux_tensor,
use_grad, # use_grad,
workspace, # workspace,
workspace.shape[0], # workspace.shape[0],
accumulate, # accumulate,
use_split_accumulator, # use_split_accumulator,
)[0] # )[0]
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32) # dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32) # # dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
torch.cuda.synchronize() # torch.cuda.synchronize()
end = time.time() # end = time.time()
# NT # NT
...@@ -423,6 +446,7 @@ transb = True ...@@ -423,6 +446,7 @@ transb = True
dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16) dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16) x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
dw = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
seed = 0 seed = 0
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
...@@ -442,10 +466,14 @@ end = time.time() ...@@ -442,10 +466,14 @@ end = time.time()
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2) dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2) x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
if int8_simulation_fp8_tensorwise:
dy_int8, dy_scales = dy_fp8._data.view(dtype=torch.int8), dy_fp8._scale_inv
x_int8, x_scales = x_fp8._data.view(dtype=torch.int8), x_fp8._scale_inv
else:
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False) # dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False) # x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False) dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False) x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False) # dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False) # x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
...@@ -471,47 +499,50 @@ dw_int32 = tex.generic_gemm( ...@@ -471,47 +499,50 @@ dw_int32 = tex.generic_gemm(
use_split_accumulator, use_split_accumulator,
)[0] )[0]
dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32) if int8_simulation_fp8_tensorwise:
tensorwise_dequantize(dy_scales, x_scales, dw_int32, dw)
else:
dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32) # dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
print("bf16_dw: ", bf16_dw) print("bf16_dw: ", bf16_dw)
print("dw: ", dw) print("dw: ", dw)
torch.cuda.synchronize() # torch.cuda.synchronize()
start = time.time() # start = time.time()
for i in range(20): # for i in range(20):
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True) # # dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True) # # x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2) # # dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2) # # x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False) # # dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False) # # x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False) # dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False) # x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False) # # dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False) # # x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dw_int32 = tex.generic_gemm( # dw_int32 = tex.generic_gemm(
x_int8, # x_int8,
transa, # transa,
dy_int8, # dy_int8,
transb, # transb,
out, # out,
out_quantizer, # out_quantizer,
TE_DType[out_dtype], # TE_DType[out_dtype],
bias, # bias,
bias_dtype, # bias_dtype,
use_gelu, # use_gelu,
aux_tensor, # aux_tensor,
use_grad, # use_grad,
workspace, # workspace,
workspace.shape[0], # workspace.shape[0],
accumulate, # accumulate,
use_split_accumulator, # use_split_accumulator,
)[0] # )[0]
dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32) # dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32) # # dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
torch.cuda.synchronize() # torch.cuda.synchronize()
end = time.time() # end = time.time()
......
...@@ -716,6 +716,12 @@ struct is_fp8<fp8e4m3> : std::true_type {}; ...@@ -716,6 +716,12 @@ struct is_fp8<fp8e4m3> : std::true_type {};
template <> template <>
struct is_fp8<fp8e5m2> : std::true_type {}; struct is_fp8<fp8e5m2> : std::true_type {};
template <typename T>
struct is_int8 : std::false_type {};
template <>
struct is_int8<int8> : std::true_type {};
template <typename T> template <typename T>
struct is_fp4 : std::false_type {}; struct is_fp4 : std::false_type {};
......
...@@ -1479,8 +1479,8 @@ private: ...@@ -1479,8 +1479,8 @@ private:
}; };
// Define a static userArgs manager // Define a static userArgs manager
// static userArgsManager UAManager; static userArgsManager UAManager;
// static d_userArgsManager d_UAManager; static d_userArgsManager d_UAManager;
void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB, std::vector<Tensor*>& outputD, void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB, std::vector<Tensor*>& outputD,
std::vector<int64_t>& m, std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b, hipblasOperation_t transa, hipblasOperation_t transb, std::vector<int64_t>& m, std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b, hipblasOperation_t transa, hipblasOperation_t transb,
...@@ -1489,10 +1489,10 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1489,10 +1489,10 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid. // Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams); NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
// int device_id; int device_id;
// hipGetDevice(&device_id); hipGetDevice(&device_id);
// hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size()); hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
// hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size()); hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size());
// hipblaslt_ext::UserArguments* userArgs; // hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments))); // NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
...@@ -1566,20 +1566,20 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1566,20 +1566,20 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
} }
// Get the default values from the grouepdgemm object // Get the default values from the grouepdgemm object
// groupedgemm.getDefaultValueForDeviceUserArguments(userArgs); groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
// Copy them to device memory // Copy them to device memory
// hipblaslt_ext::UserArguments* d_userArgs; // hipblaslt_ext::UserArguments* d_userArgs;
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream)); // NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
// NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs, NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs,
// userArgs, userArgs,
// m.size() * sizeof(hipblaslt_ext::UserArguments), m.size() * sizeof(hipblaslt_ext::UserArguments),
// hipMemcpyHostToDevice, stream)); hipMemcpyHostToDevice, stream));
// Make sure to initialize everytime the algo changes // Make sure to initialize everytime the algo changes
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace)); NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream)); NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream)); // NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream)); // NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
// NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream)); // NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
// NVTE_CHECK_CUDA(hipFree(userArgs)); // NVTE_CHECK_CUDA(hipFree(userArgs));
......
...@@ -291,8 +291,8 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf ...@@ -291,8 +291,8 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
"Tensor must be FP8 tensor with per-tensor scaling, " "Tensor must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode=", "but got scaling_mode=",
to_string(output.scaling_mode)); to_string(output.scaling_mode));
NVTE_CHECK(is_fp8_dtype(output.data.dtype), NVTE_CHECK(is_fp8_dtype(output.data.dtype) || is_int8_dtype(output.data.dtype),
"Tensor must be FP8, but got dtype=", to_string(output.data.dtype)); "Tensor must be FP8 or INT8, but got dtype=", to_string(output.data.dtype));
NVTE_CHECK(output.amax.numel() == 1, NVTE_CHECK(output.amax.numel() == 1,
"Tensor has invalid amax tensor (expected 1 entry, got shape=", output.amax.shape, "Tensor has invalid amax tensor (expected 1 entry, got shape=", output.amax.shape,
")"); ")");
...@@ -314,7 +314,7 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf ...@@ -314,7 +314,7 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
// Maximum FP8 value // Maximum FP8 value
float max_fp8 = 0.f; float max_fp8 = 0.f;
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output.data.dtype, DType, TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(output.data.dtype, DType,
max_fp8 = Quantized_Limits<DType>::max_norm;); max_fp8 = Quantized_Limits<DType>::max_norm;);
// Update scale // Update scale
......
...@@ -168,7 +168,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) { ...@@ -168,7 +168,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) { void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
const DType type = t.dtype(); const DType type = t.dtype();
if (is_fp8_dtype(type)) { if (is_fp8_dtype(type) || is_int8_dtype(type)) {
// FP8 output needs to have scale, scale_inv and (if delayed scaling) amax // FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.dptr != nullptr) { if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.dptr != nullptr) {
NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ", NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ",
......
...@@ -1076,7 +1076,7 @@ void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, ...@@ -1076,7 +1076,7 @@ void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop,
const size_t N = product(input.data.shape); const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType, input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT_WITH_INT8(
output->data.dtype, OType, output->data.dtype, OType,
if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType); constexpr int nvec = 32 / sizeof(IType);
...@@ -1105,7 +1105,7 @@ void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *inp ...@@ -1105,7 +1105,7 @@ void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *inp
const size_t N = product(input->data.shape); const size_t N = product(input->data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input->data.dtype, IType, input->data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT_WITH_INT8(
output->data.dtype, OType, output->data.dtype, OType,
if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) {
constexpr int nvec = 32 / sizeof(IType); constexpr int nvec = 32 / sizeof(IType);
...@@ -1275,6 +1275,11 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ...@@ -1275,6 +1275,11 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
switch (output_tensor->scaling_mode) { switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: { case NVTE_DELAYED_TENSOR_SCALING: {
if (output_tensor->has_columnwise_data()) { if (output_tensor->has_columnwise_data()) {
const char *NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");
if(NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1'){
NVTE_CHECK(false,
"NVTE_INT8_SIM_FP8_TENSORWISE need not be transposed!");
}
NVTE_CHECK(output_tensor->has_data(), NVTE_CHECK(output_tensor->has_data(),
"Quantizing in only the columnwise direction not supported yet!"); "Quantizing in only the columnwise direction not supported yet!");
if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) {
......
...@@ -231,12 +231,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -231,12 +231,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#endif // __HIP_PLATFORM_AMD__ #endif // __HIP_PLATFORM_AMD__
static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); NVTE_CHECK(is_fp8_dtype(input.data.dtype) || is_int8_dtype(input.data.dtype), "Input must have FP8 or INT8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(!is_fp8_dtype(output->data.dtype) && !is_int8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
const size_t N = product(input.data.shape); const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(
input.data.dtype, IType, input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->data.dtype, OType, output->data.dtype, OType,
......
...@@ -183,7 +183,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -183,7 +183,7 @@ __launch_bounds__(unary_kernel_threads) __global__
VectorizedStorer<OutputType, nvec, aligned> storer(output, N); VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 1; ComputeType s = 1;
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
} }
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
...@@ -196,18 +196,21 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -196,18 +196,21 @@ __launch_bounds__(unary_kernel_threads) __global__
for (int i = 0; i < nvec; ++i) { for (int i = 0; i < nvec; ++i) {
const ComputeType val = static_cast<ComputeType>(loader.separate()[i]); const ComputeType val = static_cast<ComputeType>(loader.separate()[i]);
ComputeType temp = OP(val, p); ComputeType temp = OP(val, p);
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
__builtin_assume(max >= 0); __builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max); max = fmaxf(fabsf(temp), max);
temp = temp * s; temp = temp * s;
} }
if constexpr (is_int8<OutputType>::value) {
storer.separate()[i] = static_cast<OutputType>(temp); storer.separate()[i] = static_cast<OutputType>(lroundf(fmaxf(-127.0f, fminf(127.0f, temp))));
} else {
storer.separate()[i] = static_cast<OutputType>(temp);
}
} }
storer.store(tid, N); storer.store(tid, N);
} }
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
// Reduce amax over block // Reduce amax over block
if (amax != nullptr) { if (amax != nullptr) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id); max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
...@@ -236,7 +239,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -236,7 +239,7 @@ __launch_bounds__(unary_kernel_threads) __global__
VectorizedStorer<OutputType, nvec, aligned> storer(output, N); VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 1; ComputeType s = 1;
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
} }
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
...@@ -251,18 +254,21 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -251,18 +254,21 @@ __launch_bounds__(unary_kernel_threads) __global__
const ComputeType val = static_cast<ComputeType>(loader.separate()[i]); const ComputeType val = static_cast<ComputeType>(loader.separate()[i]);
const ComputeType g = static_cast<ComputeType>(grad_loader.separate()[i]); const ComputeType g = static_cast<ComputeType>(grad_loader.separate()[i]);
ComputeType temp = OP(val, p) * g; ComputeType temp = OP(val, p) * g;
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
__builtin_assume(max >= 0); __builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max); max = fmaxf(fabsf(temp), max);
temp = temp * s; temp = temp * s;
} }
if constexpr (is_int8<OutputType>::value) {
storer.separate()[i] = static_cast<OutputType>(temp); storer.separate()[i] = static_cast<OutputType>(lroundf(fmaxf(-127.0f, fminf(127.0f, temp))));
} else {
storer.separate()[i] = static_cast<OutputType>(temp);
}
} }
storer.store(tid, N); storer.store(tid, N);
} }
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
// Reduce amax over block // Reduce amax over block
if (amax != nullptr) { if (amax != nullptr) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id); max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
......
...@@ -30,10 +30,14 @@ from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_q ...@@ -30,10 +30,14 @@ from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_q
channelwise_dequantize_transA_float, channelwise_dequantize_transA_float,
channelwise_dequantize_transB, channelwise_dequantize_transB,
channelwise_dequantize_transA_add, channelwise_dequantize_transA_add,
channelwise_dequantize_transA_float_add) channelwise_dequantize_transA_float_add,
tensorwise_dequantize,
tensorwise_dequantize_add,
tensorwise_dequantize_float,
tensorwise_dequantize_float_add)
from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import get_device_compute_capability
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0"))) from transformer_engine.pytorch.fp8 import int8_simulation_fp8, int8_simulation_fp8_tensorwise
__all__ = [ __all__ = [
"general_gemm", "general_gemm",
"general_grouped_gemm", "general_grouped_gemm",
...@@ -191,8 +195,12 @@ def general_gemm( ...@@ -191,8 +195,12 @@ def general_gemm(
if layout == "TN": if layout == "TN":
assert out_dtype is torch.bfloat16 assert out_dtype is torch.bfloat16
out_shape = B._data.shape[:-1] + (A._data.shape[0], ) out_shape = B._data.shape[:-1] + (A._data.shape[0], )
x_int8, x_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False) if int8_simulation_fp8_tensorwise:
w_int8, w_scales = per_token_quant_fp8_to_int8(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False) x_int8, x_scales = B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv
w_int8, w_scales = A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv
else:
x_int8, x_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
y_int32 = tex.generic_gemm( y_int32 = tex.generic_gemm(
w_int8, w_int8,
...@@ -212,14 +220,22 @@ def general_gemm( ...@@ -212,14 +220,22 @@ def general_gemm(
False, False,
use_split_accumulator, use_split_accumulator,
)[0] )[0]
y = channelwise_dequantize_transB(x_scales, w_scales, y_int32) if int8_simulation_fp8_tensorwise:
y = torch.empty_like(y_int32, device=y_int32.device, dtype=torch.bfloat16)
tensorwise_dequantize(x_scales, w_scales, y_int32, y)
else:
y = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
return y.view(out_shape), None, None, None return y.view(out_shape), None, None, None
elif layout == "NN": elif layout == "NN":
assert out_dtype is torch.bfloat16 assert out_dtype is torch.bfloat16
dx_shape = B._data.shape[:-1] + (A._data.shape[-1], ) dx_shape = B._data.shape[:-1] + (A._data.shape[-1], )
dy_int8, dy_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False) if int8_simulation_fp8_tensorwise:
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False) dy_int8, dy_scales = B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv
w_int8, w_scales = A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv
else:
dy_int8, dy_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
dx_int32 = tex.generic_gemm( dx_int32 = tex.generic_gemm(
w_int8, w_int8,
...@@ -239,13 +255,21 @@ def general_gemm( ...@@ -239,13 +255,21 @@ def general_gemm(
False, False,
use_split_accumulator, use_split_accumulator,
)[0] )[0]
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32) if int8_simulation_fp8_tensorwise:
dx = torch.empty_like(dx_int32, device=dx_int32.device, dtype=torch.bfloat16)
tensorwise_dequantize(dy_scales, w_scales, dx_int32, dx)
else:
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
return dx.view(dx_shape), None, None, None return dx.view(dx_shape), None, None, None
elif layout == "NT": elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32 assert out_dtype is torch.bfloat16 or out_dtype is torch.float32
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False) if int8_simulation_fp8_tensorwise:
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False) dy_int8, dy_scales = B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv
x_int8, x_scales = A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv
else:
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
dw_int32 = tex.generic_gemm( dw_int32 = tex.generic_gemm(
x_int8, x_int8,
...@@ -267,14 +291,29 @@ def general_gemm( ...@@ -267,14 +291,29 @@ def general_gemm(
)[0] )[0]
if out_dtype is torch.bfloat16: if out_dtype is torch.bfloat16:
if accumulate: if accumulate:
channelwise_dequantize_transA_add(dy_scales, x_scales, dw_int32, out) if int8_simulation_fp8_tensorwise:
tensorwise_dequantize_add(dy_scales, x_scales, dw_int32, out)
else:
channelwise_dequantize_transA_add(dy_scales, x_scales, dw_int32, out)
else: else:
out = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32) if int8_simulation_fp8_tensorwise:
out = torch.empty_like(dw_int32, device=dw_int32.device, dtype=torch.bfloat16)
tensorwise_dequantize(dy_scales, x_scales, dw_int32, out)
else:
out = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
else: else:
if accumulate: if accumulate:
channelwise_dequantize_transA_float_add(dy_scales, x_scales, dw_int32, out) if int8_simulation_fp8_tensorwise:
tensorwise_dequantize_float_add(dy_scales, x_scales, dw_int32, out)
else:
channelwise_dequantize_transA_float_add(dy_scales, x_scales, dw_int32, out)
else: else:
out = channelwise_dequantize_transA_float(dy_scales, x_scales, dw_int32) if int8_simulation_fp8_tensorwise:
out = torch.empty_like(dw_int32, device=dw_int32.device, dtype=torch.float32)
tensorwise_dequantize_float(dy_scales, x_scales, dw_int32, out)
else:
out = channelwise_dequantize_transA_float(dy_scales, x_scales, dw_int32)
return out, None, None, None return out, None, None, None
else: else:
......
...@@ -28,6 +28,7 @@ from .utils import get_device_compute_capability ...@@ -28,6 +28,7 @@ from .utils import get_device_compute_capability
from .jit import jit_fuser from .jit import jit_fuser
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0"))) int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
int8_simulation_fp8_tensorwise = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE", "0")))
blockwise_fp8_block_len = int(os.getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN", "128")) blockwise_fp8_block_len = int(os.getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN", "128"))
__all__ = ["fp8_autocast", "fp8_model_init"] __all__ = ["fp8_autocast", "fp8_model_init"]
......
...@@ -21,7 +21,7 @@ from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len ...@@ -21,7 +21,7 @@ from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
aten = torch.ops.aten aten = torch.ops.aten
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0"))) from transformer_engine.pytorch.fp8 import int8_simulation_fp8
class Float8BlockQuantizer(Quantizer): class Float8BlockQuantizer(Quantizer):
"""Builder class for tensors quantized with current scaling using """Builder class for tensors quantized with current scaling using
......
...@@ -16,6 +16,7 @@ from ..utils import canonicalize_process_group, devices_match ...@@ -16,6 +16,7 @@ from ..utils import canonicalize_process_group, devices_match
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..constants import dist_group_type from ..constants import dist_group_type
from transformer_engine.pytorch.fp8 import int8_simulation_fp8_tensorwise
aten = torch.ops.aten aten = torch.ops.aten
...@@ -232,7 +233,7 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -232,7 +233,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
super().__init__(rowwise=rowwise, columnwise=columnwise) super().__init__(rowwise=rowwise, columnwise=columnwise)
self.scale = torch.ones(1, dtype=torch.float32, device=device) self.scale = torch.ones(1, dtype=torch.float32, device=device)
self.amax = torch.empty(1, dtype=torch.float32, device=device) self.amax = torch.empty(1, dtype=torch.float32, device=device)
self.dtype = fp8_dtype self.dtype = tex.DType.kInt8 if int8_simulation_fp8_tensorwise else fp8_dtype
self.with_amax_reduction = with_amax_reduction self.with_amax_reduction = with_amax_reduction
self.amax_reduction_group = amax_reduction_group self.amax_reduction_group = amax_reduction_group
self.force_pow_2_scales = force_pow_2_scales self.force_pow_2_scales = force_pow_2_scales
......
...@@ -284,6 +284,8 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo ...@@ -284,6 +284,8 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
max_fp8 = 448.0 max_fp8 = 448.0
elif fp8_dtype == tex.DType.kFloat8E5M2: elif fp8_dtype == tex.DType.kFloat8E5M2:
max_fp8 = 57344.0 max_fp8 = 57344.0
elif fp8_dtype == tex.DType.kInt8:
max_fp8 = 127.0
else: else:
raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}") raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}")
multi_tensor_applier( multi_tensor_applier(
......
...@@ -122,6 +122,68 @@ def per_token_quant_int8_v2(x): ...@@ -122,6 +122,68 @@ def per_token_quant_int8_v2(x):
@triton.jit
def _tensorwise_dequantize_impl(
x_ptr,
y_ptr,
scaleA_ptr,
scaleB_ptr,
stride_x,
stride_y,
N,
is_add: tl.constexpr,
is_float: tl.constexpr,
BLOCK: tl.constexpr,
):
row_id = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
a_scale = tl.load(scaleA_ptr)
b_scale = tl.load(scaleB_ptr)
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
result = x * a_scale * b_scale
if is_add:
y = tl.load(y_ptr + row_id * stride_y + cols, mask=mask, other=0.0).to(tl.float32)
result += y
if is_float:
tl.store(y_ptr + row_id * stride_y + cols, result, mask=mask)
else:
tl.store(y_ptr + row_id * stride_y + cols, result.to(tl.bfloat16), mask=mask)
def _tensorwise_dequantize(a_scale, b_scale, x, y, is_add=False, is_float=False):
assert x.is_contiguous()
x = x.view(-1, x.shape[-1])
M = x.numel() // x.shape[-1]
N = x.shape[-1]
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
assert x.is_contiguous()
_tensorwise_dequantize_impl[(M, )](
x,
y,
a_scale,
b_scale,
stride_x=x.stride(-2),
stride_y=y.stride(-2),
N=N,
is_add=is_add,
is_float=is_float,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
@triton.jit @triton.jit
def _per_token_quant_fp8_to_int8( def _per_token_quant_fp8_to_int8(
x_ptr, x_ptr,
...@@ -343,6 +405,18 @@ def channelwise_dequantize_transB(A, B, C): ...@@ -343,6 +405,18 @@ def channelwise_dequantize_transB(A, B, C):
out_scales = A * B.T out_scales = A * B.T
return (out_scales * C.to(dtype=torch.float32)).to(torch.bfloat16) return (out_scales * C.to(dtype=torch.float32)).to(torch.bfloat16)
def tensorwise_dequantize(A, B, C, D):
_tensorwise_dequantize(A, B, C, D, is_add=False, is_float=False)
def tensorwise_dequantize_float(A, B, C, D):
_tensorwise_dequantize(A, B, C, D, is_add=False, is_float=True)
def tensorwise_dequantize_add(A, B, C, D):
_tensorwise_dequantize(A, B, C, D, is_add=True, is_float=False)
def tensorwise_dequantize_float_add(A, B, C, D):
_tensorwise_dequantize(A, B, C, D, is_add=True, is_float=True)
def to_int8(tensor: torch.Tensor): def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment