Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
3eb6ea62
Commit
3eb6ea62
authored
Aug 08, 2025
by
yuguo
Browse files
[DCU] add NVTE_INT8_SIM_FP8_TENSORWISE
parent
68d6c506
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
322 additions
and
156 deletions
+322
-156
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+1
-0
tests/pytorch/test_int8_channelwise_gemm_exact.py
tests/pytorch/test_int8_channelwise_gemm_exact.py
+137
-106
transformer_engine/common/common.h
transformer_engine/common/common.h
+6
-0
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+15
-15
transformer_engine/common/recipe/current_scaling.cu
transformer_engine/common/recipe/current_scaling.cu
+3
-3
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+1
-1
transformer_engine/common/util/cast_kernels.cuh
transformer_engine/common/util/cast_kernels.cuh
+7
-2
transformer_engine/common/util/dequantize_kernels.cuh
transformer_engine/common/util/dequantize_kernels.cuh
+3
-3
transformer_engine/common/util/vectorized_pointwise.h
transformer_engine/common/util/vectorized_pointwise.h
+16
-10
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+53
-14
transformer_engine/pytorch/fp8.py
transformer_engine/pytorch/fp8.py
+1
-0
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
+1
-1
transformer_engine/pytorch/tensor/float8_tensor.py
transformer_engine/pytorch/tensor/float8_tensor.py
+2
-1
transformer_engine/pytorch/tensor/utils.py
transformer_engine/pytorch/tensor/utils.py
+2
-0
transformer_engine/pytorch/triton/per_token_group_quant.py
transformer_engine/pytorch/triton/per_token_group_quant.py
+74
-0
No files found.
qa/L0_pytorch_unittest/test.sh
View file @
3eb6ea62
...
...
@@ -36,6 +36,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8blockwisetensor.xml
$TE_PATH
/tests/pytorch/test_float8blockwisetensor.py
||
test_fail
"test_float8blockwisetensor.py"
# channelwise int8 test
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
-v
-s
test_float8_current_scaling_exact.py
NVTE_INT8_SIM_FP8
=
1
NVTE_INT8_SIM_FP8_TENSORWISE
=
1 python3
-m
pytest
-v
-s
test_float8_current_scaling_exact.py
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_scaling_exact.py
||
test_fail
"test_float8_blockwise_scaling_exact.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_gemm_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_gemm_exact.py
||
test_fail
"test_float8_blockwise_gemm_exact.py"
python3
$TE_PATH
/tests/pytorch/test_int8_blockwise_gemm_exact.py
...
...
tests/pytorch/test_int8_channelwise_gemm_exact.py
View file @
3eb6ea62
# NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 test_int8_channelwise_gemm_exact.py
from
collections.abc
import
Iterable
import
io
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
,
Optional
...
...
@@ -20,10 +21,16 @@ import transformer_engine_torch as tex
from
references.ref_per_tensor_cs
import
ref_per_tensor_cs_cast
from
transformer_engine.pytorch.triton.per_token_group_quant
import
per_token_quant_int8
,
per_token_quant_int8_v2
,
per_token_quant_fp8_to_int8
,
per_token_quant_fp8_to_int8_v2
,
channelwise_dequantize
,
channelwise_dequantize_transA
,
channelwise_dequantize_transB
,
per_token_quant_fp8_to_int8_opt
from
transformer_engine.pytorch.triton.per_token_group_quant
import
(
per_token_quant_fp8_to_int8
,
per_token_quant_fp8_to_int8_opt
,
channelwise_dequantize
,
channelwise_dequantize_transA
,
channelwise_dequantize_transB
,
tensorwise_dequantize
)
import
time
import
os
int8_simulation_fp8_tensorwise
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
,
"0"
)))
# TN
m
=
4096
...
...
@@ -227,6 +234,7 @@ transb = False
x_bf16
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
w_bf16
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
output
=
(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
bf16_out
=
torch
.
matmul
(
x_bf16
,
w_bf16
.
t
())
...
...
@@ -248,8 +256,12 @@ w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e4m3fn))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e4m3fn))
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
x_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
w_fp8
.
_scale_inv
,
False
)
if
int8_simulation_fp8_tensorwise
:
x_int8
,
x_scales
=
x_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
x_fp8
.
_scale_inv
w_int8
,
w_scales
=
w_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
w_fp8
.
_scale_inv
else
:
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
x_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
w_fp8
.
_scale_inv
,
False
)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
...
...
@@ -278,40 +290,43 @@ y_int32 = tex.generic_gemm(
# y_int32 = torch._int_mm(x_int8, w_int8.t())
# print("y_int32: ", y_int32)
output
=
channelwise_dequantize_transB
(
x_scales
,
w_scales
,
y_int32
)
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize
(
x_scales
,
w_scales
,
y_int32
,
output
)
else
:
output
=
channelwise_dequantize_transB
(
x_scales
,
w_scales
,
y_int32
)
# print("out_scales.shape: ", out_scales.shape)
# print("out_scales: ", out_scales)
print
(
"bf16_out: "
,
bf16_out
)
print
(
"output: "
,
output
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
i
in
range
(
20
):
x_fp8
=
to_float8_CS
(
x_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
x_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e4m3fn
),
w_fp8
.
_scale_inv
,
False
)
y_int32
=
tex
.
generic_gemm
(
w_int8
,
transa
,
x_int8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
output
=
channelwise_dequantize_transB
(
x_scales
,
w_scales
,
y_int32
)
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
#
torch.cuda.synchronize()
#
start = time.time()
#
for i in range(20):
#
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
#
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
#
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False)
#
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False)
#
y_int32 = tex.generic_gemm(
#
w_int8,
#
transa,
#
x_int8,
#
transb,
#
out,
#
out_quantizer,
#
TE_DType[out_dtype],
#
bias,
#
bias_dtype,
#
use_gelu,
#
aux_tensor,
#
use_grad,
#
workspace,
#
workspace.shape[0],
#
accumulate,
#
use_split_accumulator,
#
)[0]
#
output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
#
torch.cuda.synchronize()
#
end = time.time()
# NN
...
...
@@ -321,6 +336,7 @@ transb = False
dy_bf16
=
(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
w_bf16
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
dx
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
bf16_dx
=
torch
.
matmul
(
dy_bf16
,
w_bf16
)
...
...
@@ -336,8 +352,12 @@ dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
w_fp8
=
to_float8_CS
(
w_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8
(
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
dy_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8_opt
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
w_fp8
.
_scale_inv
,
False
)
if
int8_simulation_fp8_tensorwise
:
dy_int8
,
dy_scales
=
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
dy_fp8
.
_scale_inv
w_int8
,
w_scales
=
w_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
w_fp8
.
_scale_inv
else
:
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8
(
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
dy_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8_opt
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
w_fp8
.
_scale_inv
,
False
)
# w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
...
...
@@ -372,7 +392,10 @@ dx_int32 = tex.generic_gemm(
# dx_int32 = torch._int_mm(dy_int8, w_int8)
# print("dx_int32: ", dx_int32)
dx
=
channelwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
)
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
,
dx
)
else
:
dx
=
channelwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
# print("dx_scales.shape: ", dx_scales.shape)
...
...
@@ -380,38 +403,38 @@ dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
print
(
"bf16_dx: "
,
bf16_dx
)
print
(
"dx: "
,
dx
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
i
in
range
(
20
):
dy_fp8
=
to_float8_CS
(
dy_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8
(
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
dy_fp8
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8_opt
(
w_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
w_fp8
.
_scale_inv
,
False
)
# w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
dx_int32
=
tex
.
generic_gemm
(
w_int8
,
transa
,
dy_int8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
dx
=
channelwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
#
torch.cuda.synchronize()
#
start = time.time()
#
for i in range(20):
#
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
#
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
#
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
#
dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
#
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
#
# w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
#
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
#
dx_int32 = tex.generic_gemm(
#
w_int8,
#
transa,
#
dy_int8,
#
transb,
#
out,
#
out_quantizer,
#
TE_DType[out_dtype],
#
bias,
#
bias_dtype,
#
use_gelu,
#
aux_tensor,
#
use_grad,
#
workspace,
#
workspace.shape[0],
#
accumulate,
#
use_split_accumulator,
#
)[0]
#
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
#
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
#
torch.cuda.synchronize()
#
end = time.time()
# NT
...
...
@@ -423,6 +446,7 @@ transb = True
dy_bf16
=
(
torch
.
randn
((
m
,
n
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
x_bf16
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
dw
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)).
to
(
dtype
=
torch
.
bfloat16
)
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
...
...
@@ -442,10 +466,14 @@ end = time.time()
dy_fp8
=
to_float8_CS
(
dy_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
x_fp8
=
to_float8_CS
(
x_bf16
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
)
if
int8_simulation_fp8_tensorwise
:
dy_int8
,
dy_scales
=
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
dy_fp8
.
_scale_inv
x_int8
,
x_scales
=
x_fp8
.
_data
.
view
(
dtype
=
torch
.
int8
),
x_fp8
.
_scale_inv
else
:
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8_opt
(
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
dy_fp8
.
_scale_inv
,
False
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8_opt
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
x_fp8
.
_scale_inv
,
False
)
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8_opt
(
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
dy_fp8
.
_scale_inv
,
False
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8_opt
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
x_fp8
.
_scale_inv
,
False
)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
...
...
@@ -471,47 +499,50 @@ dw_int32 = tex.generic_gemm(
use_split_accumulator
,
)[
0
]
dw
=
channelwise_dequantize_transA
(
dy_scales
,
x_scales
,
dw_int32
)
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize
(
dy_scales
,
x_scales
,
dw_int32
,
dw
)
else
:
dw
=
channelwise_dequantize_transA
(
dy_scales
,
x_scales
,
dw_int32
)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
print
(
"bf16_dw: "
,
bf16_dw
)
print
(
"dw: "
,
dw
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
for
i
in
range
(
20
):
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
#
torch.cuda.synchronize()
#
start = time.time()
#
for i in range(20):
#
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
#
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
#
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
#
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8_opt
(
dy_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
dy_fp8
.
_scale_inv
,
False
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8_opt
(
x_fp8
.
_data
.
view
(
dtype
=
torch
.
float8_e5m2
),
x_fp8
.
_scale_inv
,
False
)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dw_int32
=
tex
.
generic_gemm
(
x_int8
,
transa
,
dy_int8
,
transb
,
out
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
dw
=
channelwise_dequantize_transA
(
dy_scales
,
x_scales
,
dw_int32
)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
#
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
#
# x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
#
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
#
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
#
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
#
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
#
dw_int32 = tex.generic_gemm(
#
x_int8,
#
transa,
#
dy_int8,
#
transb,
#
out,
#
out_quantizer,
#
TE_DType[out_dtype],
#
bias,
#
bias_dtype,
#
use_gelu,
#
aux_tensor,
#
use_grad,
#
workspace,
#
workspace.shape[0],
#
accumulate,
#
use_split_accumulator,
#
)[0]
#
dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
#
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
#
torch.cuda.synchronize()
#
end = time.time()
...
...
transformer_engine/common/common.h
View file @
3eb6ea62
...
...
@@ -716,6 +716,12 @@ struct is_fp8<fp8e4m3> : std::true_type {};
template
<
>
struct
is_fp8
<
fp8e5m2
>
:
std
::
true_type
{};
template
<
typename
T
>
struct
is_int8
:
std
::
false_type
{};
template
<
>
struct
is_int8
<
int8
>
:
std
::
true_type
{};
template
<
typename
T
>
struct
is_fp4
:
std
::
false_type
{};
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
3eb6ea62
...
...
@@ -1479,8 +1479,8 @@ private:
};
// Define a static userArgs manager
//
static userArgsManager UAManager;
//
static d_userArgsManager d_UAManager;
static
userArgsManager
UAManager
;
static
d_userArgsManager
d_UAManager
;
void
hipblaslt_goupedgemm
(
std
::
vector
<
const
Tensor
*>&
inputA
,
std
::
vector
<
const
Tensor
*>&
inputB
,
std
::
vector
<
Tensor
*>&
outputD
,
std
::
vector
<
int64_t
>&
m
,
std
::
vector
<
int64_t
>&
n
,
std
::
vector
<
int64_t
>&
k
,
std
::
vector
<
int64_t
>&
b
,
hipblasOperation_t
transa
,
hipblasOperation_t
transb
,
...
...
@@ -1489,10 +1489,10 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid.
NVTE_CHECK
(
compute_stream_offset
>=
-
1
&&
compute_stream_offset
<
compute_num_streams
);
//
int device_id;
//
hipGetDevice(&device_id);
//
hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
//
hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size());
int
device_id
;
hipGetDevice
(
&
device_id
);
hipblaslt_ext
::
UserArguments
*
userArgs
=
UAManager
.
get
(
device_id
,
m
.
size
());
hipblaslt_ext
::
UserArguments
*
d_userArgs
=
d_UAManager
.
get
(
device_id
,
m
.
size
());
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
...
...
@@ -1566,20 +1566,20 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
}
// Get the default values from the grouepdgemm object
//
groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
groupedgemm
.
getDefaultValueForDeviceUserArguments
(
userArgs
);
// Copy them to device memory
// hipblaslt_ext::UserArguments* d_userArgs;
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
//
NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs,
//
userArgs,
//
m.size() * sizeof(hipblaslt_ext::UserArguments),
//
hipMemcpyHostToDevice, stream));
NVTE_CHECK_CUDA
(
hipMemcpyAsync
(
d_userArgs
,
userArgs
,
m
.
size
()
*
sizeof
(
hipblaslt_ext
::
UserArguments
),
hipMemcpyHostToDevice
,
stream
));
// Make sure to initialize everytime the algo changes
//
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
//
NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
initialize
(
heuristicResult
[
0
].
algo
,
workspace
,
false
,
stream
));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
stream
));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
initialize
(
heuristicResult
[
0
].
algo
,
workspace
));
NVTE_CHECK_HIPBLASLT
(
groupedgemm
.
run
(
d_userArgs
,
stream
));
//
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
//
NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
// NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
// NVTE_CHECK_CUDA(hipFree(userArgs));
...
...
transformer_engine/common/recipe/current_scaling.cu
View file @
3eb6ea62
...
...
@@ -291,8 +291,8 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
"Tensor must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode="
,
to_string
(
output
.
scaling_mode
));
NVTE_CHECK
(
is_fp8_dtype
(
output
.
data
.
dtype
),
"Tensor must be FP8, but got dtype="
,
to_string
(
output
.
data
.
dtype
));
NVTE_CHECK
(
is_fp8_dtype
(
output
.
data
.
dtype
)
||
is_int8_dtype
(
output
.
data
.
dtype
)
,
"Tensor must be FP8
or INT8
, but got dtype="
,
to_string
(
output
.
data
.
dtype
));
NVTE_CHECK
(
output
.
amax
.
numel
()
==
1
,
"Tensor has invalid amax tensor (expected 1 entry, got shape="
,
output
.
amax
.
shape
,
")"
);
...
...
@@ -314,7 +314,7 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
// Maximum FP8 value
float
max_fp8
=
0.
f
;
TRANSFORMER_ENGINE_TYPE_SWITCH_
FP8ONLY
(
output
.
data
.
dtype
,
DType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_
8BIT
(
output
.
data
.
dtype
,
DType
,
max_fp8
=
Quantized_Limits
<
DType
>::
max_norm
;);
// Update scale
...
...
transformer_engine/common/transformer_engine.cpp
View file @
3eb6ea62
...
...
@@ -166,7 +166,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
void
CheckOutputTensor
(
const
Tensor
&
t
,
const
std
::
string
&
name
,
bool
allow_empty
)
{
const
DType
type
=
t
.
dtype
();
if
(
is_fp8_dtype
(
type
))
{
if
(
is_fp8_dtype
(
type
)
||
is_int8_dtype
(
type
)
)
{
// FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
if
(
t
.
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
&&
t
.
amax
.
dptr
!=
nullptr
)
{
NVTE_CHECK
(
t
.
amax
.
dtype
==
DType
::
kFloat32
,
"Invalid amax dtype (expected "
,
...
...
transformer_engine/common/util/cast_kernels.cuh
View file @
3eb6ea62
...
...
@@ -1076,7 +1076,7 @@ void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop,
const
size_t
N
=
product
(
input
.
data
.
shape
);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
data
.
dtype
,
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
_WITH_INT8
(
output
->
data
.
dtype
,
OType
,
if
(
!
is_fp8_dtype
(
output
->
data
.
dtype
)
||
is_tensor_scaling
(
output
->
scaling_mode
))
{
constexpr
int
nvec
=
32
/
sizeof
(
IType
);
...
...
@@ -1105,7 +1105,7 @@ void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *inp
const
size_t
N
=
product
(
input
->
data
.
shape
);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
->
data
.
dtype
,
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
_WITH_INT8
(
output
->
data
.
dtype
,
OType
,
if
(
!
is_fp8_dtype
(
output
->
data
.
dtype
)
||
is_tensor_scaling
(
output
->
scaling_mode
))
{
constexpr
int
nvec
=
32
/
sizeof
(
IType
);
...
...
@@ -1275,6 +1275,11 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
switch
(
output_tensor
->
scaling_mode
)
{
case
NVTE_DELAYED_TENSOR_SCALING
:
{
if
(
output_tensor
->
has_columnwise_data
())
{
const
char
*
NVTE_INT8_SIM_FP8_TENSORWISE
=
std
::
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
);
if
(
NVTE_INT8_SIM_FP8_TENSORWISE
!=
nullptr
&&
NVTE_INT8_SIM_FP8_TENSORWISE
[
0
]
==
'1'
){
NVTE_CHECK
(
false
,
"NVTE_INT8_SIM_FP8_TENSORWISE need not be transposed!"
);
}
NVTE_CHECK
(
output_tensor
->
has_data
(),
"Quantizing in only the columnwise direction not supported yet!"
);
if
constexpr
(
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
)
{
...
...
transformer_engine/common/util/dequantize_kernels.cuh
View file @
3eb6ea62
...
...
@@ -231,12 +231,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#endif // __HIP_PLATFORM_AMD__
static
void
fp8_dequantize
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
NVTE_CHECK
(
is_fp8_dtype
(
input
.
data
.
dtype
),
"Input must have FP8 type."
);
NVTE_CHECK
(
!
is_fp8_dtype
(
output
->
data
.
dtype
),
"Output must be in higher precision."
);
NVTE_CHECK
(
is_fp8_dtype
(
input
.
data
.
dtype
)
||
is_int8_dtype
(
input
.
data
.
dtype
)
,
"Input must have FP8
or INT8
type."
);
NVTE_CHECK
(
!
is_fp8_dtype
(
output
->
data
.
dtype
)
&&
!
is_int8_dtype
(
output
->
data
.
dtype
)
,
"Output must be in higher precision."
);
NVTE_CHECK
(
output
->
data
.
shape
==
input
.
data
.
shape
,
"Input and output shapes need to match."
);
const
size_t
N
=
product
(
input
.
data
.
shape
);
TRANSFORMER_ENGINE_TYPE_SWITCH_
FP8ONLY
(
TRANSFORMER_ENGINE_TYPE_SWITCH_
8BIT
(
input
.
data
.
dtype
,
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
output
->
data
.
dtype
,
OType
,
...
...
transformer_engine/common/util/vectorized_pointwise.h
View file @
3eb6ea62
...
...
@@ -183,7 +183,7 @@ __launch_bounds__(unary_kernel_threads) __global__
VectorizedStorer
<
OutputType
,
nvec
,
aligned
>
storer
(
output
,
N
);
ComputeType
max
=
0
;
ComputeType
s
=
1
;
if
constexpr
(
is_fp8
<
OutputType
>::
value
)
{
if
constexpr
(
is_fp8
<
OutputType
>::
value
||
is_int8
<
OutputType
>::
value
)
{
if
(
scale
!=
nullptr
)
s
=
*
scale
;
}
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
...
...
@@ -196,18 +196,21 @@ __launch_bounds__(unary_kernel_threads) __global__
for
(
int
i
=
0
;
i
<
nvec
;
++
i
)
{
const
ComputeType
val
=
static_cast
<
ComputeType
>
(
loader
.
separate
()[
i
]);
ComputeType
temp
=
OP
(
val
,
p
);
if
constexpr
(
is_fp8
<
OutputType
>::
value
)
{
if
constexpr
(
is_fp8
<
OutputType
>::
value
||
is_int8
<
OutputType
>::
value
)
{
__builtin_assume
(
max
>=
0
);
max
=
fmaxf
(
fabsf
(
temp
),
max
);
temp
=
temp
*
s
;
}
storer
.
separate
()[
i
]
=
static_cast
<
OutputType
>
(
temp
);
if
constexpr
(
is_int8
<
OutputType
>::
value
)
{
storer
.
separate
()[
i
]
=
static_cast
<
OutputType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
temp
))));
}
else
{
storer
.
separate
()[
i
]
=
static_cast
<
OutputType
>
(
temp
);
}
}
storer
.
store
(
tid
,
N
);
}
if
constexpr
(
is_fp8
<
OutputType
>::
value
)
{
if
constexpr
(
is_fp8
<
OutputType
>::
value
||
is_int8
<
OutputType
>::
value
)
{
// Reduce amax over block
if
(
amax
!=
nullptr
)
{
max
=
reduce_max
<
unary_kernel_threads
/
THREADS_PER_WARP
>
(
max
,
warp_id
);
...
...
@@ -236,7 +239,7 @@ __launch_bounds__(unary_kernel_threads) __global__
VectorizedStorer
<
OutputType
,
nvec
,
aligned
>
storer
(
output
,
N
);
ComputeType
max
=
0
;
ComputeType
s
=
1
;
if
constexpr
(
is_fp8
<
OutputType
>::
value
)
{
if
constexpr
(
is_fp8
<
OutputType
>::
value
||
is_int8
<
OutputType
>::
value
)
{
if
(
scale
!=
nullptr
)
s
=
*
scale
;
}
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
...
...
@@ -251,18 +254,21 @@ __launch_bounds__(unary_kernel_threads) __global__
const
ComputeType
val
=
static_cast
<
ComputeType
>
(
loader
.
separate
()[
i
]);
const
ComputeType
g
=
static_cast
<
ComputeType
>
(
grad_loader
.
separate
()[
i
]);
ComputeType
temp
=
OP
(
val
,
p
)
*
g
;
if
constexpr
(
is_fp8
<
OutputType
>::
value
)
{
if
constexpr
(
is_fp8
<
OutputType
>::
value
||
is_int8
<
OutputType
>::
value
)
{
__builtin_assume
(
max
>=
0
);
max
=
fmaxf
(
fabsf
(
temp
),
max
);
temp
=
temp
*
s
;
}
storer
.
separate
()[
i
]
=
static_cast
<
OutputType
>
(
temp
);
if
constexpr
(
is_int8
<
OutputType
>::
value
)
{
storer
.
separate
()[
i
]
=
static_cast
<
OutputType
>
(
lroundf
(
fmaxf
(
-
127.0
f
,
fminf
(
127.0
f
,
temp
))));
}
else
{
storer
.
separate
()[
i
]
=
static_cast
<
OutputType
>
(
temp
);
}
}
storer
.
store
(
tid
,
N
);
}
if
constexpr
(
is_fp8
<
OutputType
>::
value
)
{
if
constexpr
(
is_fp8
<
OutputType
>::
value
||
is_int8
<
OutputType
>::
value
)
{
// Reduce amax over block
if
(
amax
!=
nullptr
)
{
max
=
reduce_max
<
unary_kernel_threads
/
THREADS_PER_WARP
>
(
max
,
warp_id
);
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
3eb6ea62
...
...
@@ -30,10 +30,14 @@ from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_q
channelwise_dequantize_transA_float
,
channelwise_dequantize_transB
,
channelwise_dequantize_transA_add
,
channelwise_dequantize_transA_float_add
)
channelwise_dequantize_transA_float_add
,
tensorwise_dequantize
,
tensorwise_dequantize_add
,
tensorwise_dequantize_float
,
tensorwise_dequantize_float_add
)
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
from
transformer_engine.pytorch.fp8
import
int8_simulation_fp8
,
int8_simulation_fp8_tensorwise
__all__
=
[
"general_gemm"
,
"general_grouped_gemm"
,
...
...
@@ -191,8 +195,12 @@ def general_gemm(
if
layout
==
"TN"
:
assert
out_dtype
is
torch
.
bfloat16
out_shape
=
B
.
_data
.
shape
[:
-
1
]
+
(
A
.
_data
.
shape
[
0
],
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
if
int8_simulation_fp8_tensorwise
:
x_int8
,
x_scales
=
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
w_int8
,
w_scales
=
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
else
:
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
y_int32
=
tex
.
generic_gemm
(
w_int8
,
...
...
@@ -212,14 +220,22 @@ def general_gemm(
False
,
use_split_accumulator
,
)[
0
]
y
=
channelwise_dequantize_transB
(
x_scales
,
w_scales
,
y_int32
)
if
int8_simulation_fp8_tensorwise
:
y
=
torch
.
empty_like
(
y_int32
,
device
=
y_int32
.
device
,
dtype
=
torch
.
bfloat16
)
tensorwise_dequantize
(
x_scales
,
w_scales
,
y_int32
,
y
)
else
:
y
=
channelwise_dequantize_transB
(
x_scales
,
w_scales
,
y_int32
)
return
y
.
view
(
out_shape
),
None
,
None
,
None
elif
layout
==
"NN"
:
assert
out_dtype
is
torch
.
bfloat16
dx_shape
=
B
.
_data
.
shape
[:
-
1
]
+
(
A
.
_data
.
shape
[
-
1
],
)
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8_opt
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
if
int8_simulation_fp8_tensorwise
:
dy_int8
,
dy_scales
=
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
w_int8
,
w_scales
=
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
else
:
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
w_int8
,
w_scales
=
per_token_quant_fp8_to_int8_opt
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
dx_int32
=
tex
.
generic_gemm
(
w_int8
,
...
...
@@ -239,13 +255,21 @@ def general_gemm(
False
,
use_split_accumulator
,
)[
0
]
dx
=
channelwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
)
if
int8_simulation_fp8_tensorwise
:
dx
=
torch
.
empty_like
(
dx_int32
,
device
=
dx_int32
.
device
,
dtype
=
torch
.
bfloat16
)
tensorwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
,
dx
)
else
:
dx
=
channelwise_dequantize
(
dy_scales
,
w_scales
,
dx_int32
)
return
dx
.
view
(
dx_shape
),
None
,
None
,
None
elif
layout
==
"NT"
:
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float32
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8_opt
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8_opt
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
if
int8_simulation_fp8_tensorwise
:
dy_int8
,
dy_scales
=
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
x_int8
,
x_scales
=
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
else
:
dy_int8
,
dy_scales
=
per_token_quant_fp8_to_int8_opt
(
B
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
B
.
_fp8_dtype
]),
B
.
_scale_inv
,
False
)
x_int8
,
x_scales
=
per_token_quant_fp8_to_int8_opt
(
A
.
_data
.
view
(
dtype
=
TE_DType_To_Torch
[
A
.
_fp8_dtype
]),
A
.
_scale_inv
,
False
)
dw_int32
=
tex
.
generic_gemm
(
x_int8
,
...
...
@@ -267,14 +291,29 @@ def general_gemm(
)[
0
]
if
out_dtype
is
torch
.
bfloat16
:
if
accumulate
:
channelwise_dequantize_transA_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
channelwise_dequantize_transA_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
out
=
channelwise_dequantize_transA
(
dy_scales
,
x_scales
,
dw_int32
)
if
int8_simulation_fp8_tensorwise
:
out
=
torch
.
empty_like
(
dw_int32
,
device
=
dw_int32
.
device
,
dtype
=
torch
.
bfloat16
)
tensorwise_dequantize
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
out
=
channelwise_dequantize_transA
(
dy_scales
,
x_scales
,
dw_int32
)
else
:
if
accumulate
:
channelwise_dequantize_transA_float_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
if
int8_simulation_fp8_tensorwise
:
tensorwise_dequantize_float_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
channelwise_dequantize_transA_float_add
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
out
=
channelwise_dequantize_transA_float
(
dy_scales
,
x_scales
,
dw_int32
)
if
int8_simulation_fp8_tensorwise
:
out
=
torch
.
empty_like
(
dw_int32
,
device
=
dw_int32
.
device
,
dtype
=
torch
.
float32
)
tensorwise_dequantize_float
(
dy_scales
,
x_scales
,
dw_int32
,
out
)
else
:
out
=
channelwise_dequantize_transA_float
(
dy_scales
,
x_scales
,
dw_int32
)
return
out
,
None
,
None
,
None
else
:
...
...
transformer_engine/pytorch/fp8.py
View file @
3eb6ea62
...
...
@@ -28,6 +28,7 @@ from .utils import get_device_compute_capability
from
.jit
import
jit_fuser
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
int8_simulation_fp8_tensorwise
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8_TENSORWISE"
,
"0"
)))
blockwise_fp8_block_len
=
int
(
os
.
getenv
(
"NVTE_BLOCKWISE_FP8_BLOCK_LEN"
,
"128"
))
__all__
=
[
"fp8_autocast"
,
"fp8_model_init"
]
...
...
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
View file @
3eb6ea62
...
...
@@ -20,7 +20,7 @@ from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
aten
=
torch
.
ops
.
aten
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
from
transformer_engine.pytorch.fp8
import
int8_simulation_fp8
class
Float8BlockQuantizer
(
Quantizer
):
"""Builder class for tensors quantized with current scaling using
...
...
transformer_engine/pytorch/tensor/float8_tensor.py
View file @
3eb6ea62
...
...
@@ -16,6 +16,7 @@ from ..utils import canonicalize_process_group, devices_match
from
._internal.float8_tensor_base
import
Float8TensorBase
,
_FromFloat8Func
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
from
..constants
import
dist_group_type
from
transformer_engine.pytorch.fp8
import
int8_simulation_fp8_tensorwise
aten
=
torch
.
ops
.
aten
...
...
@@ -217,7 +218,7 @@ class Float8CurrentScalingQuantizer(Quantizer):
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
scale
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
amax
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
dtype
=
fp8_dtype
self
.
dtype
=
tex
.
DType
.
kInt8
if
int8_simulation_fp8_tensorwise
else
fp8_dtype
self
.
with_amax_reduction
=
with_amax_reduction
self
.
amax_reduction_group
=
amax_reduction_group
self
.
force_pow_2_scales
=
force_pow_2_scales
...
...
transformer_engine/pytorch/tensor/utils.py
View file @
3eb6ea62
...
...
@@ -284,6 +284,8 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
max_fp8
=
448.0
elif
fp8_dtype
==
tex
.
DType
.
kFloat8E5M2
:
max_fp8
=
57344.0
elif
fp8_dtype
==
tex
.
DType
.
kInt8
:
max_fp8
=
127.0
else
:
raise
ValueError
(
f
"Unsupported FP8 dtype:
{
fp8_dtype
}
"
)
multi_tensor_applier
(
...
...
transformer_engine/pytorch/triton/per_token_group_quant.py
View file @
3eb6ea62
...
...
@@ -122,6 +122,68 @@ def per_token_quant_int8_v2(x):
@
triton
.
jit
def
_tensorwise_dequantize_impl
(
x_ptr
,
y_ptr
,
scaleA_ptr
,
scaleB_ptr
,
stride_x
,
stride_y
,
N
,
is_add
:
tl
.
constexpr
,
is_float
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
):
row_id
=
tl
.
program_id
(
0
)
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
a_scale
=
tl
.
load
(
scaleA_ptr
)
b_scale
=
tl
.
load
(
scaleB_ptr
)
x
=
tl
.
load
(
x_ptr
+
row_id
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
result
=
x
*
a_scale
*
b_scale
if
is_add
:
y
=
tl
.
load
(
y_ptr
+
row_id
*
stride_y
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
result
+=
y
if
is_float
:
tl
.
store
(
y_ptr
+
row_id
*
stride_y
+
cols
,
result
,
mask
=
mask
)
else
:
tl
.
store
(
y_ptr
+
row_id
*
stride_y
+
cols
,
result
.
to
(
tl
.
bfloat16
),
mask
=
mask
)
def
_tensorwise_dequantize
(
a_scale
,
b_scale
,
x
,
y
,
is_add
=
False
,
is_float
=
False
):
assert
x
.
is_contiguous
()
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
M
=
x
.
numel
()
//
x
.
shape
[
-
1
]
N
=
x
.
shape
[
-
1
]
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
assert
x
.
is_contiguous
()
_tensorwise_dequantize_impl
[(
M
,
)](
x
,
y
,
a_scale
,
b_scale
,
stride_x
=
x
.
stride
(
-
2
),
stride_y
=
y
.
stride
(
-
2
),
N
=
N
,
is_add
=
is_add
,
is_float
=
is_float
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
@
triton
.
jit
def
_per_token_quant_fp8_to_int8
(
x_ptr
,
...
...
@@ -343,6 +405,18 @@ def channelwise_dequantize_transB(A, B, C):
out_scales
=
A
*
B
.
T
return
(
out_scales
*
C
.
to
(
dtype
=
torch
.
float32
)).
to
(
torch
.
bfloat16
)
def
tensorwise_dequantize
(
A
,
B
,
C
,
D
):
_tensorwise_dequantize
(
A
,
B
,
C
,
D
,
is_add
=
False
,
is_float
=
False
)
def
tensorwise_dequantize_float
(
A
,
B
,
C
,
D
):
_tensorwise_dequantize
(
A
,
B
,
C
,
D
,
is_add
=
False
,
is_float
=
True
)
def
tensorwise_dequantize_add
(
A
,
B
,
C
,
D
):
_tensorwise_dequantize
(
A
,
B
,
C
,
D
,
is_add
=
True
,
is_float
=
False
)
def
tensorwise_dequantize_float_add
(
A
,
B
,
C
,
D
):
_tensorwise_dequantize
(
A
,
B
,
C
,
D
,
is_add
=
True
,
is_float
=
True
)
def
to_int8
(
tensor
:
torch
.
Tensor
):
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment