Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
1b303e91
Commit
1b303e91
authored
Jun 04, 2025
by
yuguo
Browse files
Merge branch 'develop_v2.3' of
http://10.16.6.30/dcutoolkit/deeplearing/TransformerEngine
parents
52ba87a1
735227cd
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
2335 additions
and
9 deletions
+2335
-9
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+2
-0
tests/pytorch/test_float8_current_scaling_exact.py
tests/pytorch/test_float8_current_scaling_exact.py
+2
-2
tests/pytorch/test_int8_blockwise_gemm_exact.py
tests/pytorch/test_int8_blockwise_gemm_exact.py
+708
-0
tests/pytorch/test_int8_blockwise_layers.py
tests/pytorch/test_int8_blockwise_layers.py
+175
-0
transformer_engine/pytorch/constants.py
transformer_engine/pytorch/constants.py
+1
-0
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+64
-1
transformer_engine/pytorch/fp8.py
transformer_engine/pytorch/fp8.py
+9
-4
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
+3
-2
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
+563
-0
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
...mer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
+568
-0
transformer_engine/pytorch/triton/per_token_group_quant.py
transformer_engine/pytorch/triton/per_token_group_quant.py
+240
-0
No files found.
qa/L0_pytorch_unittest/test.sh
View file @
1b303e91
...
...
@@ -36,6 +36,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8blockwisetensor.xml
$TE_PATH
/tests/pytorch/test_float8blockwisetensor.py
||
test_fail
"test_float8blockwisetensor.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_scaling_exact.py
||
test_fail
"test_float8_blockwise_scaling_exact.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_gemm_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_gemm_exact.py
||
test_fail
"test_float8_blockwise_gemm_exact.py"
python3
$TE_PATH
/tests/pytorch/test_int8_blockwise_gemm_exact.py
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_int8_blockwise_layers.xml
$TE_PATH
/tests/pytorch/test_int8_blockwise_layers.py
||
test_fail
"test_int8_blockwise_layers.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_gqa.xml
$TE_PATH
/tests/pytorch/test_gqa.py
||
test_fail
"test_gqa.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_optimizer.xml
$TE_PATH
/tests/pytorch/test_fused_optimizer.py
||
test_fail
"test_fused_optimizer.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_multi_tensor.xml
$TE_PATH
/tests/pytorch/test_multi_tensor.py
||
test_fail
"test_multi_tensor.py"
...
...
tests/pytorch/test_float8_current_scaling_exact.py
View file @
1b303e91
...
...
@@ -385,7 +385,7 @@ class TestFP8RecipeLinearBase:
)
# recipe1
using_fp8_recipe
=
recipe1
!=
GetRecipes
.
n
one
using_fp8_recipe
=
recipe1
()
is
not
N
one
if
using_fp8_recipe
:
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
recipe1
()):
y_q_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
...
...
@@ -608,7 +608,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
)
# recipe1
using_fp8_recipe
=
recipe1
!=
GetRecipes
.
n
one
using_fp8_recipe
=
recipe1
()
is
not
N
one
if
using_fp8_recipe
:
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
recipe1
()):
y_q_ref
,
ln_out_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
self
.
run_layernorm_linear
(
...
...
tests/pytorch/test_int8_blockwise_gemm_exact.py
0 → 100644
View file @
1b303e91
import
pytest
import
torch
import
transformer_engine
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
)
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad
import
w8a8_block_int8_matmul_wgrad
from
references.blockwise_quantizer_reference
import
CuBLASScaleMunger
from
references.blockwise_fp8_gemm_reference
import
CuBLASRefBlockwiseGemm
def
fp8_blockwise_gemm_supported
()
->
bool
:
supported
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
return
supported
def
cublas_gemm_fp8_blockwise_case_fw
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
noise_type
,
x_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
*
,
x_columnwise
:
bool
=
False
,
w_columnwise
:
bool
=
False
,
use_bias
:
bool
=
False
,
use_gelu
:
bool
=
False
,
use_grad
:
bool
=
False
,
atol
:
float
=
5e-1
,
rtol
:
float
=
5e-1
):
if
x_dtype
==
torch
.
float8_e5m2
and
w_dtype
==
torch
.
float8_e5m2
:
pytest
.
skip
(
"FP8 GEMM doesn't support both a and b types being torch.float8_e5m2"
)
if
not
(
is_x_1d_scaled
or
is_w_1d_scaled
):
pytest
.
skip
(
"FP8 GEMM doesn't support 2dimensional qtile by 2dimensional qtile"
)
if
not
fp8_blockwise_gemm_supported
():
pytest
.
skip
(
"CUDA version does not support blockwise FP8 gemm."
)
# Setup device and random seed
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x_shape
=
(
K
,
M
)
if
x_columnwise
else
(
M
,
K
)
w_shape
=
(
K
,
N
)
if
w_columnwise
else
(
N
,
K
)
# generate random input and weight
if
noise_type
==
"uniform"
:
x
=
torch
.
rand
(
x_shape
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
*
x_magnitude
*
2
-
x_magnitude
w
=
torch
.
rand
(
w_shape
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
*
w_magnitude
*
2
-
w_magnitude
elif
noise_type
==
"normal"
:
x
=
torch
.
randn
(
x_shape
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
*
x_magnitude
w
=
torch
.
randn
(
w_shape
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
*
w_magnitude
else
:
assert
False
bf16_out
=
torch
.
matmul
(
x
,
w
.
t
())
# print(f"x.shape: {x.shape}, w.shape: {w.shape}")
# print("bf16 gemm output: ", bf16_out)
# print("bf16 gemm output shape: ", bf16_out.shape)
# Setup out tensor if accumulate is True
if
accumulate
:
out
=
torch
.
randn
((
M
,
N
),
dtype
=
out_dtype
,
device
=
device
)
*
x_magnitude
else
:
out
=
None
assert
not
(
use_bias
and
use_grad
),
"Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
x_quant_tile_shape
=
(
1
,
128
)
if
is_x_1d_scaled
else
(
128
,
128
)
w_quant_tile_shape
=
(
1
,
128
)
if
is_w_1d_scaled
else
(
128
,
128
)
x_block_scaling_dim
=
1
if
is_x_1d_scaled
else
2
w_block_scaling_dim
=
1
if
is_w_1d_scaled
else
2
x_te_dtype
=
TE_DType
[
x_dtype
]
w_te_dtype
=
TE_DType
[
w_dtype
]
x_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
x_te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
block_scaling_dim
=
x_block_scaling_dim
,
)
w_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
w_te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
block_scaling_dim
=
w_block_scaling_dim
,
)
# Quantize x and w
qx
=
x_quantizer
.
make_empty
(
x_shape
,
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
qx
=
x_quantizer
.
update_quantized
(
x
,
qx
)
qw
=
w_quantizer
.
make_empty
(
w_shape
,
dtype
=
w_dtype
,
device
=
device
,
requires_grad
=
False
)
qw
=
w_quantizer
.
update_quantized
(
w
,
qw
)
if
not
use_bias
:
bias
=
None
else
:
bias
=
torch
.
randn
((
1
,
N
),
dtype
=
torch
.
bfloat16
,
device
=
device
)
# Reference GEMM
ref_gemm
=
CuBLASRefBlockwiseGemm
()
scale_decoder
=
CuBLASScaleMunger
()
qx_data
=
(
qx
.
_columnwise_data
.
view
(
dtype
=
x_dtype
)
if
x_columnwise
else
qx
.
_rowwise_data
.
view
(
dtype
=
x_dtype
)
)
qw_data
=
(
qw
.
_columnwise_data
.
view
(
dtype
=
w_dtype
)
if
w_columnwise
else
qw
.
_rowwise_data
.
view
(
dtype
=
w_dtype
)
)
ref_scales_x
=
qx
.
_columnwise_scale_inv
if
x_columnwise
else
qx
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
# print(f"qx_data.shape: {qx_data.shape}, qw_data.shape: {qw_data.shape}")
# print(f"ref_scales_x.shape: {ref_scales_x.shape}, ref_scales_w.shape: {ref_scales_w.shape}")
# print(f"ref_scales_x_t.shape: {ref_scales_x.t().shape}")
# print(f"ref_scales_x_columnwise.shape: {qx._columnwise_scale_inv.shape}")
y_ref
=
ref_gemm
.
qgemm
(
qx
=
qx_data
,
qw
=
qw_data
,
out_dtype
=
out_dtype
,
demunged_sx
=
CuBLASScaleMunger
.
demunge_scale_shape_from_backend
(
qtensor_shape
=
(
M
,
K
),
scales
=
ref_scales_x
,
tile_shape
=
x_quant_tile_shape
),
demunged_sw
=
CuBLASScaleMunger
.
demunge_scale_shape_from_backend
(
qtensor_shape
=
(
N
,
K
),
scales
=
ref_scales_w
,
tile_shape
=
w_quant_tile_shape
),
quant_tile_shape_x
=
x_quant_tile_shape
,
quant_tile_shape_w
=
w_quant_tile_shape
,
bias
=
bias
,
out
=
out
.
clone
()
if
accumulate
else
None
,
accumulate
=
accumulate
,
use_split_accumulator
=
use_split_accumulator
,
)
# print("fp8 gemm output: ", y_ref)
# print("fp8 gemm output shape: ", y_ref.shape)
x_te_dtype
=
TE_DType
[
torch
.
int8
]
w_te_dtype
=
TE_DType
[
torch
.
int8
]
x_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
x_te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
block_scaling_dim
=
x_block_scaling_dim
,
)
w_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
w_te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
block_scaling_dim
=
w_block_scaling_dim
,
)
# Quantize x and w
qx
=
x_quantizer
.
make_empty
(
x_shape
,
dtype
=
torch
.
int8
,
device
=
device
,
requires_grad
=
False
)
qx
=
x_quantizer
.
update_quantized
(
x
,
qx
)
qw
=
w_quantizer
.
make_empty
(
w_shape
,
dtype
=
torch
.
int8
,
device
=
device
,
requires_grad
=
False
)
qw
=
w_quantizer
.
update_quantized
(
w
,
qw
)
qx_data
=
(
qx
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
if
x_columnwise
else
qx
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
)
qw_data
=
(
qw
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
if
w_columnwise
else
qw
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
)
ref_scales_x
=
qx
.
_columnwise_scale_inv
if
x_columnwise
else
qx
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
128
,
128
],
output_dtype
=
out_dtype
)
# print("int8 gemm output: ", y)
# print("int8 gemm output shape: ", y.shape)
torch
.
testing
.
assert_close
(
y
,
bf16_out
,
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_close
(
y_ref
,
bf16_out
,
atol
=
atol
,
rtol
=
rtol
)
def
cublas_gemm_fp8_blockwise_case_bw_xgrad
(
dout_dtype
,
w_dtype
,
dx_dtype
,
M
,
K
,
N
,
noise_type
,
dout_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_dout_1d_scaled
,
is_w_1d_scaled
,
*
,
dout_columnwise
:
bool
=
False
,
w_columnwise
:
bool
=
True
,
use_bias
:
bool
=
False
,
use_gelu
:
bool
=
False
,
use_grad
:
bool
=
False
,
atol
:
float
=
5e-1
,
rtol
:
float
=
5e-1
):
if
dout_dtype
==
torch
.
float8_e5m2
and
w_dtype
==
torch
.
float8_e5m2
:
pytest
.
skip
(
"FP8 GEMM doesn't support both a and b types being torch.float8_e5m2"
)
if
not
(
is_dout_1d_scaled
or
is_w_1d_scaled
):
pytest
.
skip
(
"FP8 GEMM doesn't support 2dimensional qtile by 2dimensional qtile"
)
if
not
fp8_blockwise_gemm_supported
():
pytest
.
skip
(
"CUDA version does not support blockwise FP8 gemm."
)
# Setup device and random seed
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
dout_shape
=
(
M
,
N
)
w_shape
=
(
N
,
K
)
# generate random input and weight
if
noise_type
==
"uniform"
:
dout
=
torch
.
rand
(
dout_shape
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
*
dout_magnitude
*
2
-
dout_magnitude
w
=
torch
.
rand
(
w_shape
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
*
w_magnitude
*
2
-
w_magnitude
elif
noise_type
==
"normal"
:
dout
=
torch
.
randn
(
dout_shape
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
*
dout_magnitude
w
=
torch
.
randn
(
w_shape
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
*
w_magnitude
else
:
assert
False
bf16_dx
=
torch
.
matmul
(
dout
,
w
)
# print("bf16 gemm dx: ", bf16_dx)
# Setup out tensor if accumulate is True
if
accumulate
:
dx
=
torch
.
randn
((
M
,
K
),
dtype
=
dx_dtype
,
device
=
device
)
*
dout_magnitude
else
:
dx
=
None
assert
not
(
use_bias
and
use_grad
),
"Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
dout_quant_tile_shape
=
(
1
,
128
)
if
is_dout_1d_scaled
else
(
128
,
128
)
w_quant_tile_shape
=
(
1
,
128
)
if
is_w_1d_scaled
else
(
128
,
128
)
dout_block_scaling_dim
=
1
if
is_dout_1d_scaled
else
2
w_block_scaling_dim
=
1
if
is_w_1d_scaled
else
2
dout_te_dtype
=
TE_DType
[
dout_dtype
]
w_te_dtype
=
TE_DType
[
w_dtype
]
dout_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
dout_te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
block_scaling_dim
=
dout_block_scaling_dim
,
)
w_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
w_te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
block_scaling_dim
=
w_block_scaling_dim
,
)
# Quantize dout and w
qdout
=
dout_quantizer
.
make_empty
(
dout_shape
,
dtype
=
dout_dtype
,
device
=
device
,
requires_grad
=
False
)
qdout
=
dout_quantizer
.
update_quantized
(
dout
,
qdout
)
qw
=
w_quantizer
.
make_empty
(
w_shape
,
dtype
=
w_dtype
,
device
=
device
,
requires_grad
=
False
)
qw
=
w_quantizer
.
update_quantized
(
w
,
qw
)
if
not
use_bias
:
bias
=
None
else
:
bias
=
torch
.
randn
((
1
,
N
),
dtype
=
torch
.
bfloat16
,
device
=
device
)
# Reference GEMM
ref_gemm
=
CuBLASRefBlockwiseGemm
()
scale_decoder
=
CuBLASScaleMunger
()
qdout_data
=
(
qdout
.
_columnwise_data
.
view
(
dtype
=
dout_dtype
)
if
dout_columnwise
else
qdout
.
_rowwise_data
.
view
(
dtype
=
dout_dtype
)
)
qw_data
=
(
qw
.
_columnwise_data
.
view
(
dtype
=
w_dtype
)
if
w_columnwise
else
qw
.
_rowwise_data
.
view
(
dtype
=
w_dtype
)
)
ref_scales_dout
=
qdout
.
_columnwise_scale_inv
if
dout_columnwise
else
qdout
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
y_ref
=
ref_gemm
.
qgemm
(
qx
=
qdout_data
,
qw
=
qw_data
,
out_dtype
=
dx_dtype
,
demunged_sx
=
CuBLASScaleMunger
.
demunge_scale_shape_from_backend
(
qtensor_shape
=
(
M
,
N
),
scales
=
ref_scales_dout
,
tile_shape
=
dout_quant_tile_shape
),
demunged_sw
=
CuBLASScaleMunger
.
demunge_scale_shape_from_backend
(
qtensor_shape
=
(
K
,
N
),
scales
=
ref_scales_w
,
tile_shape
=
w_quant_tile_shape
),
quant_tile_shape_x
=
dout_quant_tile_shape
,
quant_tile_shape_w
=
w_quant_tile_shape
,
bias
=
bias
,
out
=
dx
.
clone
()
if
accumulate
else
None
,
accumulate
=
accumulate
,
use_split_accumulator
=
use_split_accumulator
,
)
# print("fp8 gemm dx: ", y_ref)
dout_te_dtype
=
TE_DType
[
torch
.
int8
]
w_te_dtype
=
TE_DType
[
torch
.
int8
]
dout_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
dout_te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
block_scaling_dim
=
dout_block_scaling_dim
,
)
w_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
w_te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
block_scaling_dim
=
w_block_scaling_dim
,
)
# Quantize x and w
qdout
=
dout_quantizer
.
make_empty
(
dout_shape
,
dtype
=
torch
.
int8
,
device
=
device
,
requires_grad
=
False
)
qdout
=
dout_quantizer
.
update_quantized
(
dout
,
qdout
)
qw
=
w_quantizer
.
make_empty
(
w_shape
,
dtype
=
torch
.
int8
,
device
=
device
,
requires_grad
=
False
)
qw
=
w_quantizer
.
update_quantized
(
w
,
qw
)
qdout_data
=
(
qdout
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
if
dout_columnwise
else
qdout
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
)
qw_data
=
(
qw
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
if
w_columnwise
else
qw
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
)
ref_scales_dout
=
qdout
.
_columnwise_scale_inv
if
dout_columnwise
else
qdout
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
y
,
_
=
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
128
,
128
],
output_dtype
=
dx_dtype
)
# print("int8 gemm dx: ", y)
torch
.
testing
.
assert_close
(
y_ref
,
bf16_dx
,
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_close
(
y
,
bf16_dx
,
atol
=
atol
,
rtol
=
rtol
)
def
cublas_gemm_fp8_blockwise_case_bw_wgrad
(
dout_dtype
,
x_dtype
,
dw_dtype
,
M
,
K
,
N
,
noise_type
,
dout_magnitude
,
x_magnitude
,
accumulate
,
use_split_accumulator
,
is_dout_1d_scaled
,
is_x_1d_scaled
,
*
,
dout_columnwise
:
bool
=
True
,
x_columnwise
:
bool
=
True
,
use_bias
:
bool
=
False
,
use_gelu
:
bool
=
False
,
use_grad
:
bool
=
False
,
atol
:
float
=
5e-1
,
rtol
:
float
=
5e-1
):
if
dout_dtype
==
torch
.
float8_e5m2
and
x_dtype
==
torch
.
float8_e5m2
:
pytest
.
skip
(
"FP8 GEMM doesn't support both a and b types being torch.float8_e5m2"
)
if
not
(
is_dout_1d_scaled
or
is_x_1d_scaled
):
pytest
.
skip
(
"FP8 GEMM doesn't support 2dimensional qtile by 2dimensional qtile"
)
if
not
fp8_blockwise_gemm_supported
():
pytest
.
skip
(
"CUDA version does not support blockwise FP8 gemm."
)
# Setup device and random seed
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
dout_shape
=
(
M
,
N
)
x_shape
=
(
M
,
K
)
# generate random input and weight
if
noise_type
==
"uniform"
:
dout
=
torch
.
rand
(
dout_shape
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
*
dout_magnitude
*
2
-
dout_magnitude
x
=
torch
.
rand
(
x_shape
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
*
x_magnitude
*
2
-
x_magnitude
elif
noise_type
==
"normal"
:
dout
=
torch
.
randn
(
dout_shape
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
*
dout_magnitude
x
=
torch
.
randn
(
x_shape
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
*
x_magnitude
else
:
assert
False
bf16_dw
=
torch
.
matmul
(
dout
.
t
(),
x
)
# print("bf16 gemm dw: ", bf16_dw)
# Setup out tensor if accumulate is True
if
accumulate
:
dw
=
torch
.
randn
((
N
,
K
),
dtype
=
dw_dtype
,
device
=
device
)
*
dout_magnitude
else
:
dw
=
None
assert
not
(
use_bias
and
use_grad
),
"Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
dout_quant_tile_shape
=
(
1
,
128
)
if
is_dout_1d_scaled
else
(
128
,
128
)
x_quant_tile_shape
=
(
1
,
128
)
if
is_x_1d_scaled
else
(
128
,
128
)
dout_block_scaling_dim
=
1
if
is_dout_1d_scaled
else
2
x_block_scaling_dim
=
1
if
is_x_1d_scaled
else
2
dout_te_dtype
=
TE_DType
[
dout_dtype
]
x_te_dtype
=
TE_DType
[
x_dtype
]
dout_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
dout_te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
block_scaling_dim
=
dout_block_scaling_dim
,
)
x_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
x_te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
block_scaling_dim
=
x_block_scaling_dim
,
)
# Quantize dout and w
qdout
=
dout_quantizer
.
make_empty
(
dout_shape
,
dtype
=
dout_dtype
,
device
=
device
,
requires_grad
=
False
)
qdout
=
dout_quantizer
.
update_quantized
(
dout
,
qdout
)
qx
=
x_quantizer
.
make_empty
(
x_shape
,
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
qx
=
x_quantizer
.
update_quantized
(
x
,
qx
)
if
not
use_bias
:
bias
=
None
else
:
bias
=
torch
.
randn
((
1
,
N
),
dtype
=
torch
.
bfloat16
,
device
=
device
)
# Reference GEMM
ref_gemm
=
CuBLASRefBlockwiseGemm
()
scale_decoder
=
CuBLASScaleMunger
()
qdout_data
=
(
qdout
.
_columnwise_data
.
view
(
dtype
=
dout_dtype
)
if
dout_columnwise
else
qdout
.
_rowwise_data
.
view
(
dtype
=
dout_dtype
)
)
qx_data
=
(
qx
.
_columnwise_data
.
view
(
dtype
=
x_dtype
)
if
x_columnwise
else
qx
.
_rowwise_data
.
view
(
dtype
=
x_dtype
)
)
ref_scales_dout
=
qdout
.
_columnwise_scale_inv
if
dout_columnwise
else
qdout
.
_rowwise_scale_inv
ref_scales_x
=
qx
.
_columnwise_scale_inv
if
x_columnwise
else
qx
.
_rowwise_scale_inv
y_ref
=
ref_gemm
.
qgemm
(
qx
=
qdout_data
,
qw
=
qx_data
,
out_dtype
=
dw_dtype
,
demunged_sx
=
CuBLASScaleMunger
.
demunge_scale_shape_from_backend
(
qtensor_shape
=
(
N
,
M
),
scales
=
ref_scales_dout
,
tile_shape
=
dout_quant_tile_shape
),
demunged_sw
=
CuBLASScaleMunger
.
demunge_scale_shape_from_backend
(
qtensor_shape
=
(
K
,
M
),
scales
=
ref_scales_x
,
tile_shape
=
x_quant_tile_shape
),
quant_tile_shape_x
=
dout_quant_tile_shape
,
quant_tile_shape_w
=
x_quant_tile_shape
,
bias
=
bias
,
out
=
dw
.
clone
()
if
accumulate
else
None
,
accumulate
=
accumulate
,
use_split_accumulator
=
use_split_accumulator
,
)
# print("fp8 gemm dw: ",y_ref)
dout_te_dtype
=
TE_DType
[
torch
.
int8
]
x_te_dtype
=
TE_DType
[
torch
.
int8
]
dout_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
dout_te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
block_scaling_dim
=
dout_block_scaling_dim
,
)
x_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
x_te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
block_scaling_dim
=
x_block_scaling_dim
,
)
# Quantize x and w
qdout
=
dout_quantizer
.
make_empty
(
dout_shape
,
dtype
=
torch
.
int8
,
device
=
device
,
requires_grad
=
False
)
qdout
=
dout_quantizer
.
update_quantized
(
dout
,
qdout
)
qx
=
x_quantizer
.
make_empty
(
x_shape
,
dtype
=
torch
.
int8
,
device
=
device
,
requires_grad
=
False
)
qx
=
x_quantizer
.
update_quantized
(
x
,
qx
)
qdout_data
=
(
qdout
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
if
dout_columnwise
else
qdout
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
)
qx_data
=
(
qx
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
if
x_columnwise
else
qx
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
)
ref_scales_dout
=
qdout
.
_columnwise_scale_inv
if
dout_columnwise
else
qdout
.
_rowwise_scale_inv
ref_scales_x
=
qx
.
_columnwise_scale_inv
if
x_columnwise
else
qx
.
_rowwise_scale_inv
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
y
,
_
=
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
[
128
,
128
],
output_dtype
=
dw_dtype
)
# print("int8 gemm dw: ",y)
torch
.
testing
.
assert_close
(
y_ref
,
bf16_dw
,
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_close
(
y
,
bf16_dw
,
atol
=
atol
,
rtol
=
rtol
)
def
test_cublas_gemm_fp8_blockwise_fw
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
noise_type
,
x_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
):
cublas_gemm_fp8_blockwise_case_fw
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
noise_type
,
x_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
)
def
test_cublas_gemm_fp8_blockwise_bw_xgrad
(
dout_dtype
,
w_dtype
,
dx_dtype
,
M
,
K
,
N
,
noise_type
,
dout_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_dout_1d_scaled
,
is_w_1d_scaled
,
):
cublas_gemm_fp8_blockwise_case_bw_xgrad
(
dout_dtype
,
w_dtype
,
dx_dtype
,
M
,
K
,
N
,
noise_type
,
dout_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_dout_1d_scaled
,
is_w_1d_scaled
,
)
def
test_cublas_gemm_fp8_blockwise_bw_wgrad
(
dout_dtype
,
x_dtype
,
dw_dtype
,
M
,
K
,
N
,
noise_type
,
dout_magnitude
,
x_magnitude
,
accumulate
,
use_split_accumulator
,
is_dout_1d_scaled
,
is_x_1d_scaled
,
):
cublas_gemm_fp8_blockwise_case_bw_wgrad
(
dout_dtype
,
x_dtype
,
dw_dtype
,
M
,
K
,
N
,
noise_type
,
dout_magnitude
,
x_magnitude
,
accumulate
,
use_split_accumulator
,
is_dout_1d_scaled
,
is_x_1d_scaled
,
)
if
__name__
==
"__main__"
:
test_cublas_gemm_fp8_blockwise_fw
(
x_dtype
=
torch
.
float8_e4m3fn
,
# torch.float8_e4m3fnuz if te.e4m3 use funz
w_dtype
=
torch
.
float8_e4m3fn
,
# torch.float8_e4m3fnuz if te.e4m3 use funz
out_dtype
=
torch
.
bfloat16
,
M
=
128
,
# batch_size * seq_len
K
=
512
,
# in_feature
N
=
256
,
# out_feature
noise_type
=
"normal"
,
x_magnitude
=
1e-1
,
w_magnitude
=
1
,
accumulate
=
False
,
use_split_accumulator
=
True
,
is_x_1d_scaled
=
True
,
is_w_1d_scaled
=
False
)
test_cublas_gemm_fp8_blockwise_bw_xgrad
(
dout_dtype
=
torch
.
float8_e4m3fn
,
# torch.float8_e4m3fnuz if te.e4m3 use funz
w_dtype
=
torch
.
float8_e4m3fn
,
# torch.float8_e4m3fnuz if te.e4m3 use funz
dx_dtype
=
torch
.
bfloat16
,
M
=
128
,
# batch_size * seq_len
K
=
512
,
# in_feature
N
=
256
,
# out_feature
noise_type
=
"normal"
,
dout_magnitude
=
1e-1
,
w_magnitude
=
1
,
accumulate
=
False
,
use_split_accumulator
=
True
,
is_dout_1d_scaled
=
True
,
is_w_1d_scaled
=
False
,
)
test_cublas_gemm_fp8_blockwise_bw_wgrad
(
dout_dtype
=
torch
.
float8_e4m3fn
,
# torch.float8_e4m3fnuz if te.e4m3 use funz
x_dtype
=
torch
.
float8_e4m3fn
,
# torch.float8_e4m3fnuz if te.e4m3 use funz
dw_dtype
=
torch
.
bfloat16
,
M
=
128
,
# batch_size * seq_len
K
=
512
,
# in_feature
N
=
256
,
# out_feature
noise_type
=
"normal"
,
dout_magnitude
=
1e-1
,
x_magnitude
=
1
,
accumulate
=
False
,
use_split_accumulator
=
True
,
is_dout_1d_scaled
=
True
,
is_x_1d_scaled
=
True
,
)
\ No newline at end of file
tests/pytorch/test_int8_blockwise_layers.py
0 → 100644
View file @
1b303e91
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
typing
import
Tuple
import
math
import
os
import
pathlib
import
pytest
import
torch
import
transformer_engine
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.common.recipe
import
Float8BlockScaling
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
)
from
references.blockwise_quantizer_reference
import
(
BlockwiseQuantizerReference
,
QuantizeResult
,
)
from
test_float8_current_scaling_exact
import
(
TestFP8RecipeLinearBase
,
TestFP8RecipeLayerNormLinearBase
,
)
import
logging
# read env variable NVTE_TEST_FLOAT8_BLOCK_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory
TENSOR_DUMP_DIR
=
pathlib
.
Path
(
__file__
).
resolve
().
parent
.
parent
.
parent
/
"tensor_dumps"
tensor_dump_dir_env
=
os
.
getenv
(
"NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR"
)
if
tensor_dump_dir_env
is
not
None
:
TENSOR_DUMP_DIR
=
pathlib
.
Path
(
tensor_dump_dir_env
)
recipe_available
,
reason_for_no_recipe
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
class
GetRecipes
:
@
staticmethod
def
none
():
return
None
@
staticmethod
def
fp8_blockwise
():
# return default configs
return
Float8BlockScaling
()
# FP8 per tesnor current scaling
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
class
TestFP8BlockScalingRecipeLinear
(
TestFP8RecipeLinearBase
):
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
@
pytest
.
mark
.
parametrize
(
"batch_size, hidden_size, out_size"
,
[
(
16
,
256
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
[
"bf16"
])
@
pytest
.
mark
.
parametrize
(
"recipe1, recipe2"
,
[
(
GetRecipes
.
none
,
GetRecipes
.
fp8_blockwise
),
],
)
def
test_fp8_current_scaling_with_linear_module
(
self
,
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
dtype
,
use_bias
=
False
,
):
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map
=
self
.
_check_golden_tensor_dumps
(
TENSOR_DUMP_DIR
,
recipe2
,
(
batch_size
,
hidden_size
,
out_size
),
dtype
,
use_bias
)
if
tensor_map
is
not
None
:
fp8_zero_tolerance_tensor_dumps_recipe2
=
tensor_map
assert
recipe1
==
GetRecipes
.
none
,
"Only None recipe is supported for recipe1"
self
.
compare_recipe
(
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
use_bias
,
seed
=
torch
.
initial_seed
(),
dtype
=
dtype
,
y_error
=
0.5
,
dgrad_error
=
1
,
wgrad_error
=
1
,
bgrad_error
=
0.5
,
recipe1_golden_tensors
=
None
,
recipe2_golden_tensors
=
fp8_zero_tolerance_tensor_dumps_recipe2
,
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
class
TestFP8BlockScalingRecipeLayerNormLinear
(
TestFP8RecipeLayerNormLinearBase
):
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
@
pytest
.
mark
.
parametrize
(
"batch_size, hidden_size, out_size"
,
[
(
16
,
256
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
[
"bf16"
])
@
pytest
.
mark
.
parametrize
(
"recipe1, recipe2"
,
[
(
GetRecipes
.
none
,
GetRecipes
.
fp8_blockwise
),
],
)
def
test_fp8_current_scaling_with_layernorm_linear_module
(
self
,
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
dtype
,
use_bias
=
False
,
):
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map
=
self
.
_check_golden_tensor_dumps
(
TENSOR_DUMP_DIR
,
recipe2
,
(
batch_size
,
hidden_size
,
out_size
),
dtype
,
use_bias
,
"LayerNorm"
,
)
if
tensor_map
is
not
None
:
fp8_zero_tolerance_tensor_dumps_recipe2
=
tensor_map
self
.
compare_recipe
(
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
use_bias
,
seed
=
torch
.
initial_seed
(),
dtype
=
dtype
,
y_error
=
0.9
,
ln_out_error
=
0.5
,
dgrad_error
=
1.5
,
wgrad_error
=
1
,
bgrad_error
=
0.5
,
recipe1_golden_tensors
=
None
,
recipe2_golden_tensors
=
fp8_zero_tolerance_tensor_dumps_recipe2
,
)
transformer_engine/pytorch/constants.py
View file @
1b303e91
...
...
@@ -35,6 +35,7 @@ TE_DType_To_Torch = {
tex
.
DType
.
kByte
:
torch
.
uint8
,
tex
.
DType
.
kFloat8E4M3
:
torch
.
float8_e4m3fn
,
tex
.
DType
.
kFloat8E5M2
:
torch
.
float8_e5m2
,
tex
.
DType
.
kInt8
:
torch
.
int8
,
tex
.
DType
.
kInt32
:
torch
.
int32
,
tex
.
DType
.
kFloat32
:
torch
.
float32
,
tex
.
DType
.
kFloat16
:
torch
.
half
,
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
1b303e91
...
...
@@ -10,11 +10,13 @@ import torch
import
transformer_engine_torch
as
tex
from
..constants
import
TE_DType
from
..utils
import
get_sm_count
,
_empty_tensor
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad
import
w8a8_block_int8_matmul_wgrad
from
..tensor.quantized_tensor
import
Quantizer
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
__all__
=
[
"general_gemm"
,
"general_grouped_gemm"
,
...
...
@@ -54,6 +56,67 @@ def general_gemm(
# + "a valid `ub` communicator object."
# )
if
int8_simulation_fp8
and
(
isinstance
(
A
,
Float8BlockwiseQTensorBase
)
or
isinstance
(
B
,
Float8BlockwiseQTensorBase
)):
assert
not
gelu
,
"GELU not supported with int8 simulation"
assert
gelu_in
is
None
,
"GELU input not supported with int8 simulation"
assert
bias
is
None
,
"Bias not supported with int8 simulation"
assert
not
accumulate
,
"Accumulation not supported with int8 simulation"
assert
ub
is
None
,
"User buffer not supported with int8 simulation"
assert
ub_type
is
None
,
"User buffer type not supported with int8 simulation"
assert
extra_output
is
None
,
"Extra output not supported with int8 simulation"
assert
not
bulk_overlap
,
"Bulk overlap not supported with int8 simulation"
if
layout
==
"TN"
:
qx_data
=
(
B
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
)
qw_data
=
(
A
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
)
ref_scales_x
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
128
,
128
],
output_dtype
=
out_dtype
)
return
y
,
None
,
None
,
None
elif
layout
==
"NN"
:
qdout_data
=
(
B
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
)
qw_data
=
(
A
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
)
ref_scales_dout
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_columnwise_scale_inv
y
,
_
=
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
128
,
128
],
output_dtype
=
out_dtype
)
return
y
,
None
,
None
,
None
elif
layout
==
"NT"
:
qdout_data
=
(
B
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
)
qx_data
=
(
A
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
)
ref_scales_dout
=
B
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
y
,
_
=
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
[
128
,
128
],
output_dtype
=
out_dtype
)
return
y
,
None
,
None
,
None
else
:
raise
ValueError
(
f
"Unsupported layout
{
layout
}
in int8 simulation fp8"
)
if
ub
is
not
None
:
assert
ub_type
is
not
None
,
"Comm+GEMM overlap requires a valid `comm_type` argument."
if
ub_type
==
tex
.
CommOverlapType
.
RS
:
...
...
transformer_engine/pytorch/fp8.py
View file @
1b303e91
...
...
@@ -27,16 +27,18 @@ from .constants import dist_group_type
from
.utils
import
get_device_compute_capability
from
.jit
import
jit_fuser
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
__all__
=
[
"fp8_autocast"
,
"fp8_model_init"
]
if
IS_HIP_EXTENSION
:
from
transformer_engine.pytorch.utils
import
is_K100_AI
,
is_BW
def
check_fp8_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 support is available"""
if
IS_HIP_EXTENSION
:
if
get_device_compute_capability
()
==
(
9
,
4
)
:
return
True
,
""
if
(
is_K100_AI
()
or
is_BW
())
and
int8_simulation_fp8
:
return
True
,
"
DCU turn on fp8 simulation with int8
"
else
:
return
False
,
"DCU not support fp8 for now"
else
:
...
...
@@ -61,7 +63,10 @@ def check_mxfp8_support() -> Tuple[bool, str]:
def
check_fp8_block_scaling_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 block scaling support is available"""
if
IS_HIP_EXTENSION
:
if
is_K100_AI
()
or
is_BW
():
return
True
,
""
else
:
return
False
,
"DCU not support block_scaling fp8 for now"
if
(
get_device_compute_capability
()
>=
(
9
,
0
)
and
get_device_compute_capability
()
<
(
10
,
0
)
...
...
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
View file @
1b303e91
...
...
@@ -9,7 +9,7 @@ from typing import Optional, Tuple, Iterable
import
math
import
torch
import
transformer_engine_torch
as
tex
import
os
from
transformer_engine_torch
import
DType
as
TE_DType
from
._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
...
...
@@ -17,6 +17,7 @@ from ..utils import devices_match, round_up_to_nearest_multiple
aten
=
torch
.
ops
.
aten
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
class
Float8BlockQuantizer
(
Quantizer
):
"""Builder class for tensors quantized with current scaling using
...
...
@@ -44,7 +45,7 @@ class Float8BlockQuantizer(Quantizer):
block_scaling_dim
:
int
=
2
,
)
->
None
:
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
dtype
=
fp8_dtype
self
.
dtype
=
tex
.
DType
.
kInt8
if
int8_simulation_fp8
else
fp8_dtype
self
.
block_len
=
128
self
.
force_pow_2_scales
=
force_pow_2_scales
self
.
amax_epsilon
=
amax_epsilon
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
0 → 100644
View file @
1b303e91
import
torch
import
time
from
typing
import
Optional
,
Type
,
Any
,
Dict
,
List
,
Tuple
import
pandas
as
pd
import
os
import
json
import
triton
import
triton.language
as
tl
import
pandas
as
pd
from
transformer_engine.pytorch.triton.per_token_group_quant
import
_int8_gemm_helper
import
functools
import
logging
logger
=
logging
.
getLogger
(
__name__
)
device_name
=
torch
.
cuda
.
get_device_properties
(
'cuda'
).
name
.
replace
(
" "
,
"_"
)
tuning_full_space
=
False
# tuning_full_space = True
def
get_full_tuning_space
():
configs
=
[]
if
not
tuning_full_space
:
return
configs
block_m_range
=
[
16
,
32
,
64
]
block_n_range
=
[
16
,
32
,
64
,
128
]
block_k_range
=
[
32
,
64
,
128
]
num_warps_range
=
[
4
,
8
]
group_m_range
=
[
2
,
4
,
8
]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range
=
[
0
,
1
,
2
]
for
block_m
in
block_m_range
:
for
block_n
in
block_n_range
:
for
block_k
in
block_k_range
:
for
num_warps
in
num_warps_range
:
for
group_m
in
group_m_range
:
for
num_stages
in
num_stage_range
:
configs
.
append
(
triton
.
Config
({
'BLOCK_SIZE_M'
:
block_m
,
'BLOCK_SIZE_N'
:
block_n
,
'BLOCK_SIZE_K'
:
block_k
,
'GROUP_SIZE_M'
:
group_m
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
enable_mmacfuse
=
2
))
return
configs
@
triton
.
autotune
(
configs
=
get_full_tuning_space
()
if
tuning_full_space
else
[
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 2, 'kpack':2}, num_stages=2, num_warps=8),
triton
.
Config
({
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
2
,},
num_stages
=
1
,
num_warps
=
4
,
enable_mmacfuse
=
2
),
triton
.
Config
({
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
128
,
'BLOCK_SIZE_K'
:
128
,
'GROUP_SIZE_M'
:
8
,},
num_stages
=
1
,
num_warps
=
4
,
enable_mmacfuse
=
2
),
],
key
=
[
'M'
,
'N'
,
'K'
],
# reset_to_zero=['c_ptr']
)
@
triton
.
jit
def
_w8a8_block_int8_matmul
(
# Pointers to inputs and output
A
,
B
,
C
,
As
,
Bs
,
# Shape for matmul
M
,
N
,
K
,
# Block size for block-wise quantization
group_n
,
group_k
,
# Stride for inputs and output
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_As_m
,
stride_As_k
,
stride_Bs_k
,
stride_Bs_n
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization,
and store the result in output tensor `C`.
"""
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
(
pid
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
offs_am
=
(
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
))
%
M
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
# offs_bsn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_bsn
=
pid_n
*
BLOCK_SIZE_N
//
group_n
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
A
+
(
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
B
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
# a_ptrs = A + (offs_am[:, None] * stride_am)
# b_ptrs = B + (offs_bn[None, :] * stride_bn)
As_ptrs
=
As
+
offs_am
*
stride_As_m
# offs_bsn = offs_bn // group_n
Bs_ptrs
=
Bs
+
offs_bsn
*
stride_Bs_n
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_s
=
tl
.
load
(
As_ptrs
+
offs_ks
*
stride_As_k
)
b_s
=
tl
.
load
(
Bs_ptrs
+
offs_ks
*
stride_Bs_k
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
accumulator
+=
tl
.
dot
(
a
,
b
).
to
(
tl
.
float32
)
*
a_s
[:,
None
]
*
b_s
[
None
,
:]
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
C
.
dtype
.
element_ty
==
tl
.
bfloat16
:
c
=
accumulator
.
to
(
tl
.
bfloat16
)
elif
C
.
dtype
.
element_ty
==
tl
.
float16
:
c
=
accumulator
.
to
(
tl
.
float16
)
else
:
c
=
accumulator
.
to
(
tl
.
float32
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
C
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
@
functools
.
lru_cache
def
get_w8a8_block_int8_configs
(
N
:
int
,
K
:
int
,
block_n
:
int
,
block_k
:
int
)
->
Optional
[
Dict
[
int
,
Any
]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name
=
torch
.
cuda
.
get_device_properties
(
'cuda'
).
name
.
replace
(
" "
,
"_"
)
json_file_name
=
f
"N=
{
N
}
,K=
{
K
}
,device_name=
{
device_name
}
,dtype=int8_w8a8,block_shape=[
{
block_n
}
,
{
block_k
}
].json"
# noqa: E501
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for W8A8 Block INT8 kernel."
,
config_file_path
,
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
# If no optimized configuration is available, we will use the default
# configuration
logger
.
warning
(
(
"Using default W8A8 Block INT8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s"
),
config_file_path
,
)
return
None
def
w8a8_block_int8_matmul
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
block_size
:
List
[
int
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
best_config
:
Optional
[
dict
]
=
None
)
->
torch
.
Tensor
:
"""matrix multiplication with block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
# print(f"A.shape[:-1] : {A.shape[:-1]}, As.shape[:-1]: {As.shape[:-1]}")
# assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
As
.
shape
[
0
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
B
.
is_contiguous
()
and
Bs
.
ndim
==
2
N
,
K
=
B
.
shape
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
# assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
# configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
# if configs:
# # If an optimal configuration map has been found, look up the
# # optimal config
# config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
# else:
if
best_config
:
config
=
best_config
else
:
#print("best config has not found!")
# config = {
# "BLOCK_SIZE_M": 32, #64
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
#print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
if
M
<=
64
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
elif
M
<
128
:
config
=
{
"BLOCK_SIZE_M"
:
32
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
elif
M
<=
256
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
8
,
"num_stages"
:
0
,
}
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
#print("config:",config)
# print(f"zhenggf, A.shape:{A.shape}, B.shape:{B.shape}")
# print(f"zhenggf, A.stride(-2):{A.stride(-2)}, A.stride(-1):{A.stride(-1)}, B.stride(1):{B.stride(1)}, B.stride(0):{B.stride(0)}")
# print(f"zhenggf, As.stride(-2):{As.stride(-2)}, As.stride(-1):{As.stride(-1)}, Bs.stride(1):{Bs.stride(1)}, Bs.stride(0):{Bs.stride(0)}")
# print(f"zhenggf, As.stride(-2):{As.stride(-2)}, As.stride(-1):{As.stride(-1)}, Bs.stride(1):{Bs.stride(1)}, Bs.stride(0):{Bs.stride(0)}")
# As = As.permute(1, 0).contiguous()
_w8a8_block_int8_matmul
[
grid
](
A
,
B
,
C
,
As
,
Bs
,
M
,
N
,
K
,
block_n
,
block_k
,
A
.
stride
(
-
2
),
A
.
stride
(
-
1
),
B
.
stride
(
1
),
B
.
stride
(
0
),
C
.
stride
(
-
2
),
C
.
stride
(
-
1
),
# As.stride(-2),
# As.stride(-1),
As
.
stride
(
1
),
As
.
stride
(
0
),
Bs
.
stride
(
1
),
Bs
.
stride
(
0
),
# **config,
)
config
=
_w8a8_block_int8_matmul
.
best_config
return
C
,
config
def
apply_w8a8_block_int8_linear_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
best_config
:
Optional
[
dict
]
=
None
):
q_input
,
x_scale
,
weight
,
weight_scale
=
_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
out_dtype
=
out_dtype
,
device
=
device
,
block_size
=
block_size
)
print
(
f
"zhenggf, q_input:
{
q_input
.
shape
}
, x_scale:
{
x_scale
.
shape
}
, weight:
{
weight
.
shape
}
, weight_scale:
{
weight_scale
.
shape
}
"
)
torch_output
=
native_w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
)
x_scale
=
x_scale
.
permute
(
1
,
0
).
contiguous
()
output
,
config
=
w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
)
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
print
(
"triton 精度检查不合格!!!"
)
else
:
print
(
"triton 精度检查合格"
)
# unit test end
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
for
it
in
range
(
1000
):
output
,
_
=
w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
)
torch
.
cuda
.
synchronize
()
start_time_
=
time
.
time
()
# 开始计时
g
.
replay
()
torch
.
cuda
.
synchronize
()
end_time_
=
time
.
time
()
# 结束计时
elapsed_time
=
round
((
end_time_
-
start_time_
)
*
1000
,
7
)
# 计算耗时
print
(
"_time:{} us
\n
"
.
format
(
elapsed_time
))
quantiles
=
[
0.5
,
0.2
,
0.8
]
gpu_costtime
=
triton
.
testing
.
do_bench
(
lambda
:
w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
),
quantiles
=
None
,
return_mode
=
"mean"
)
*
1000
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
dtype
=
out_dtype
),
elapsed_time
,
gpu_costtime
,
config
def
get_triton_cache
(
file_path
,
n
,
k
,
block_n
,
block_k
):
#会将所报错的json文件以字典的形式return出来
#先读取指定的文件,该文件地址不存在则会读默认路径
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
cachedata
=
{}
# 写入空数据到新的JSON文件
with
open
(
file_path
,
'w'
)
as
file
:
json
.
dump
(
cachedata
,
file
)
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_value
=
{
'BLOCK_SIZE_M'
:
int
(
sub_value
[
"BLOCK_SIZE_M"
]),
'BLOCK_SIZE_N'
:
int
(
sub_value
[
"BLOCK_SIZE_N"
]),
'BLOCK_SIZE_K'
:
int
(
sub_value
[
"BLOCK_SIZE_K"
]),
'GROUP_SIZE_M'
:
int
(
sub_value
[
"GROUP_SIZE_M"
]),
'num_stages'
:
int
(
sub_value
[
'num_stages'
]),
'num_warps'
:
int
(
sub_value
[
'num_warps'
]),
# 'kpack':int(sub_value['kpack']),
'enable_mmacfuse'
:
int
(
2
),
}
configs_dict
[
configs_key
]
=
configs_value
return
configs_dict
def
getspec_config
(
configs_dict
,
m
,
n
,
k
,
block_n
,
block_k
):
if
f
"
{
m
}
_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]"
in
configs_dict
:
return
configs_dict
[
f
"
{
m
}
_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]"
]
else
:
return
None
# For test
def
native_w8a8_block_int8_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
=
torch
.
bfloat16
):
"""This function performs matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A
=
A
.
to
(
torch
.
float32
)
B
=
B
.
to
(
torch
.
float32
)
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
B
.
is_contiguous
()
and
Bs
.
ndim
==
2
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
(
A
.
shape
[
-
1
]
+
block_k
-
1
)
//
block_k
==
As
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
N
,
K
=
B
.
shape
origin_C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
A
=
A
.
reshape
(
M
,
A
.
shape
[
-
1
])
As
=
As
.
reshape
(
M
,
As
.
shape
[
-
1
])
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
assert
n_tiles
==
Bs
.
shape
[
0
]
assert
k_tiles
==
Bs
.
shape
[
1
]
C_shape
=
(
M
,
N
)
C
=
torch
.
zeros
(
C_shape
,
dtype
=
torch
.
float32
,
device
=
A
.
device
)
A_tiles
=
[
A
[:,
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
)]
for
i
in
range
(
k_tiles
)]
B_tiles
=
[
[
B
[
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
),
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
),
]
for
i
in
range
(
k_tiles
)
]
for
j
in
range
(
n_tiles
)
]
C_tiles
=
[
C
[:,
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
)]
for
j
in
range
(
n_tiles
)]
As_tiles
=
[
As
[:,
i
:
i
+
1
]
for
i
in
range
(
k_tiles
)]
for
i
in
range
(
k_tiles
):
for
j
in
range
(
n_tiles
):
a
=
A_tiles
[
i
]
b
=
B_tiles
[
j
][
i
]
c
=
C_tiles
[
j
]
s
=
As_tiles
[
i
]
*
Bs
[
j
][
i
]
c
[:,
:]
+=
torch
.
matmul
(
a
,
b
.
t
())
*
s
C
=
C
.
reshape
(
origin_C_shape
).
to
(
output_dtype
)
return
C
def
main
():
m1
=
[
item
if
item
<
17
else
1
<<
(
item
-
27
)
for
item
in
range
(
1
,
17
)]
m2
=
[
item
<<
2
if
item
<
17
else
(
item
-
8
)
<<
3
for
item
in
range
(
5
,
29
)]
m3
=
[
2
<<
(
item
)
for
item
in
range
(
7
,
13
)]
m_list
=
m1
+
m2
+
m3
n_list
=
[
576
,
2048
,
7168
,
256
,
7168
,
1536
,
1536
,
2304
,
7168
]
k_list
=
[
7168
,
512
,
1024
,
7168
,
128
,
7168
,
1536
,
7168
,
1152
]
m_list
=
[
8192
]
n_list
=
[
7168
]
k_list
=
[
1152
]
block_size
=
[
128
,
128
]
out_dtype
=
torch
.
bfloat16
_n
=
[]
_k
=
[]
_m
=
[]
_configs_block_m
=
[]
_configs_block_n
=
[]
_configs_block_k
=
[]
_configs_block_group_m
=
[]
_configs_block_num_warps
=
[]
_configs_block_num_stages
=
[]
_configs_kpack
=
[]
cost_times
=
[]
gpu_costtimes
=
[]
device_name
=
torch
.
cuda
.
get_device_properties
(
'cuda'
).
name
.
replace
(
" "
,
"_"
)
for
i
in
range
(
0
,
len
(
k_list
),
1
):
for
m
in
m_list
:
print
(
"m:{} n:{} k:{} "
.
format
(
m
,
n_list
[
i
],
k_list
[
i
]))
best_config
=
[]
output
,
elapsed_time
,
gpu_costtime
,
config
=
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
cost_times
.
append
(
elapsed_time
)
gpu_costtimes
.
append
(
gpu_costtime
)
_n
.
append
(
n_list
[
i
])
_k
.
append
(
k_list
[
i
])
_m
.
append
(
m
)
print
(
f
"zhenggf,
{
config
}
"
)
print
(
f
"zhenggf,
{
config
.
kwargs
}
"
)
_configs_block_m
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_M'
])
_configs_block_n
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_N'
])
_configs_block_k
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_K'
])
_configs_block_group_m
.
append
(
config
.
kwargs
[
'GROUP_SIZE_M'
])
_configs_block_num_warps
.
append
(
config
.
num_warps
)
_configs_block_num_stages
.
append
(
config
.
num_stages
)
# _configs_kpack.append(config['kpack'])
# 创建一个包含这三个列表的 DataFrame
df
=
pd
.
DataFrame
({
'm'
:
_m
,
'n'
:
_n
,
'k'
:
_k
,
'线性层gemm量化算子耗时'
:
cost_times
,
'GPU算子耗时'
:
gpu_costtimes
,
'BLOCK_SIZE_M'
:
_configs_block_m
,
'BLOCK_SIZE_N'
:
_configs_block_n
,
'BLOCK_SIZE_K'
:
_configs_block_k
,
'GROUP_SIZE_M'
:
_configs_block_group_m
,
'num_warps'
:
_configs_block_num_warps
,
'num_stages'
:
_configs_block_num_stages
,
#'kpack':_configs_kpack
})
# 将 DataFrame 写入 Excel 文件
df
.
to_excel
(
'gemmoutput.xlsx'
,
index
=
False
)
print
(
"表格已保存到 gemmoutput.xlsx 文件中。"
)
if
__name__
==
"__main__"
:
main
()
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
0 → 100644
View file @
1b303e91
import
torch
import
time
from
typing
import
Optional
,
Type
,
Any
,
Dict
,
List
,
Tuple
import
pandas
as
pd
import
os
import
json
import
triton
import
triton.language
as
tl
import
pandas
as
pd
from
transformer_engine.pytorch.triton.per_token_group_quant
import
_int8_gemm_helper_b
import
functools
import
logging
logger
=
logging
.
getLogger
(
__name__
)
device_name
=
torch
.
cuda
.
get_device_properties
(
'cuda'
).
name
.
replace
(
" "
,
"_"
)
tuning_full_space
=
False
# tuning_full_space = True
def
get_full_tuning_space
():
configs
=
[]
if
not
tuning_full_space
:
return
configs
block_m_range
=
[
16
,
32
,
64
]
block_n_range
=
[
16
,
32
,
64
,
128
]
block_k_range
=
[
32
,
64
,
128
]
num_warps_range
=
[
4
,
8
]
group_m_range
=
[
2
,
4
,
8
]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range
=
[
0
,
1
,
2
]
for
block_m
in
block_m_range
:
for
block_n
in
block_n_range
:
for
block_k
in
block_k_range
:
for
num_warps
in
num_warps_range
:
for
group_m
in
group_m_range
:
for
num_stages
in
num_stage_range
:
configs
.
append
(
triton
.
Config
({
'BLOCK_SIZE_M'
:
block_m
,
'BLOCK_SIZE_N'
:
block_n
,
'BLOCK_SIZE_K'
:
block_k
,
'GROUP_SIZE_M'
:
group_m
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
enable_mmacfuse
=
2
))
return
configs
@
triton
.
autotune
(
configs
=
get_full_tuning_space
()
if
tuning_full_space
else
[
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 2, 'kpack':2}, num_stages=2, num_warps=8),
triton
.
Config
({
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
2
,},
num_stages
=
1
,
num_warps
=
4
,
enable_mmacfuse
=
2
),
triton
.
Config
({
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
128
,
'BLOCK_SIZE_K'
:
128
,
'GROUP_SIZE_M'
:
8
,},
num_stages
=
1
,
num_warps
=
4
,
enable_mmacfuse
=
2
),
],
key
=
[
'M'
,
'N'
,
'K'
],
# reset_to_zero=['c_ptr']
)
@
triton
.
jit
def
_w8a8_block_int8_matmul
(
# Pointers to inputs and output
A
,
B
,
C
,
As
,
Bs
,
# Shape for matmul
M
,
N
,
K
,
# Block size for block-wise quantization
group_n
,
group_k
,
# Stride for inputs and output
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_As_m
,
stride_As_k
,
stride_Bs_k
,
stride_Bs_n
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization,
and store the result in output tensor `C`.
"""
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
(
pid
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
offs_am
=
(
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
))
%
M
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_bsn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
# offs_bsn = pid_n * BLOCK_SIZE_N // group_n
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
A
+
(
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
B
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
As_ptrs
=
As
+
offs_am
*
stride_As_m
# offs_bsn = offs_bn // group_n
Bs_ptrs
=
Bs
+
offs_bsn
*
stride_Bs_n
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_s
=
tl
.
load
(
As_ptrs
+
offs_ks
*
stride_As_k
)
b_s
=
tl
.
load
(
Bs_ptrs
+
offs_ks
*
stride_Bs_k
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
accumulator
+=
tl
.
dot
(
a
,
b
).
to
(
tl
.
float32
)
*
a_s
[:,
None
]
*
b_s
[
None
,
:]
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
C
.
dtype
.
element_ty
==
tl
.
bfloat16
:
c
=
accumulator
.
to
(
tl
.
bfloat16
)
elif
C
.
dtype
.
element_ty
==
tl
.
float16
:
c
=
accumulator
.
to
(
tl
.
float16
)
else
:
c
=
accumulator
.
to
(
tl
.
float32
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
C
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
@
functools
.
lru_cache
def
get_w8a8_block_int8_configs
(
N
:
int
,
K
:
int
,
block_n
:
int
,
block_k
:
int
)
->
Optional
[
Dict
[
int
,
Any
]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name
=
torch
.
cuda
.
get_device_properties
(
'cuda'
).
name
.
replace
(
" "
,
"_"
)
json_file_name
=
f
"N=
{
N
}
,K=
{
K
}
,device_name=
{
device_name
}
,dtype=int8_w8a8,block_shape=[
{
block_n
}
,
{
block_k
}
].json"
# noqa: E501
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for W8A8 Block INT8 kernel."
,
config_file_path
,
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
# If no optimized configuration is available, we will use the default
# configuration
logger
.
warning
(
(
"Using default W8A8 Block INT8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s"
),
config_file_path
,
)
return
None
def
w8a8_block_int8_matmul_wgrad
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
block_size
:
List
[
int
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
best_config
:
Optional
[
dict
]
=
None
)
->
torch
.
Tensor
:
"""matrix multiplication with block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
# print(f"A.shape[:-1] : {A.shape[:-1]}, As.shape[:-1]: {As.shape[:-1]}")
# assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
As
.
shape
[
0
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
B
.
is_contiguous
()
and
Bs
.
ndim
==
2
N
,
K
=
B
.
shape
# assert triton.cdiv(N, block_n) == Bs.shape[0]
# assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
# configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
# if configs:
# # If an optimal configuration map has been found, look up the
# # optimal config
# config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
# else:
if
best_config
:
config
=
best_config
else
:
#print("best config has not found!")
# config = {
# "BLOCK_SIZE_M": 32, #64
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
#print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
if
M
<=
64
:
config
=
{
"BLOCK_SIZE_M"
:
16
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
elif
M
<
128
:
config
=
{
"BLOCK_SIZE_M"
:
32
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
elif
M
<=
256
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
0
,
}
else
:
config
=
{
"BLOCK_SIZE_M"
:
64
,
#64
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
8
,
"num_stages"
:
0
,
}
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
#print("config:",config)
# print(f"zhenggf, A.shape:{A.shape}, B.shape:{B.shape}")
# print(f"zhenggf, A.stride(-2):{A.stride(-2)}, A.stride(-1):{A.stride(-1)}, B.stride(1):{B.stride(1)}, B.stride(0):{B.stride(0)}")
# print(f"zhenggf, As.stride(-2):{As.stride(-2)}, As.stride(-1):{As.stride(-1)}, Bs.stride(1):{Bs.stride(1)}, Bs.stride(0):{Bs.stride(0)}")
# print(f"zhenggf, As.stride(-2):{As.stride(-2)}, As.stride(-1):{As.stride(-1)}, Bs.stride(1):{Bs.stride(1)}, Bs.stride(0):{Bs.stride(0)}")
# As = As.permute(1, 0).contiguous()
_w8a8_block_int8_matmul
[
grid
](
A
,
B
,
C
,
As
,
Bs
,
M
,
N
,
K
,
block_n
,
block_k
,
A
.
stride
(
-
2
),
A
.
stride
(
-
1
),
B
.
stride
(
1
),
B
.
stride
(
0
),
C
.
stride
(
-
2
),
C
.
stride
(
-
1
),
# As.stride(-2),
# As.stride(-1),
As
.
stride
(
1
),
As
.
stride
(
0
),
Bs
.
stride
(
-
2
),
Bs
.
stride
(
-
1
),
# Bs.stride(1),
# Bs.stride(0),
# **config,
)
config
=
_w8a8_block_int8_matmul
.
best_config
return
C
,
config
def
apply_w8a8_block_int8_linear_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
best_config
:
Optional
[
dict
]
=
None
):
q_input
,
x_scale
,
weight
,
weight_scale
=
_int8_gemm_helper_b
(
m
=
m
,
n
=
n
,
k
=
k
,
out_dtype
=
out_dtype
,
device
=
device
,
block_size
=
block_size
)
print
(
f
"zhenggf, q_input:
{
q_input
.
shape
}
, x_scale:
{
x_scale
.
shape
}
, weight:
{
weight
.
shape
}
, weight_scale:
{
weight_scale
.
shape
}
"
)
torch_output
=
native_w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
)
x_scale
=
x_scale
.
permute
(
1
,
0
).
contiguous
()
weight_scale
=
weight_scale
.
permute
(
1
,
0
).
contiguous
()
print
(
f
"zhenggf 转置后传递给triton kernel, q_input:
{
q_input
.
shape
}
, x_scale:
{
x_scale
.
shape
}
, weight:
{
weight
.
shape
}
, weight_scale:
{
weight_scale
.
shape
}
"
)
output
,
config
=
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
)
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
print
(
"triton 精度检查不合格!!!"
)
else
:
print
(
"triton 精度检查合格"
)
# unit test end
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
for
it
in
range
(
1000
):
output
,
_
=
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
)
torch
.
cuda
.
synchronize
()
start_time_
=
time
.
time
()
# 开始计时
g
.
replay
()
torch
.
cuda
.
synchronize
()
end_time_
=
time
.
time
()
# 结束计时
elapsed_time
=
round
((
end_time_
-
start_time_
)
*
1000
,
7
)
# 计算耗时
print
(
"_time:{} us
\n
"
.
format
(
elapsed_time
))
quantiles
=
[
0.5
,
0.2
,
0.8
]
gpu_costtime
=
triton
.
testing
.
do_bench
(
lambda
:
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
),
quantiles
=
None
,
return_mode
=
"mean"
)
*
1000
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
dtype
=
out_dtype
),
elapsed_time
,
gpu_costtime
,
config
def
get_triton_cache
(
file_path
,
n
,
k
,
block_n
,
block_k
):
#会将所报错的json文件以字典的形式return出来
#先读取指定的文件,该文件地址不存在则会读默认路径
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
cachedata
=
{}
# 写入空数据到新的JSON文件
with
open
(
file_path
,
'w'
)
as
file
:
json
.
dump
(
cachedata
,
file
)
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_value
=
{
'BLOCK_SIZE_M'
:
int
(
sub_value
[
"BLOCK_SIZE_M"
]),
'BLOCK_SIZE_N'
:
int
(
sub_value
[
"BLOCK_SIZE_N"
]),
'BLOCK_SIZE_K'
:
int
(
sub_value
[
"BLOCK_SIZE_K"
]),
'GROUP_SIZE_M'
:
int
(
sub_value
[
"GROUP_SIZE_M"
]),
'num_stages'
:
int
(
sub_value
[
'num_stages'
]),
'num_warps'
:
int
(
sub_value
[
'num_warps'
]),
# 'kpack':int(sub_value['kpack']),
'enable_mmacfuse'
:
int
(
2
),
}
configs_dict
[
configs_key
]
=
configs_value
return
configs_dict
def
getspec_config
(
configs_dict
,
m
,
n
,
k
,
block_n
,
block_k
):
if
f
"
{
m
}
_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]"
in
configs_dict
:
return
configs_dict
[
f
"
{
m
}
_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]"
]
else
:
return
None
# For test
def
native_w8a8_block_int8_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
=
torch
.
bfloat16
):
"""This function performs matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A
=
A
.
to
(
torch
.
float32
)
B
=
B
.
to
(
torch
.
float32
)
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
B
.
is_contiguous
()
and
Bs
.
ndim
==
2
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
(
A
.
shape
[
-
1
]
+
block_k
-
1
)
//
block_k
==
As
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
N
,
K
=
B
.
shape
origin_C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
A
=
A
.
reshape
(
M
,
A
.
shape
[
-
1
])
As
=
As
.
reshape
(
M
,
As
.
shape
[
-
1
])
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
# assert n_tiles == Bs.shape[0]
assert
k_tiles
==
Bs
.
shape
[
1
]
C_shape
=
(
M
,
N
)
C
=
torch
.
zeros
(
C_shape
,
dtype
=
torch
.
float32
,
device
=
A
.
device
)
A_tiles
=
[
A
[:,
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
)]
for
i
in
range
(
k_tiles
)]
B_tiles
=
[
[
B
[
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
),
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
),
]
for
i
in
range
(
k_tiles
)
]
for
j
in
range
(
n_tiles
)
]
C_tiles
=
[
C
[:,
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
)]
for
j
in
range
(
n_tiles
)]
As_tiles
=
[
As
[:,
i
:
i
+
1
]
for
i
in
range
(
k_tiles
)]
Bs_tiles
=
[
Bs
[:,
i
:
i
+
1
]
for
i
in
range
(
k_tiles
)]
for
i
in
range
(
k_tiles
):
for
j
in
range
(
n_tiles
):
a
=
A_tiles
[
i
]
b
=
B_tiles
[
j
][
i
]
c
=
C_tiles
[
j
]
s
=
As_tiles
[
i
]
*
Bs_tiles
[
i
].
t
()[:,
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
)]
c
[:,
:]
+=
torch
.
matmul
(
a
,
b
.
t
())
*
s
C
=
C
.
reshape
(
origin_C_shape
).
to
(
output_dtype
)
return
C
def
main
():
m1
=
[
item
if
item
<
17
else
1
<<
(
item
-
27
)
for
item
in
range
(
1
,
17
)]
m2
=
[
item
<<
2
if
item
<
17
else
(
item
-
8
)
<<
3
for
item
in
range
(
5
,
29
)]
m3
=
[
2
<<
(
item
)
for
item
in
range
(
7
,
13
)]
m_list
=
m1
+
m2
+
m3
n_list
=
[
576
,
2048
,
7168
,
256
,
7168
,
1536
,
1536
,
2304
,
7168
]
k_list
=
[
7168
,
512
,
1024
,
7168
,
128
,
7168
,
1536
,
7168
,
1152
]
m_list
=
[
8192
]
n_list
=
[
7168
]
k_list
=
[
1152
]
block_size
=
[
128
,
128
]
out_dtype
=
torch
.
bfloat16
_n
=
[]
_k
=
[]
_m
=
[]
_configs_block_m
=
[]
_configs_block_n
=
[]
_configs_block_k
=
[]
_configs_block_group_m
=
[]
_configs_block_num_warps
=
[]
_configs_block_num_stages
=
[]
_configs_kpack
=
[]
cost_times
=
[]
gpu_costtimes
=
[]
device_name
=
torch
.
cuda
.
get_device_properties
(
'cuda'
).
name
.
replace
(
" "
,
"_"
)
for
i
in
range
(
0
,
len
(
k_list
),
1
):
for
m
in
m_list
:
print
(
"m:{} n:{} k:{} "
.
format
(
m
,
n_list
[
i
],
k_list
[
i
]))
best_config
=
[]
output
,
elapsed_time
,
gpu_costtime
,
config
=
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
cost_times
.
append
(
elapsed_time
)
gpu_costtimes
.
append
(
gpu_costtime
)
_n
.
append
(
n_list
[
i
])
_k
.
append
(
k_list
[
i
])
_m
.
append
(
m
)
print
(
f
"zhenggf,
{
config
}
"
)
print
(
f
"zhenggf,
{
config
.
kwargs
}
"
)
_configs_block_m
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_M'
])
_configs_block_n
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_N'
])
_configs_block_k
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_K'
])
_configs_block_group_m
.
append
(
config
.
kwargs
[
'GROUP_SIZE_M'
])
_configs_block_num_warps
.
append
(
config
.
num_warps
)
_configs_block_num_stages
.
append
(
config
.
num_stages
)
# _configs_kpack.append(config['kpack'])
# 创建一个包含这三个列表的 DataFrame
df
=
pd
.
DataFrame
({
'm'
:
_m
,
'n'
:
_n
,
'k'
:
_k
,
'线性层gemm量化算子耗时'
:
cost_times
,
'GPU算子耗时'
:
gpu_costtimes
,
'BLOCK_SIZE_M'
:
_configs_block_m
,
'BLOCK_SIZE_N'
:
_configs_block_n
,
'BLOCK_SIZE_K'
:
_configs_block_k
,
'GROUP_SIZE_M'
:
_configs_block_group_m
,
'num_warps'
:
_configs_block_num_warps
,
'num_stages'
:
_configs_block_num_stages
,
#'kpack':_configs_kpack
})
# 将 DataFrame 写入 Excel 文件
df
.
to_excel
(
'gemmoutput.xlsx'
,
index
=
False
)
print
(
"表格已保存到 gemmoutput.xlsx 文件中。"
)
if
__name__
==
"__main__"
:
main
()
transformer_engine/pytorch/triton/per_token_group_quant.py
0 → 100644
View file @
1b303e91
import
torch
import
time
from
typing
import
Optional
,
Type
,
Any
,
Dict
,
List
,
Tuple
import
pandas
as
pd
import
os
import
json
import
triton
import
triton.language
as
tl
import
pandas
as
pd
import
logging
import
math
def
to_int8
(
tensor
:
torch
.
Tensor
):
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
@
triton
.
jit
def
_per_token_group_quant_int8
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
# Stride of input
y_stride
,
# Collums of input
N
,
# Avoid to divide zero
eps
,
# Information for int8
int8_min
,
int8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform
per-token-group quantization on a tensor.
This function converts the tensor values into int8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
y_stride
y_q_ptr
+=
g_id
*
y_stride
y_s_ptr
+=
g_id
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
mask
=
cols
<
N
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
int8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
int8_min
,
int8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
per_token_group_quant_int8
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
torch
.
int8
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.int8`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
iinfo
=
torch
.
iinfo
(
dtype
)
int8_max
=
iinfo
.
max
int8_min
=
iinfo
.
min
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
#N是blocksize[1]
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
_per_token_group_quant_int8
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
N
,
eps
,
int8_min
=
int8_min
,
int8_max
=
int8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
x_q
,
x_s
,
BLOCK
,
num_warps
,
num_stages
,
M
def
_int8_gemm_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
best_config
:
Optional
[
list
]
=
None
):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
input
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)
*
5
).
to
(
dtype
=
out_dtype
)
weight
=
to_int8
(
torch
.
randn
((
n
,
k
),
device
=
device
)
*
5
)
weight_scale
=
(
torch
.
randn
((
math
.
ceil
(
n
/
block_size
[
0
]),
math
.
ceil
(
k
/
block_size
[
1
])),
device
=
device
,
dtype
=
torch
.
float32
))
print
(
"input.dtype:"
,
input
.
dtype
)
#print("m:{} n:{} k:{},weight_scale.shape:{}".format(m,n,k,weight_scale.shape))
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
q_input
,
x_scale
,
_
,
_
,
_
,
_
=
per_token_group_quant_int8
(
input_2d
,
block_size
[
1
])
return
q_input
,
x_scale
,
weight
,
weight_scale
def
_int8_gemm_helper_b
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
best_config
:
Optional
[
list
]
=
None
):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
input
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)
*
5
).
to
(
dtype
=
out_dtype
)
weight
=
to_int8
(
torch
.
randn
((
n
,
k
),
device
=
device
)
*
5
)
weight_scale
=
(
torch
.
randn
((
n
,
math
.
ceil
(
k
/
block_size
[
1
])),
device
=
device
,
dtype
=
torch
.
float32
))
print
(
"input.dtype:"
,
input
.
dtype
)
#print("m:{} n:{} k:{},weight_scale.shape:{}".format(m,n,k,weight_scale.shape))
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
q_input
,
x_scale
,
_
,
_
,
_
,
_
=
per_token_group_quant_int8
(
input_2d
,
block_size
[
1
])
return
q_input
,
x_scale
,
weight
,
weight_scale
def
_int8_gemm_helper_test
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
best_config
:
Optional
[
list
]
=
None
):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
input
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)
*
5
).
to
(
dtype
=
out_dtype
)
weight
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)
*
5
).
t
().
to
(
dtype
=
out_dtype
)
print
(
"input.dtype:"
,
input
.
dtype
)
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
q_input
,
x_scale
,
BLOCK
,
num_warps
,
num_stages
,
M
=
per_token_group_quant_int8
(
input_2d
,
block_size
[
1
])
start_time_
=
time
.
time
()
# 开始计时
for
it
in
range
(
1000
):
q_input
,
x_scale
,
_
,
_
,
_
,
_
=
per_token_group_quant_int8
(
input_2d
,
block_size
[
1
])
torch
.
cuda
.
synchronize
()
end_time_
=
time
.
time
()
# 结束计时
elapsed_time
=
round
((
end_time_
-
start_time_
)
*
1000
,
7
)
# 计算耗时
print
(
"_time:{} us
\n
"
.
format
(
elapsed_time
))
return
q_input
,
x_scale
,
elapsed_time
,
BLOCK
,
num_warps
,
num_stages
,
M
def
main
():
m_list
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
32768
]
n_list
=
[
576
,
2048
,
7168
,
256
,
7168
,
1536
,
1536
]
k_list
=
[
7168
,
512
,
1024
,
7168
,
128
,
7168
,
1536
]
block_size
=
[
128
,
128
]
out_dtype
=
torch
.
bfloat16
_n
=
[]
_k
=
[]
_m
=
[]
config_blocks
=
[]
config_num_warps
=
[]
config_num_stages
=
[]
config_M
=
[]
cost_times
=
[]
for
i
in
range
(
0
,
len
(
k_list
),
1
):
for
m
in
m_list
:
print
(
"m:{} n:{} k:{} "
.
format
(
m
,
n_list
[
i
],
k_list
[
i
]))
q_input
,
x_scale
,
elapsed_time
,
BLOCK
,
num_warps
,
num_stages
,
M
=
_int8_gemm_helper_test
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
torch
.
bfloat16
)
cost_times
.
append
(
elapsed_time
)
_n
.
append
(
n_list
[
i
])
_k
.
append
(
k_list
[
i
])
_m
.
append
(
m
)
config_blocks
.
append
(
BLOCK
)
config_num_warps
.
append
(
num_warps
)
config_num_stages
.
append
(
num_stages
)
config_M
.
append
(
M
)
# 创建一个包含这三个列表的 DataFrame
df
=
pd
.
DataFrame
({
'm'
:
_m
,
'n'
:
_n
,
'k'
:
_k
,
'量化算子耗时'
:
cost_times
,
'BLOCK'
:
config_blocks
,
'num_warps'
:
config_num_warps
,
'config_num_stages'
:
config_num_stages
,
'config_M'
:
config_M
})
# 将 DataFrame 写入 Excel 文件
df
.
to_excel
(
'output.xlsx'
,
index
=
False
)
print
(
"表格已保存到 output.xlsx 文件中。"
)
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment