Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
fdf60506
Commit
fdf60506
authored
Jul 11, 2025
by
wenjh
Browse files
Merge branch 'develop_v2.4'
parents
403db136
3b1f30a9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
211 additions
and
152 deletions
+211
-152
tests/pytorch/test_int8_blockwise_gemm_exact.py
tests/pytorch/test_int8_blockwise_gemm_exact.py
+35
-16
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+31
-15
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
+67
-57
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
...mer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
+78
-64
No files found.
tests/pytorch/test_int8_blockwise_gemm_exact.py
View file @
fdf60506
...
@@ -2,7 +2,8 @@ import pytest
...
@@ -2,7 +2,8 @@ import pytest
import
torch
import
torch
import
transformer_engine
as
te
import
transformer_engine
as
te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
import
w8a8_matmul_extension
from
transformer_engine.pytorch
import
get_device_compute_capability
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
...
@@ -195,11 +196,17 @@ def cublas_gemm_fp8_blockwise_case_fw(
...
@@ -195,11 +196,17 @@ def cublas_gemm_fp8_blockwise_case_fw(
)
)
ref_scales_x
=
qx
.
_columnwise_scale_inv
if
x_columnwise
else
qx
.
_rowwise_scale_inv
ref_scales_x
=
qx
.
_columnwise_scale_inv
if
x_columnwise
else
qx
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
y
,
_
=
w8a8_block_int8_matmul
(
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
:
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
y
,
_
=
w8a8_block_int8_matmul
(
output_dtype
=
out_dtype
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
)
output_dtype
=
out_dtype
)
else
:
y
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
output_dtype
=
out_dtype
)
# print("int8 gemm output: ", y)
# print("int8 gemm output: ", y)
# print("int8 gemm output shape: ", y.shape)
# print("int8 gemm output shape: ", y.shape)
...
@@ -374,10 +381,16 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
...
@@ -374,10 +381,16 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
ref_scales_dout
=
qdout
.
_columnwise_scale_inv
if
dout_columnwise
else
qdout
.
_rowwise_scale_inv
ref_scales_dout
=
qdout
.
_columnwise_scale_inv
if
dout_columnwise
else
qdout
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
y
,
_
=
w8a8_block_int8_matmul
(
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
:
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
block_len
,
block_len
],
y
,
_
=
w8a8_block_int8_matmul
(
output_dtype
=
dx_dtype
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
block_len
,
block_len
],
)
output_dtype
=
dx_dtype
)
else
:
y
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
block_len
,
block_len
],
output_dtype
=
dx_dtype
)
# print("int8 gemm dx: ", y)
# print("int8 gemm dx: ", y)
...
@@ -553,12 +566,18 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
...
@@ -553,12 +566,18 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
:
y
,
_
=
w8a8_block_int8_matmul_wgrad
(
y
,
_
=
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
dw
.
clone
()
if
accumulate
else
None
,
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
dw
.
clone
()
if
accumulate
else
None
,
accumulate
,
[
block_len
,
block_len
],
accumulate
,
[
block_len
,
block_len
],
output_dtype
=
dw_dtype
output_dtype
=
dw_dtype
)
)
else
:
y
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
dw
.
clone
()
if
accumulate
else
None
,
accumulate
,
[
block_len
,
block_len
],
output_dtype
=
dw_dtype
)
# print("int8 gemm dw: ",y)
# print("int8 gemm dw: ",y)
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
fdf60506
...
@@ -8,6 +8,7 @@ from typing import Iterable, Optional, Tuple, Union, List
...
@@ -8,6 +8,7 @@ from typing import Iterable, Optional, Tuple, Union, List
import
os
import
os
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
import
w8a8_matmul_extension
from
..constants
import
TE_DType
from
..constants
import
TE_DType
from
..utils
import
get_sm_count
,
_empty_tensor
from
..utils
import
get_sm_count
,
_empty_tensor
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
,
w8a8_block_int8_matmul_batched
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
,
w8a8_block_int8_matmul_batched
...
@@ -75,11 +76,16 @@ def general_gemm(
...
@@ -75,11 +76,16 @@ def general_gemm(
)
)
ref_scales_x
=
B
.
_rowwise_scale_inv
ref_scales_x
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
:
y
,
_
=
w8a8_block_int8_matmul
(
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
else
:
y
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
)
return
y
,
None
,
None
,
None
return
y
,
None
,
None
,
None
elif
layout
==
"NN"
:
elif
layout
==
"NN"
:
...
@@ -91,11 +97,16 @@ def general_gemm(
...
@@ -91,11 +97,16 @@ def general_gemm(
)
)
ref_scales_dout
=
B
.
_rowwise_scale_inv
ref_scales_dout
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_columnwise_scale_inv
ref_scales_w
=
A
.
_columnwise_scale_inv
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
:
y
,
_
=
w8a8_block_int8_matmul
(
y
,
_
=
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
else
:
y
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
)
return
y
,
None
,
None
,
None
return
y
,
None
,
None
,
None
elif
layout
==
"NT"
:
elif
layout
==
"NT"
:
...
@@ -107,11 +118,16 @@ def general_gemm(
...
@@ -107,11 +118,16 @@ def general_gemm(
)
)
ref_scales_dout
=
B
.
_columnwise_scale_inv
ref_scales_dout
=
B
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
:
out
,
_
=
w8a8_block_int8_matmul_wgrad
(
out
,
_
=
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
else
:
out
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
)
return
out
,
None
,
None
,
None
return
out
,
None
,
None
,
None
else
:
else
:
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
View file @
fdf60506
...
@@ -11,6 +11,8 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
...
@@ -11,6 +11,8 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
import
functools
import
functools
import
logging
import
logging
import
w8a8_matmul_extension
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -574,7 +576,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
...
@@ -574,7 +576,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
output
=
[
q_input
.
new_empty
(
C_shape
,
dtype
=
out_dtype
)
for
i
in
range
(
batch
)]
output
=
[
q_input
.
new_empty
(
C_shape
,
dtype
=
out_dtype
)
for
i
in
range
(
batch
)]
output
=
torch
.
stack
(
output
).
contiguous
()
output
=
torch
.
stack
(
output
).
contiguous
()
torch_output
=
native_w8a8_block_int8_matmul_batched
(
q_input_b
,
weight_b
,
x_scale_b
,
weight_scale_b
,
block_size
)
torch_output
=
native_w8a8_block_int8_matmul_batched
(
q_input_b
,
weight_b
,
x_scale_b
,
weight_scale_b
,
block_size
,
out_dtype
)
torch_output
=
torch_output
.
view
(
-
1
,
torch_output
.
size
(
-
1
))
torch_output
=
torch_output
.
view
(
-
1
,
torch_output
.
size
(
-
1
))
# print(f"zhenggf, torch_output:{torch_output.shape}")
# print(f"zhenggf, torch_output:{torch_output.shape}")
...
@@ -605,16 +607,20 @@ def apply_w8a8_block_int8_linear_helper(m: int,
...
@@ -605,16 +607,20 @@ def apply_w8a8_block_int8_linear_helper(m: int,
q_input
,
x_scale
,
weight
,
weight_scale
=
_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
out_dtype
=
out_dtype
,
device
=
device
,
block_size
=
block_size
)
q_input
,
x_scale
,
weight
,
weight_scale
=
_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
out_dtype
=
out_dtype
,
device
=
device
,
block_size
=
block_size
)
print
(
f
"zhenggf, q_input:
{
q_input
.
shape
}
, x_scale:
{
x_scale
.
shape
}
, weight:
{
weight
.
shape
}
, weight_scale:
{
weight_scale
.
shape
}
"
)
print
(
f
"zhenggf, q_input:
{
q_input
.
shape
}
, x_scale:
{
x_scale
.
shape
}
, weight:
{
weight
.
shape
}
, weight_scale:
{
weight_scale
.
shape
}
"
)
torch_output
=
native_w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
)
torch_output
=
native_w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
out_dtype
)
x_scale
=
x_scale
.
permute
(
1
,
0
).
contiguous
()
x_scale
=
x_scale
.
permute
(
1
,
0
).
contiguous
()
output
,
config
=
w8a8_block_int8_matmul
(
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output
,
config
=
w8a8_block_int8_matmul
(
output_dtype
=
out_dtype
,
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
best_config
=
best_config
output_dtype
=
out_dtype
,
)
best_config
=
best_config
)
else
:
output
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
out_dtype
)
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
print
(
"triton 精度检查不合格!!!"
)
print
(
"triton 精度检查不合格!!!"
)
...
@@ -622,28 +628,29 @@ def apply_w8a8_block_int8_linear_helper(m: int,
...
@@ -622,28 +628,29 @@ def apply_w8a8_block_int8_linear_helper(m: int,
else
:
else
:
print
(
"triton 精度检查合格"
)
print
(
"triton 精度检查合格"
)
# unit test end
# unit test end
g
=
torch
.
cuda
.
CUDAGraph
()
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
with
torch
.
cuda
.
graph
(
g
):
g
=
torch
.
cuda
.
CUDAGraph
()
for
it
in
range
(
1000
):
with
torch
.
cuda
.
graph
(
g
):
output
,
_
=
w8a8_block_int8_matmul
(
for
it
in
range
(
1000
):
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output
,
_
=
w8a8_block_int8_matmul
(
output_dtype
=
out_dtype
,
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
best_config
=
best_config
output_dtype
=
out_dtype
,
)
best_config
=
best_config
torch
.
cuda
.
synchronize
()
)
start_time_
=
time
.
time
()
# 开始计时
torch
.
cuda
.
synchronize
()
g
.
replay
()
start_time_
=
time
.
time
()
# 开始计时
torch
.
cuda
.
synchronize
()
g
.
replay
()
end_time_
=
time
.
time
()
# 结束计时
torch
.
cuda
.
synchronize
()
end_time_
=
time
.
time
()
# 结束计时
elapsed_time
=
round
((
end_time_
-
start_time_
)
*
1000
,
7
)
# 计算耗时
print
(
"_time:{} us
\n
"
.
format
(
elapsed_time
))
elapsed_time
=
round
((
end_time_
-
start_time_
)
*
1000
,
7
)
# 计算耗时
quantiles
=
[
0.5
,
0.2
,
0.8
]
print
(
"_time:{} us
\n
"
.
format
(
elapsed_time
))
gpu_costtime
=
triton
.
testing
.
do_bench
(
lambda
:
w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
),
quantiles
=
None
,
return_mode
=
"mean"
)
*
1000
quantiles
=
[
0.5
,
0.2
,
0.8
]
gpu_costtime
=
triton
.
testing
.
do_bench
(
lambda
:
w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
),
quantiles
=
None
,
return_mode
=
"mean"
)
*
1000
if
bias
is
not
None
:
output
=
output
+
bias
if
bias
is
not
None
:
return
output
.
to
(
dtype
=
out_dtype
),
elapsed_time
,
gpu_costtime
,
config
output
=
output
+
bias
return
output
.
to
(
dtype
=
out_dtype
),
elapsed_time
,
gpu_costtime
,
config
def
get_triton_cache
(
file_path
,
n
,
k
,
block_n
,
block_k
):
def
get_triton_cache
(
file_path
,
n
,
k
,
block_n
,
block_k
):
#会将所报错的json文件以字典的形式return出来
#会将所报错的json文件以字典的形式return出来
...
@@ -800,33 +807,36 @@ def main():
...
@@ -800,33 +807,36 @@ def main():
best_config
=
[]
best_config
=
[]
apply_w8a8_block_int8_linear_batched_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
apply_w8a8_block_int8_linear_batched_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
output
,
elapsed_time
,
gpu_costtime
,
config
=
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
output
,
elapsed_time
,
gpu_costtime
,
config
=
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
cost_times
.
append
(
elapsed_time
)
gpu_costtimes
.
append
(
gpu_costtime
)
cost_times
.
append
(
elapsed_time
)
_n
.
append
(
n_list
[
i
])
gpu_costtimes
.
append
(
gpu_costtime
)
_k
.
append
(
k_list
[
i
])
_n
.
append
(
n_list
[
i
])
_m
.
append
(
m
)
_k
.
append
(
k_list
[
i
])
print
(
f
"zhenggf,
{
config
}
"
)
_m
.
append
(
m
)
print
(
f
"zhenggf,
{
config
.
kwargs
}
"
)
print
(
f
"zhenggf,
{
config
}
"
)
_configs_block_m
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_M'
])
print
(
f
"zhenggf,
{
config
.
kwargs
}
"
)
_configs_block_n
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_N'
])
_configs_block_m
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_M'
])
_configs_block_k
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_K'
])
_configs_block_n
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_N'
])
_configs_block_group_m
.
append
(
config
.
kwargs
[
'GROUP_SIZE_M'
])
_configs_block_k
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_K'
])
_configs_block_num_warps
.
append
(
config
.
num_warps
)
_configs_block_group_m
.
append
(
config
.
kwargs
[
'GROUP_SIZE_M'
])
_configs_block_num_stages
.
append
(
config
.
num_stages
)
_configs_block_num_warps
.
append
(
config
.
num_warps
)
# _configs_kpack.append(config['kpack'])
_configs_block_num_stages
.
append
(
config
.
num_stages
)
# _configs_kpack.append(config['kpack'])
else
:
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
# 创建一个包含这三个列表的 DataFrame
# 创建一个包含这三个列表的 DataFrame
df
=
pd
.
DataFrame
({
'm'
:
_m
,
'n'
:
_n
,
'k'
:
_k
,
'线性层gemm量化算子耗时'
:
cost_times
,
'GPU算子耗时'
:
gpu_costtimes
,
df
=
pd
.
DataFrame
({
'm'
:
_m
,
'n'
:
_n
,
'k'
:
_k
,
'线性层gemm量化算子耗时'
:
cost_times
,
'GPU算子耗时'
:
gpu_costtimes
,
'BLOCK_SIZE_M'
:
_configs_block_m
,
'BLOCK_SIZE_N'
:
_configs_block_n
,
'BLOCK_SIZE_K'
:
_configs_block_k
,
'BLOCK_SIZE_M'
:
_configs_block_m
,
'BLOCK_SIZE_N'
:
_configs_block_n
,
'BLOCK_SIZE_K'
:
_configs_block_k
,
'GROUP_SIZE_M'
:
_configs_block_group_m
,
'num_warps'
:
_configs_block_num_warps
,
'num_stages'
:
_configs_block_num_stages
,
#'kpack':_configs_kpack
'GROUP_SIZE_M'
:
_configs_block_group_m
,
'num_warps'
:
_configs_block_num_warps
,
'num_stages'
:
_configs_block_num_stages
,
#'kpack':_configs_kpack
})
})
# 将 DataFrame 写入 Excel 文件
# 将 DataFrame 写入 Excel 文件
df
.
to_excel
(
'gemmoutput.xlsx'
,
index
=
False
)
df
.
to_excel
(
'gemmoutput.xlsx'
,
index
=
False
)
print
(
"表格已保存到 gemmoutput.xlsx 文件中。"
)
print
(
"表格已保存到 gemmoutput.xlsx 文件中。"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
View file @
fdf60506
...
@@ -11,7 +11,8 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
...
@@ -11,7 +11,8 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
import
functools
import
functools
import
logging
import
logging
import
w8a8_matmul_extension
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
device_name
=
torch
.
cuda
.
get_device_properties
(
'cuda'
).
name
.
replace
(
" "
,
"_"
)
device_name
=
torch
.
cuda
.
get_device_properties
(
'cuda'
).
name
.
replace
(
" "
,
"_"
)
...
@@ -463,11 +464,17 @@ def w8a8_block_int8_matmul_wgrad_batched_native(
...
@@ -463,11 +464,17 @@ def w8a8_block_int8_matmul_wgrad_batched_native(
):
):
for
i
in
range
(
len
(
C_list
)):
for
i
in
range
(
len
(
C_list
)):
assert
C_list
[
i
]
is
not
None
assert
C_list
[
i
]
is
not
None
C_list
[
i
],
config
=
w8a8_block_int8_matmul_wgrad
(
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
block_size
,
C_list
[
i
],
config
=
w8a8_block_int8_matmul_wgrad
(
output_dtype
=
output_dtype
,
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
block_size
,
best_config
=
best_config
output_dtype
=
output_dtype
,
)
best_config
=
best_config
)
else
:
C_list
[
i
]
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul_wgrad
(
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
block_size
,
output_dtype
=
output_dtype
)
return
C_list
return
C_list
def
w8a8_block_int8_matmul_wgrad_batched
(
def
w8a8_block_int8_matmul_wgrad_batched
(
...
@@ -613,7 +620,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
...
@@ -613,7 +620,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
output
=
[
q_input
.
new_empty
(
C_shape
,
dtype
=
out_dtype
)
for
i
in
range
(
batch
)]
output
=
[
q_input
.
new_empty
(
C_shape
,
dtype
=
out_dtype
)
for
i
in
range
(
batch
)]
# print(f"zhenggf, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
# print(f"zhenggf, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
torch_output
=
native_w8a8_block_int8_matmul_batched
(
q_input_b
,
weight_b
,
x_scale_b
,
weight_scale_b
,
block_size
)
torch_output
=
native_w8a8_block_int8_matmul_batched
(
q_input_b
,
weight_b
,
x_scale_b
,
weight_scale_b
,
block_size
,
out_dtype
)
# print(f"zhenggf, torch_output:{torch_output.shape}")
# print(f"zhenggf, torch_output:{torch_output.shape}")
x_scale_b
=
[
xs
.
permute
(
1
,
0
).
contiguous
()
for
xs
in
x_scale_b
]
x_scale_b
=
[
xs
.
permute
(
1
,
0
).
contiguous
()
for
xs
in
x_scale_b
]
...
@@ -648,7 +655,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
...
@@ -648,7 +655,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
q_input
,
x_scale
,
weight
,
weight_scale
=
_int8_gemm_helper_b
(
m
=
m
,
n
=
n
,
k
=
k
,
out_dtype
=
out_dtype
,
device
=
device
,
block_size
=
block_size
)
q_input
,
x_scale
,
weight
,
weight_scale
=
_int8_gemm_helper_b
(
m
=
m
,
n
=
n
,
k
=
k
,
out_dtype
=
out_dtype
,
device
=
device
,
block_size
=
block_size
)
print
(
f
"zhenggf, q_input:
{
q_input
.
shape
}
, x_scale:
{
x_scale
.
shape
}
, weight:
{
weight
.
shape
}
, weight_scale:
{
weight_scale
.
shape
}
"
)
print
(
f
"zhenggf, q_input:
{
q_input
.
shape
}
, x_scale:
{
x_scale
.
shape
}
, weight:
{
weight
.
shape
}
, weight_scale:
{
weight_scale
.
shape
}
"
)
torch_output
=
native_w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
)
torch_output
=
native_w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
out_dtype
)
x_scale
=
x_scale
.
permute
(
1
,
0
).
contiguous
()
x_scale
=
x_scale
.
permute
(
1
,
0
).
contiguous
()
weight_scale
=
weight_scale
.
permute
(
1
,
0
).
contiguous
()
weight_scale
=
weight_scale
.
permute
(
1
,
0
).
contiguous
()
...
@@ -657,42 +664,47 @@ def apply_w8a8_block_int8_linear_helper(m: int,
...
@@ -657,42 +664,47 @@ def apply_w8a8_block_int8_linear_helper(m: int,
C_shape
=
q_input
.
shape
[:
-
1
]
+
(
N
,)
C_shape
=
q_input
.
shape
[:
-
1
]
+
(
N
,)
output
=
q_input
.
new_empty
(
C_shape
,
dtype
=
out_dtype
)
output
=
q_input
.
new_empty
(
C_shape
,
dtype
=
out_dtype
)
print
(
f
"zhenggf 转置后传递给triton kernel, q_input:
{
q_input
.
shape
}
, x_scale:
{
x_scale
.
shape
}
, weight:
{
weight
.
shape
}
, weight_scale:
{
weight_scale
.
shape
}
"
)
print
(
f
"zhenggf 转置后传递给triton kernel, q_input:
{
q_input
.
shape
}
, x_scale:
{
x_scale
.
shape
}
, weight:
{
weight
.
shape
}
, weight_scale:
{
weight_scale
.
shape
}
"
)
output
,
config
=
w8a8_block_int8_matmul_wgrad
(
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
output
,
config
=
w8a8_block_int8_matmul_wgrad
(
output_dtype
=
out_dtype
,
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
best_config
=
best_config
output_dtype
=
out_dtype
,
)
best_config
=
best_config
)
else
:
output
=
w8a8_matmul_extension
.
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
output_dtype
=
out_dtype
)
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
print
(
"triton 精度检查不合格!!!"
)
print
(
"triton 精度检查不合格!!!"
)
else
:
else
:
print
(
"triton 精度检查合格"
)
print
(
"triton 精度检查合格"
)
# unit test end
# unit test end
g
=
torch
.
cuda
.
CUDAGraph
()
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
with
torch
.
cuda
.
graph
(
g
):
g
=
torch
.
cuda
.
CUDAGraph
()
for
it
in
range
(
1000
):
with
torch
.
cuda
.
graph
(
g
):
output
,
_
=
w8a8_block_int8_matmul_wgrad
(
for
it
in
range
(
1000
):
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
output
,
_
=
w8a8_block_int8_matmul_wgrad
(
output_dtype
=
out_dtype
,
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
best_config
=
best_config
output_dtype
=
out_dtype
,
)
best_config
=
best_config
torch
.
cuda
.
synchronize
()
)
start_time_
=
time
.
time
()
# 开始计时
torch
.
cuda
.
synchronize
()
g
.
replay
()
start_time_
=
time
.
time
()
# 开始计时
torch
.
cuda
.
synchronize
()
g
.
replay
()
end_time_
=
time
.
time
()
# 结束计时
torch
.
cuda
.
synchronize
()
end_time_
=
time
.
time
()
# 结束计时
elapsed_time
=
round
((
end_time_
-
start_time_
)
*
1000
,
7
)
# 计算耗时
print
(
"_time:{} us
\n
"
.
format
(
elapsed_time
))
elapsed_time
=
round
((
end_time_
-
start_time_
)
*
1000
,
7
)
# 计算耗时
quantiles
=
[
0.5
,
0.2
,
0.8
]
print
(
"_time:{} us
\n
"
.
format
(
elapsed_time
))
gpu_costtime
=
triton
.
testing
.
do_bench
(
lambda
:
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
),
quantiles
=
None
,
return_mode
=
"mean"
)
*
1000
quantiles
=
[
0.5
,
0.2
,
0.8
]
gpu_costtime
=
triton
.
testing
.
do_bench
(
lambda
:
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
),
quantiles
=
None
,
return_mode
=
"mean"
)
*
1000
if
bias
is
not
None
:
output
=
output
+
bias
if
bias
is
not
None
:
return
output
.
to
(
dtype
=
out_dtype
),
elapsed_time
,
gpu_costtime
,
config
output
=
output
+
bias
return
output
.
to
(
dtype
=
out_dtype
),
elapsed_time
,
gpu_costtime
,
config
def
get_triton_cache
(
file_path
,
n
,
k
,
block_n
,
block_k
):
def
get_triton_cache
(
file_path
,
n
,
k
,
block_n
,
block_k
):
#会将所报错的json文件以字典的形式return出来
#会将所报错的json文件以字典的形式return出来
...
@@ -850,34 +862,36 @@ def main():
...
@@ -850,34 +862,36 @@ def main():
best_config
=
[]
best_config
=
[]
apply_w8a8_block_int8_linear_batched_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
apply_w8a8_block_int8_linear_batched_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
output
,
elapsed_time
,
gpu_costtime
,
config
=
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
output
,
elapsed_time
,
gpu_costtime
,
config
=
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
cost_times
.
append
(
elapsed_time
)
cost_times
.
append
(
elapsed_time
)
gpu_costtimes
.
append
(
gpu_costtime
)
gpu_costtimes
.
append
(
gpu_costtime
)
_n
.
append
(
n_list
[
i
])
_n
.
append
(
n_list
[
i
])
_k
.
append
(
k_list
[
i
])
_k
.
append
(
k_list
[
i
])
_m
.
append
(
m
)
_m
.
append
(
m
)
print
(
f
"zhenggf,
{
config
}
"
)
print
(
f
"zhenggf,
{
config
}
"
)
print
(
f
"zhenggf,
{
config
.
kwargs
}
"
)
print
(
f
"zhenggf,
{
config
.
kwargs
}
"
)
_configs_block_m
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_M'
])
_configs_block_m
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_M'
])
_configs_block_n
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_N'
])
_configs_block_n
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_N'
])
_configs_block_k
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_K'
])
_configs_block_k
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_K'
])
_configs_block_group_m
.
append
(
config
.
kwargs
[
'GROUP_SIZE_M'
])
_configs_block_group_m
.
append
(
config
.
kwargs
[
'GROUP_SIZE_M'
])
_configs_block_num_warps
.
append
(
config
.
num_warps
)
_configs_block_num_warps
.
append
(
config
.
num_warps
)
_configs_block_num_stages
.
append
(
config
.
num_stages
)
_configs_block_num_stages
.
append
(
config
.
num_stages
)
# _configs_kpack.append(config['kpack'])
# _configs_kpack.append(config['kpack'])
else
:
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
# 创建一个包含这三个列表的 DataFrame
# 创建一个包含这三个列表的 DataFrame
df
=
pd
.
DataFrame
({
'm'
:
_m
,
'n'
:
_n
,
'k'
:
_k
,
'线性层gemm量化算子耗时'
:
cost_times
,
'GPU算子耗时'
:
gpu_costtimes
,
df
=
pd
.
DataFrame
({
'm'
:
_m
,
'n'
:
_n
,
'k'
:
_k
,
'线性层gemm量化算子耗时'
:
cost_times
,
'GPU算子耗时'
:
gpu_costtimes
,
'BLOCK_SIZE_M'
:
_configs_block_m
,
'BLOCK_SIZE_N'
:
_configs_block_n
,
'BLOCK_SIZE_K'
:
_configs_block_k
,
'BLOCK_SIZE_M'
:
_configs_block_m
,
'BLOCK_SIZE_N'
:
_configs_block_n
,
'BLOCK_SIZE_K'
:
_configs_block_k
,
'GROUP_SIZE_M'
:
_configs_block_group_m
,
'num_warps'
:
_configs_block_num_warps
,
'num_stages'
:
_configs_block_num_stages
,
#'kpack':_configs_kpack
'GROUP_SIZE_M'
:
_configs_block_group_m
,
'num_warps'
:
_configs_block_num_warps
,
'num_stages'
:
_configs_block_num_stages
,
#'kpack':_configs_kpack
})
})
# 将 DataFrame 写入 Excel 文件
# 将 DataFrame 写入 Excel 文件
df
.
to_excel
(
'gemmoutput.xlsx'
,
index
=
False
)
df
.
to_excel
(
'gemmoutput.xlsx'
,
index
=
False
)
print
(
"表格已保存到 gemmoutput.xlsx 文件中。"
)
print
(
"表格已保存到 gemmoutput.xlsx 文件中。"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment