Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9ab6cd98
Commit
9ab6cd98
authored
Jul 29, 2025
by
yuguo
Browse files
Merge branch 'main' of
http://10.16.6.30/dcutoolkit/deeplearing/TransformerEngine
parents
782f6092
84e8ce2f
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
83 additions
and
73 deletions
+83
-73
tests/pytorch/test_int8_blockwise_gemm_exact.py
tests/pytorch/test_int8_blockwise_gemm_exact.py
+18
-19
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+17
-18
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
+13
-16
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
...mer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
+16
-19
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+19
-1
No files found.
tests/pytorch/test_int8_blockwise_gemm_exact.py
View file @
9ab6cd98
...
@@ -5,10 +5,9 @@ import transformer_engine_torch as tex
...
@@ -5,10 +5,9 @@ import transformer_engine_torch as tex
import
warnings
import
warnings
try
:
try
:
import
lightop
import
lightop
enable_lightop
=
True
except
ImportError
:
except
ImportError
:
enable_lightop
=
False
pass
from
transformer_engine.pytorch
import
get_device_compute_capability
from
transformer_engine.pytorch
.utils
import
use_lightop_w8a8
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
...
@@ -202,15 +201,15 @@ def cublas_gemm_fp8_blockwise_case_fw(
...
@@ -202,15 +201,15 @@ def cublas_gemm_fp8_blockwise_case_fw(
ref_scales_x
=
qx
.
_columnwise_scale_inv
if
x_columnwise
else
qx
.
_rowwise_scale_inv
ref_scales_x
=
qx
.
_columnwise_scale_inv
if
x_columnwise
else
qx
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
or
not
enable_lightop
:
if
use_lightop_w8a8
([
block_len
,
block_len
]):
y
=
lightop
.
gemm_w8a8_asm
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
out_dtype
,
'TN'
)
else
:
y
,
_
=
w8a8_block_int8_matmul
(
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
else
:
y
=
lightop
.
gemm_w8a8_asm
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
out_dtype
)
# print("int8 gemm output: ", y)
# print("int8 gemm output: ", y)
# print("int8 gemm output shape: ", y.shape)
# print("int8 gemm output shape: ", y.shape)
...
@@ -385,15 +384,15 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
...
@@ -385,15 +384,15 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
ref_scales_dout
=
qdout
.
_columnwise_scale_inv
if
dout_columnwise
else
qdout
.
_rowwise_scale_inv
ref_scales_dout
=
qdout
.
_columnwise_scale_inv
if
dout_columnwise
else
qdout
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
or
not
enable_lightop
:
if
use_lightop_w8a8
([
block_len
,
block_len
]):
y
=
lightop
.
gemm_w8a8_asm
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
block_len
,
block_len
],
dx_dtype
,
'TN'
)
else
:
y
,
_
=
w8a8_block_int8_matmul
(
y
,
_
=
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
block_len
,
block_len
],
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
block_len
,
block_len
],
output_dtype
=
dx_dtype
output_dtype
=
dx_dtype
)
)
else
:
y
=
lightop
.
gemm_w8a8_asm
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
block_len
,
block_len
],
dx_dtype
)
# print("int8 gemm dx: ", y)
# print("int8 gemm dx: ", y)
...
@@ -569,16 +568,16 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
...
@@ -569,16 +568,16 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_len
!=
128
or
not
enable_lightop
:
if
use_lightop_w8a8
([
block_len
,
block_len
])
:
y
,
_
=
w8a8_block_int8_matmul
_wgrad
(
y
=
lightop
.
gemm_w8a8
_wgrad
_asm
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
dw
.
clone
()
if
accumulate
else
None
,
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
dw
.
clone
()
if
accumulate
else
None
,
accumulate
,
[
block_len
,
block_len
],
accumulate
,
[
block_len
,
block_len
],
dw_dtype
,
'TN'
output_dtype
=
dw_dtype
)
)
else
:
else
:
y
=
lightop
.
gemm_w8a8
_wgrad
_asm
(
y
,
_
=
w8a8_block_int8_matmul
_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
dw
.
clone
()
if
accumulate
else
None
,
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
dw
.
clone
()
if
accumulate
else
None
,
accumulate
,
[
block_len
,
block_len
],
dw_dtype
accumulate
,
[
block_len
,
block_len
],
output_dtype
=
dw_dtype
)
)
# print("int8 gemm dw: ",y)
# print("int8 gemm dw: ",y)
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
9ab6cd98
...
@@ -8,12 +8,10 @@ from typing import Iterable, Optional, Tuple, Union, List
...
@@ -8,12 +8,10 @@ from typing import Iterable, Optional, Tuple, Union, List
import
os
import
os
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
import
warnings
try
:
try
:
import
lightop
import
lightop
enable_lightop
=
True
except
ImportError
:
except
ImportError
:
enable_lightop
=
False
pass
from
..constants
import
TE_DType
,
TE_DType_To_Torch
from
..constants
import
TE_DType
,
TE_DType_To_Torch
from
..utils
import
get_sm_count
,
_empty_tensor
from
..utils
import
get_sm_count
,
_empty_tensor
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
,
w8a8_block_int8_matmul_batched
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
,
w8a8_block_int8_matmul_batched
...
@@ -23,6 +21,7 @@ from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTens
...
@@ -23,6 +21,7 @@ from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTens
from
..tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..tensor._internal.float8_tensor_base
import
Float8TensorBase
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.utils
import
use_lightop_w8a8
from
transformer_engine.pytorch.triton.per_token_group_quant
import
(
per_token_quant_fp8_to_int8
,
from
transformer_engine.pytorch.triton.per_token_group_quant
import
(
per_token_quant_fp8_to_int8
,
per_token_quant_fp8_to_int8_v2
,
per_token_quant_fp8_to_int8_v2
,
per_token_quant_fp8_to_int8_opt
,
per_token_quant_fp8_to_int8_opt
,
...
@@ -92,15 +91,15 @@ def general_gemm(
...
@@ -92,15 +91,15 @@ def general_gemm(
)
)
ref_scales_x
=
B
.
_rowwise_scale_inv
ref_scales_x
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
if
get_device_compute_capability
()
<
(
9
,
3
)
or
blockwise_fp8_block_len
!=
128
or
not
enable_lightop
:
if
use_lightop_w8a8
([
blockwise_fp8_block_len
,
blockwise_fp8_block_len
]):
y
=
lightop
.
gemm_w8a8_asm
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'TN'
)
else
:
y
,
_
=
w8a8_block_int8_matmul
(
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
else
:
y
=
lightop
.
gemm_w8a8_asm
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
)
return
y
,
None
,
None
,
None
return
y
,
None
,
None
,
None
elif
layout
==
"NN"
:
elif
layout
==
"NN"
:
...
@@ -112,15 +111,15 @@ def general_gemm(
...
@@ -112,15 +111,15 @@ def general_gemm(
)
)
ref_scales_dout
=
B
.
_rowwise_scale_inv
ref_scales_dout
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_columnwise_scale_inv
ref_scales_w
=
A
.
_columnwise_scale_inv
if
get_device_compute_capability
()
<
(
9
,
3
)
or
blockwise_fp8_block_len
!=
128
or
not
enable_lightop
:
if
use_lightop_w8a8
([
blockwise_fp8_block_len
,
blockwise_fp8_block_len
]):
y
=
lightop
.
gemm_w8a8_asm
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'TN'
)
else
:
y
,
_
=
w8a8_block_int8_matmul
(
y
,
_
=
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
else
:
y
=
lightop
.
gemm_w8a8_asm
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
)
return
y
,
None
,
None
,
None
return
y
,
None
,
None
,
None
elif
layout
==
"NT"
:
elif
layout
==
"NT"
:
...
@@ -132,15 +131,15 @@ def general_gemm(
...
@@ -132,15 +131,15 @@ def general_gemm(
)
)
ref_scales_dout
=
B
.
_columnwise_scale_inv
ref_scales_dout
=
B
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
if
get_device_compute_capability
()
<
(
9
,
3
)
or
blockwise_fp8_block_len
!=
128
or
not
enable_lightop
:
if
use_lightop_w8a8
([
blockwise_fp8_block_len
,
blockwise_fp8_block_len
]):
out
=
lightop
.
gemm_w8a8_wgrad_asm
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'TN'
)
else
:
out
,
_
=
w8a8_block_int8_matmul_wgrad
(
out
,
_
=
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
output_dtype
=
out_dtype
)
)
else
:
out
=
lightop
.
gemm_w8a8_wgrad_asm
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
)
return
out
,
None
,
None
,
None
return
out
,
None
,
None
,
None
else
:
else
:
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
View file @
9ab6cd98
...
@@ -11,13 +11,11 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
...
@@ -11,13 +11,11 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
import
functools
import
functools
import
logging
import
logging
import
warnings
try
:
try
:
import
lightop
import
lightop
enable_lightop
=
True
except
ImportError
:
except
ImportError
:
enable_lightop
=
False
pass
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.pytorch.utils
import
use_lightop_w8a8
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -615,24 +613,23 @@ def apply_w8a8_block_int8_linear_helper(m: int,
...
@@ -615,24 +613,23 @@ def apply_w8a8_block_int8_linear_helper(m: int,
torch_output
=
native_w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
out_dtype
)
torch_output
=
native_w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
out_dtype
)
x_scale
=
x_scale
.
permute
(
1
,
0
).
contiguous
()
x_scale
=
x_scale
.
permute
(
1
,
0
).
contiguous
()
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
or
not
enable_lightop
:
if
use_lightop_w8a8
(
block_size
):
output
=
lightop
.
gemm_w8a8_asm
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
out_dtype
,
'TN'
)
else
:
output
,
config
=
w8a8_block_int8_matmul
(
output
,
config
=
w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
out_dtype
,
output_dtype
=
out_dtype
,
best_config
=
best_config
best_config
=
best_config
)
)
else
:
output
=
lightop
.
gemm_w8a8_asm
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
out_dtype
)
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
print
(
"triton 精度检查不合格!!!"
)
print
(
"w8a8_block_int8 精度检查不合格!!!"
)
else
:
else
:
print
(
"triton 精度检查合格"
)
print
(
"w8a8_block_int8 精度检查合格"
)
# unit test end
# unit test end
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
if
not
use_lightop_w8a8
(
block_size
)
:
g
=
torch
.
cuda
.
CUDAGraph
()
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
with
torch
.
cuda
.
graph
(
g
):
for
it
in
range
(
1000
):
for
it
in
range
(
1000
):
...
@@ -811,7 +808,7 @@ def main():
...
@@ -811,7 +808,7 @@ def main():
best_config
=
[]
best_config
=
[]
apply_w8a8_block_int8_linear_batched_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
apply_w8a8_block_int8_linear_batched_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
if
not
use_lightop_w8a8
(
block_size
)
:
output
,
elapsed_time
,
gpu_costtime
,
config
=
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
output
,
elapsed_time
,
gpu_costtime
,
config
=
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
cost_times
.
append
(
elapsed_time
)
cost_times
.
append
(
elapsed_time
)
...
@@ -831,7 +828,7 @@ def main():
...
@@ -831,7 +828,7 @@ def main():
else
:
else
:
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
if
not
use_lightop_w8a8
(
block_size
)
:
# 创建一个包含这三个列表的 DataFrame
# 创建一个包含这三个列表的 DataFrame
df
=
pd
.
DataFrame
({
'm'
:
_m
,
'n'
:
_n
,
'k'
:
_k
,
'线性层gemm量化算子耗时'
:
cost_times
,
'GPU算子耗时'
:
gpu_costtimes
,
df
=
pd
.
DataFrame
({
'm'
:
_m
,
'n'
:
_n
,
'k'
:
_k
,
'线性层gemm量化算子耗时'
:
cost_times
,
'GPU算子耗时'
:
gpu_costtimes
,
'BLOCK_SIZE_M'
:
_configs_block_m
,
'BLOCK_SIZE_N'
:
_configs_block_n
,
'BLOCK_SIZE_K'
:
_configs_block_k
,
'BLOCK_SIZE_M'
:
_configs_block_m
,
'BLOCK_SIZE_N'
:
_configs_block_n
,
'BLOCK_SIZE_K'
:
_configs_block_k
,
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
View file @
9ab6cd98
...
@@ -11,13 +11,11 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
...
@@ -11,13 +11,11 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
import
functools
import
functools
import
logging
import
logging
import
warnings
try
:
try
:
import
lightop
import
lightop
enable_lightop
=
True
except
ImportError
:
except
ImportError
:
enable_lightop
=
False
pass
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
from
transformer_engine.pytorch.utils
import
use_lightop_w8a8
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
device_name
=
torch
.
cuda
.
get_device_properties
(
'cuda'
).
name
.
replace
(
" "
,
"_"
)
device_name
=
torch
.
cuda
.
get_device_properties
(
'cuda'
).
name
.
replace
(
" "
,
"_"
)
...
@@ -469,16 +467,16 @@ def w8a8_block_int8_matmul_wgrad_batched_native(
...
@@ -469,16 +467,16 @@ def w8a8_block_int8_matmul_wgrad_batched_native(
):
):
for
i
in
range
(
len
(
C_list
)):
for
i
in
range
(
len
(
C_list
)):
assert
C_list
[
i
]
is
not
None
assert
C_list
[
i
]
is
not
None
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
or
not
enable_lightop
:
if
use_lightop_w8a8
(
block_size
):
C_list
[
i
]
=
lightop
.
gemm_w8a8_wgrad_asm
(
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
block_size
,
output_dtype
,
'TN'
)
else
:
C_list
[
i
],
config
=
w8a8_block_int8_matmul_wgrad
(
C_list
[
i
],
config
=
w8a8_block_int8_matmul_wgrad
(
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
block_size
,
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
block_size
,
output_dtype
=
output_dtype
,
output_dtype
=
output_dtype
,
best_config
=
best_config
best_config
=
best_config
)
)
else
:
C_list
[
i
]
=
lightop
.
gemm_w8a8_wgrad_asm
(
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
block_size
,
output_dtype
)
return
C_list
return
C_list
def
w8a8_block_int8_matmul_wgrad_batched
(
def
w8a8_block_int8_matmul_wgrad_batched
(
...
@@ -667,25 +665,24 @@ def apply_w8a8_block_int8_linear_helper(m: int,
...
@@ -667,25 +665,24 @@ def apply_w8a8_block_int8_linear_helper(m: int,
N
,
K
=
weight
.
shape
N
,
K
=
weight
.
shape
C_shape
=
q_input
.
shape
[:
-
1
]
+
(
N
,)
C_shape
=
q_input
.
shape
[:
-
1
]
+
(
N
,)
output
=
q_input
.
new_empty
(
C_shape
,
dtype
=
out_dtype
)
output
=
q_input
.
new_empty
(
C_shape
,
dtype
=
out_dtype
)
if
use_lightop_w8a8
(
block_size
):
output
=
lightop
.
gemm_w8a8_wgrad_asm
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
out_dtype
,
'TN'
)
else
:
print
(
f
"zhenggf 转置后传递给triton kernel, q_input:
{
q_input
.
shape
}
, x_scale:
{
x_scale
.
shape
}
, weight:
{
weight
.
shape
}
, weight_scale:
{
weight_scale
.
shape
}
"
)
print
(
f
"zhenggf 转置后传递给triton kernel, q_input:
{
q_input
.
shape
}
, x_scale:
{
x_scale
.
shape
}
, weight:
{
weight
.
shape
}
, weight_scale:
{
weight_scale
.
shape
}
"
)
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
or
not
enable_lightop
:
output
,
config
=
w8a8_block_int8_matmul_wgrad
(
output
,
config
=
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
output_dtype
=
out_dtype
,
output_dtype
=
out_dtype
,
best_config
=
best_config
best_config
=
best_config
)
)
else
:
output
=
lightop
.
gemm_w8a8_wgrad_asm
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
out_dtype
)
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
print
(
"triton 精度检查不合格!!!"
)
print
(
"triton 精度检查不合格!!!"
)
else
:
else
:
print
(
"triton 精度检查合格"
)
print
(
"triton 精度检查合格"
)
# unit test end
# unit test end
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
if
not
use_lightop_w8a8
(
block_size
)
:
g
=
torch
.
cuda
.
CUDAGraph
()
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
with
torch
.
cuda
.
graph
(
g
):
for
it
in
range
(
1000
):
for
it
in
range
(
1000
):
...
@@ -865,7 +862,7 @@ def main():
...
@@ -865,7 +862,7 @@ def main():
best_config
=
[]
best_config
=
[]
apply_w8a8_block_int8_linear_batched_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
apply_w8a8_block_int8_linear_batched_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
128
:
if
not
use_lightop_w8a8
(
block_size
)
:
output
,
elapsed_time
,
gpu_costtime
,
config
=
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
output
,
elapsed_time
,
gpu_costtime
,
config
=
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
cost_times
.
append
(
elapsed_time
)
cost_times
.
append
(
elapsed_time
)
...
@@ -885,7 +882,7 @@ def main():
...
@@ -885,7 +882,7 @@ def main():
else
:
else
:
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
if
get_device_compute_capability
()
<
(
9
,
3
)
or
block_size
[
1
]
!=
12
8
:
if
not
use_lightop_w8a
8
:
# 创建一个包含这三个列表的 DataFrame
# 创建一个包含这三个列表的 DataFrame
df
=
pd
.
DataFrame
({
'm'
:
_m
,
'n'
:
_n
,
'k'
:
_k
,
'线性层gemm量化算子耗时'
:
cost_times
,
'GPU算子耗时'
:
gpu_costtimes
,
df
=
pd
.
DataFrame
({
'm'
:
_m
,
'n'
:
_n
,
'k'
:
_k
,
'线性层gemm量化算子耗时'
:
cost_times
,
'GPU算子耗时'
:
gpu_costtimes
,
'BLOCK_SIZE_M'
:
_configs_block_m
,
'BLOCK_SIZE_N'
:
_configs_block_n
,
'BLOCK_SIZE_K'
:
_configs_block_k
,
'BLOCK_SIZE_M'
:
_configs_block_m
,
'BLOCK_SIZE_N'
:
_configs_block_n
,
'BLOCK_SIZE_K'
:
_configs_block_k
,
...
...
transformer_engine/pytorch/utils.py
View file @
9ab6cd98
...
@@ -10,7 +10,12 @@ import os
...
@@ -10,7 +10,12 @@ import os
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
warnings
try
:
import
lightop
enable_lightop
=
True
except
ImportError
:
enable_lightop
=
False
import
transformer_engine.pytorch.cpp_extensions
as
ext
import
transformer_engine.pytorch.cpp_extensions
as
ext
from
.
import
torch_version
from
.
import
torch_version
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
@@ -455,6 +460,19 @@ if IS_HIP_EXTENSION:
...
@@ -455,6 +460,19 @@ if IS_HIP_EXTENSION:
import
re
import
re
return
(
re
.
search
(
'BW'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
return
(
re
.
search
(
'BW'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
def
use_lightop_w8a8
(
block_size
:
List
[
int
])
->
bool
:
"""Check whether to use lightop for w8a8"""
# Just return False because lightop is not ready now.
return
False
if
(
enable_lightop
):
return
get_device_compute_capability
()
>=
(
9
,
3
)
and
block_size
[
1
]
==
128
else
:
if
(
get_device_compute_capability
()
>=
(
9
,
3
)
and
block_size
[
1
]
==
128
):
warnings
.
warn
(
"Lightop is not available. Using default implementation for w8a8."
)
return
False
def
is_bf16_compatible
()
->
None
:
def
is_bf16_compatible
()
->
None
:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher.
check on device compute capability to enforce sm_80 or higher.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment