Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f6937668
Commit
f6937668
authored
Jun 04, 2025
by
yuguo
Browse files
[DCU] support block fp8 simu with int8 for Dense
parent
521f8d3b
Changes
11
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
2335 additions
and
9 deletions
+2335
-9
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+2
-0
tests/pytorch/test_float8_current_scaling_exact.py
tests/pytorch/test_float8_current_scaling_exact.py
+2
-2
tests/pytorch/test_int8_blockwise_gemm_exact.py
tests/pytorch/test_int8_blockwise_gemm_exact.py
+708
-0
tests/pytorch/test_int8_blockwise_layers.py
tests/pytorch/test_int8_blockwise_layers.py
+175
-0
transformer_engine/pytorch/constants.py
transformer_engine/pytorch/constants.py
+1
-0
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+64
-1
transformer_engine/pytorch/fp8.py
transformer_engine/pytorch/fp8.py
+9
-4
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
+3
-2
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
+563
-0
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
...mer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
+568
-0
transformer_engine/pytorch/triton/per_token_group_quant.py
transformer_engine/pytorch/triton/per_token_group_quant.py
+240
-0
No files found.
qa/L0_pytorch_unittest/test.sh
View file @
f6937668
...
...
@@ -36,6 +36,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8blockwisetensor.xml
$TE_PATH
/tests/pytorch/test_float8blockwisetensor.py
||
test_fail
"test_float8blockwisetensor.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_scaling_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_scaling_exact.py
||
test_fail
"test_float8_blockwise_scaling_exact.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_float8_blockwise_gemm_exact.xml
$TE_PATH
/tests/pytorch/test_float8_blockwise_gemm_exact.py
||
test_fail
"test_float8_blockwise_gemm_exact.py"
python3
$TE_PATH
/tests/pytorch/test_int8_blockwise_gemm_exact.py
NVTE_INT8_SIM_FP8
=
1 python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_int8_blockwise_layers.xml
$TE_PATH
/tests/pytorch/test_int8_blockwise_layers
||
test_fail
"test_int8_blockwise_layers.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_gqa.xml
$TE_PATH
/tests/pytorch/test_gqa.py
||
test_fail
"test_gqa.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_optimizer.xml
$TE_PATH
/tests/pytorch/test_fused_optimizer.py
||
test_fail
"test_fused_optimizer.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_multi_tensor.xml
$TE_PATH
/tests/pytorch/test_multi_tensor.py
||
test_fail
"test_multi_tensor.py"
...
...
tests/pytorch/test_float8_current_scaling_exact.py
View file @
f6937668
...
...
@@ -385,7 +385,7 @@ class TestFP8RecipeLinearBase:
)
# recipe1
using_fp8_recipe
=
recipe1
!=
GetRecipes
.
n
one
using_fp8_recipe
=
recipe1
()
is
not
N
one
if
using_fp8_recipe
:
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
recipe1
()):
y_q_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
self
.
run_linear
(
x
,
w
,
bias
,
gradient
)
...
...
@@ -608,7 +608,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
)
# recipe1
using_fp8_recipe
=
recipe1
!=
GetRecipes
.
n
one
using_fp8_recipe
=
recipe1
()
is
not
N
one
if
using_fp8_recipe
:
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
recipe1
()):
y_q_ref
,
ln_out_ref
,
dgrad_ref
,
wgrad_ref
,
bgrad_ref
=
self
.
run_layernorm_linear
(
...
...
tests/pytorch/test_int8_blockwise_gemm_exact.py
0 → 100644
View file @
f6937668
This diff is collapsed.
Click to expand it.
tests/pytorch/test_int8_blockwise_layers.py
0 → 100644
View file @
f6937668
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
typing
import
Tuple
import
math
import
os
import
pathlib
import
pytest
import
torch
import
transformer_engine
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.common.recipe
import
Float8BlockScaling
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
)
from
references.blockwise_quantizer_reference
import
(
BlockwiseQuantizerReference
,
QuantizeResult
,
)
from
test_float8_current_scaling_exact
import
(
TestFP8RecipeLinearBase
,
TestFP8RecipeLayerNormLinearBase
,
)
import
logging
# read env variable NVTE_TEST_FLOAT8_BLOCK_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory
TENSOR_DUMP_DIR
=
pathlib
.
Path
(
__file__
).
resolve
().
parent
.
parent
.
parent
/
"tensor_dumps"
tensor_dump_dir_env
=
os
.
getenv
(
"NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR"
)
if
tensor_dump_dir_env
is
not
None
:
TENSOR_DUMP_DIR
=
pathlib
.
Path
(
tensor_dump_dir_env
)
recipe_available
,
reason_for_no_recipe
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
class
GetRecipes
:
@
staticmethod
def
none
():
return
None
@
staticmethod
def
fp8_blockwise
():
# return default configs
return
Float8BlockScaling
()
# FP8 per tesnor current scaling
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
class
TestFP8BlockScalingRecipeLinear
(
TestFP8RecipeLinearBase
):
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
@
pytest
.
mark
.
parametrize
(
"batch_size, hidden_size, out_size"
,
[
(
16
,
256
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
[
"bf16"
])
@
pytest
.
mark
.
parametrize
(
"recipe1, recipe2"
,
[
(
GetRecipes
.
none
,
GetRecipes
.
fp8_blockwise
),
],
)
def
test_fp8_current_scaling_with_linear_module
(
self
,
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
dtype
,
use_bias
=
False
,
):
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map
=
self
.
_check_golden_tensor_dumps
(
TENSOR_DUMP_DIR
,
recipe2
,
(
batch_size
,
hidden_size
,
out_size
),
dtype
,
use_bias
)
if
tensor_map
is
not
None
:
fp8_zero_tolerance_tensor_dumps_recipe2
=
tensor_map
assert
recipe1
==
GetRecipes
.
none
,
"Only None recipe is supported for recipe1"
self
.
compare_recipe
(
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
use_bias
,
seed
=
torch
.
initial_seed
(),
dtype
=
dtype
,
y_error
=
0.5
,
dgrad_error
=
1
,
wgrad_error
=
1
,
bgrad_error
=
0.5
,
recipe1_golden_tensors
=
None
,
recipe2_golden_tensors
=
fp8_zero_tolerance_tensor_dumps_recipe2
,
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
class
TestFP8BlockScalingRecipeLayerNormLinear
(
TestFP8RecipeLayerNormLinearBase
):
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
@
pytest
.
mark
.
parametrize
(
"batch_size, hidden_size, out_size"
,
[
(
16
,
256
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
[
"bf16"
])
@
pytest
.
mark
.
parametrize
(
"recipe1, recipe2"
,
[
(
GetRecipes
.
none
,
GetRecipes
.
fp8_blockwise
),
],
)
def
test_fp8_current_scaling_with_layernorm_linear_module
(
self
,
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
dtype
,
use_bias
=
False
,
):
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map
=
self
.
_check_golden_tensor_dumps
(
TENSOR_DUMP_DIR
,
recipe2
,
(
batch_size
,
hidden_size
,
out_size
),
dtype
,
use_bias
,
"LayerNorm"
,
)
if
tensor_map
is
not
None
:
fp8_zero_tolerance_tensor_dumps_recipe2
=
tensor_map
self
.
compare_recipe
(
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
use_bias
,
seed
=
torch
.
initial_seed
(),
dtype
=
dtype
,
y_error
=
0.9
,
ln_out_error
=
0.5
,
dgrad_error
=
1.5
,
wgrad_error
=
1
,
bgrad_error
=
0.5
,
recipe1_golden_tensors
=
None
,
recipe2_golden_tensors
=
fp8_zero_tolerance_tensor_dumps_recipe2
,
)
transformer_engine/pytorch/constants.py
View file @
f6937668
...
...
@@ -35,6 +35,7 @@ TE_DType_To_Torch = {
tex
.
DType
.
kByte
:
torch
.
uint8
,
tex
.
DType
.
kFloat8E4M3
:
torch
.
float8_e4m3fn
,
tex
.
DType
.
kFloat8E5M2
:
torch
.
float8_e5m2
,
tex
.
DType
.
kInt8
:
torch
.
int8
,
tex
.
DType
.
kInt32
:
torch
.
int32
,
tex
.
DType
.
kFloat32
:
torch
.
float32
,
tex
.
DType
.
kFloat16
:
torch
.
half
,
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
f6937668
...
...
@@ -10,11 +10,13 @@ import torch
import
transformer_engine_torch
as
tex
from
..constants
import
TE_DType
from
..utils
import
get_sm_count
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad
import
w8a8_block_int8_matmul_wgrad
from
..tensor.quantized_tensor
import
Quantizer
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
__all__
=
[
"general_gemm"
,
"general_grouped_gemm"
,
...
...
@@ -60,6 +62,67 @@ def general_gemm(
# + "a valid `ub` communicator object."
# )
if
int8_simulation_fp8
and
(
isinstance
(
A
,
Float8BlockwiseQTensorBase
)
or
isinstance
(
B
,
Float8BlockwiseQTensorBase
)):
assert
not
gelu
,
"GELU not supported with int8 simulation"
assert
gelu_in
is
None
,
"GELU input not supported with int8 simulation"
assert
bias
is
None
,
"Bias not supported with int8 simulation"
assert
not
accumulate
,
"Accumulation not supported with int8 simulation"
assert
ub
is
None
,
"User buffer not supported with int8 simulation"
assert
ub_type
is
None
,
"User buffer type not supported with int8 simulation"
assert
extra_output
is
None
,
"Extra output not supported with int8 simulation"
assert
not
bulk_overlap
,
"Bulk overlap not supported with int8 simulation"
if
layout
==
"TN"
:
qx_data
=
(
B
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
)
qw_data
=
(
A
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
)
ref_scales_x
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
128
,
128
],
output_dtype
=
out_dtype
)
return
y
,
None
,
None
,
None
elif
layout
==
"NN"
:
qdout_data
=
(
B
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
)
qw_data
=
(
A
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
)
ref_scales_dout
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_columnwise_scale_inv
y
,
_
=
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
128
,
128
],
output_dtype
=
out_dtype
)
return
y
,
None
,
None
,
None
elif
layout
==
"NT"
:
qdout_data
=
(
B
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
)
qx_data
=
(
A
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
)
ref_scales_dout
=
B
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
y
,
_
=
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
[
128
,
128
],
output_dtype
=
out_dtype
)
return
y
,
None
,
None
,
None
else
:
raise
ValueError
(
f
"Unsupported layout
{
layout
}
in int8 simulation fp8"
)
if
ub
is
not
None
:
assert
ub_type
is
not
None
,
"Comm+GEMM overlap requires a valid `comm_type` argument."
if
ub_type
==
tex
.
CommOverlapType
.
RS
:
...
...
transformer_engine/pytorch/fp8.py
View file @
f6937668
...
...
@@ -27,16 +27,18 @@ from .constants import dist_group_type
from
.utils
import
get_device_compute_capability
from
.jit
import
jit_fuser
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
__all__
=
[
"fp8_autocast"
,
"fp8_model_init"
]
if
IS_HIP_EXTENSION
:
from
transformer_engine.pytorch.utils
import
is_K100_AI
,
is_BW
def
check_fp8_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 support is available"""
if
IS_HIP_EXTENSION
:
if
get_device_compute_capability
()
==
(
9
,
4
)
:
return
True
,
""
if
(
is_K100_AI
()
or
is_BW
())
and
int8_simulation_fp8
:
return
True
,
"
DCU turn on fp8 simulation with int8
"
else
:
return
False
,
"DCU not support fp8 for now"
else
:
...
...
@@ -61,7 +63,10 @@ def check_mxfp8_support() -> Tuple[bool, str]:
def
check_fp8_block_scaling_support
()
->
Tuple
[
bool
,
str
]:
"""Return if fp8 block scaling support is available"""
if
IS_HIP_EXTENSION
:
if
is_K100_AI
()
or
is_BW
():
return
True
,
""
else
:
return
False
,
"DCU not support block_scaling fp8 for now"
if
(
get_device_compute_capability
()
>=
(
9
,
0
)
and
get_device_compute_capability
()
<
(
10
,
0
)
...
...
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
View file @
f6937668
...
...
@@ -9,7 +9,7 @@ from typing import Optional, Tuple, Iterable
import
math
import
torch
import
transformer_engine_torch
as
tex
import
os
from
transformer_engine_torch
import
DType
as
TE_DType
from
._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
...
...
@@ -17,6 +17,7 @@ from ..utils import devices_match, round_up_to_nearest_multiple
aten
=
torch
.
ops
.
aten
int8_simulation_fp8
=
bool
(
int
(
os
.
getenv
(
"NVTE_INT8_SIM_FP8"
,
"0"
)))
class
Float8BlockQuantizer
(
Quantizer
):
"""Builder class for tensors quantized with current scaling using
...
...
@@ -44,7 +45,7 @@ class Float8BlockQuantizer(Quantizer):
block_scaling_dim
:
int
=
2
,
)
->
None
:
super
().
__init__
(
rowwise
=
rowwise
,
columnwise
=
columnwise
)
self
.
dtype
=
fp8_dtype
self
.
dtype
=
tex
.
DType
.
kInt8
if
int8_simulation_fp8
else
fp8_dtype
self
.
block_len
=
128
self
.
force_pow_2_scales
=
force_pow_2_scales
self
.
amax_epsilon
=
amax_epsilon
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
0 → 100644
View file @
f6937668
This diff is collapsed.
Click to expand it.
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
0 → 100644
View file @
f6937668
This diff is collapsed.
Click to expand it.
transformer_engine/pytorch/triton/per_token_group_quant.py
0 → 100644
View file @
f6937668
import
torch
import
time
from
typing
import
Optional
,
Type
,
Any
,
Dict
,
List
,
Tuple
import
pandas
as
pd
import
os
import
json
import
triton
import
triton.language
as
tl
import
pandas
as
pd
import
logging
import
math
def
to_int8
(
tensor
:
torch
.
Tensor
):
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
@
triton
.
jit
def
_per_token_group_quant_int8
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
# Stride of input
y_stride
,
# Collums of input
N
,
# Avoid to divide zero
eps
,
# Information for int8
int8_min
,
int8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform
per-token-group quantization on a tensor.
This function converts the tensor values into int8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
y_stride
y_q_ptr
+=
g_id
*
y_stride
y_s_ptr
+=
g_id
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
mask
=
cols
<
N
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
int8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
int8_min
,
int8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
per_token_group_quant_int8
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
torch
.
int8
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.int8`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
iinfo
=
torch
.
iinfo
(
dtype
)
int8_max
=
iinfo
.
max
int8_min
=
iinfo
.
min
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
#N是blocksize[1]
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
_per_token_group_quant_int8
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
N
,
eps
,
int8_min
=
int8_min
,
int8_max
=
int8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
x_q
,
x_s
,
BLOCK
,
num_warps
,
num_stages
,
M
def
_int8_gemm_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
best_config
:
Optional
[
list
]
=
None
):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
input
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)
*
5
).
to
(
dtype
=
out_dtype
)
weight
=
to_int8
(
torch
.
randn
((
n
,
k
),
device
=
device
)
*
5
)
weight_scale
=
(
torch
.
randn
((
math
.
ceil
(
n
/
block_size
[
0
]),
math
.
ceil
(
k
/
block_size
[
1
])),
device
=
device
,
dtype
=
torch
.
float32
))
print
(
"input.dtype:"
,
input
.
dtype
)
#print("m:{} n:{} k:{},weight_scale.shape:{}".format(m,n,k,weight_scale.shape))
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
q_input
,
x_scale
,
_
,
_
,
_
,
_
=
per_token_group_quant_int8
(
input_2d
,
block_size
[
1
])
return
q_input
,
x_scale
,
weight
,
weight_scale
def
_int8_gemm_helper_b
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
best_config
:
Optional
[
list
]
=
None
):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
input
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)
*
5
).
to
(
dtype
=
out_dtype
)
weight
=
to_int8
(
torch
.
randn
((
n
,
k
),
device
=
device
)
*
5
)
weight_scale
=
(
torch
.
randn
((
n
,
math
.
ceil
(
k
/
block_size
[
1
])),
device
=
device
,
dtype
=
torch
.
float32
))
print
(
"input.dtype:"
,
input
.
dtype
)
#print("m:{} n:{} k:{},weight_scale.shape:{}".format(m,n,k,weight_scale.shape))
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
q_input
,
x_scale
,
_
,
_
,
_
,
_
=
per_token_group_quant_int8
(
input_2d
,
block_size
[
1
])
return
q_input
,
x_scale
,
weight
,
weight_scale
def
_int8_gemm_helper_test
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
best_config
:
Optional
[
list
]
=
None
):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
input
=
(
torch
.
randn
((
m
,
k
),
device
=
device
)
*
5
).
to
(
dtype
=
out_dtype
)
weight
=
(
torch
.
randn
((
n
,
k
),
device
=
device
)
*
5
).
t
().
to
(
dtype
=
out_dtype
)
print
(
"input.dtype:"
,
input
.
dtype
)
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
q_input
,
x_scale
,
BLOCK
,
num_warps
,
num_stages
,
M
=
per_token_group_quant_int8
(
input_2d
,
block_size
[
1
])
start_time_
=
time
.
time
()
# 开始计时
for
it
in
range
(
1000
):
q_input
,
x_scale
,
_
,
_
,
_
,
_
=
per_token_group_quant_int8
(
input_2d
,
block_size
[
1
])
torch
.
cuda
.
synchronize
()
end_time_
=
time
.
time
()
# 结束计时
elapsed_time
=
round
((
end_time_
-
start_time_
)
*
1000
,
7
)
# 计算耗时
print
(
"_time:{} us
\n
"
.
format
(
elapsed_time
))
return
q_input
,
x_scale
,
elapsed_time
,
BLOCK
,
num_warps
,
num_stages
,
M
def
main
():
m_list
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
32768
]
n_list
=
[
576
,
2048
,
7168
,
256
,
7168
,
1536
,
1536
]
k_list
=
[
7168
,
512
,
1024
,
7168
,
128
,
7168
,
1536
]
block_size
=
[
128
,
128
]
out_dtype
=
torch
.
bfloat16
_n
=
[]
_k
=
[]
_m
=
[]
config_blocks
=
[]
config_num_warps
=
[]
config_num_stages
=
[]
config_M
=
[]
cost_times
=
[]
for
i
in
range
(
0
,
len
(
k_list
),
1
):
for
m
in
m_list
:
print
(
"m:{} n:{} k:{} "
.
format
(
m
,
n_list
[
i
],
k_list
[
i
]))
q_input
,
x_scale
,
elapsed_time
,
BLOCK
,
num_warps
,
num_stages
,
M
=
_int8_gemm_helper_test
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
torch
.
bfloat16
)
cost_times
.
append
(
elapsed_time
)
_n
.
append
(
n_list
[
i
])
_k
.
append
(
k_list
[
i
])
_m
.
append
(
m
)
config_blocks
.
append
(
BLOCK
)
config_num_warps
.
append
(
num_warps
)
config_num_stages
.
append
(
num_stages
)
config_M
.
append
(
M
)
# 创建一个包含这三个列表的 DataFrame
df
=
pd
.
DataFrame
({
'm'
:
_m
,
'n'
:
_n
,
'k'
:
_k
,
'量化算子耗时'
:
cost_times
,
'BLOCK'
:
config_blocks
,
'num_warps'
:
config_num_warps
,
'config_num_stages'
:
config_num_stages
,
'config_M'
:
config_M
})
# 将 DataFrame 写入 Excel 文件
df
.
to_excel
(
'output.xlsx'
,
index
=
False
)
print
(
"表格已保存到 output.xlsx 文件中。"
)
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment