Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
3f800f01
Commit
3f800f01
authored
Sep 18, 2025
by
wenjh
Browse files
Enable lightop w8a8
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
00fcd784
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
211 additions
and
324 deletions
+211
-324
tests/pytorch/test_float8_blockwise_gemm_exact.py
tests/pytorch/test_float8_blockwise_gemm_exact.py
+8
-39
tests/pytorch/test_int8_blockwise_gemm_exact.py
tests/pytorch/test_int8_blockwise_gemm_exact.py
+4
-37
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+89
-77
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
+55
-67
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
...mer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
+55
-86
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+0
-18
No files found.
tests/pytorch/test_float8_blockwise_gemm_exact.py
View file @
3f800f01
...
@@ -6,10 +6,6 @@ import pytest
...
@@ -6,10 +6,6 @@ import pytest
import
torch
import
torch
import
transformer_engine
as
te
import
transformer_engine
as
te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
try
:
import
lightop
except
ImportError
:
pass
from
transformer_engine.pytorch.utils
import
use_lightop_w8a8
from
transformer_engine.pytorch.utils
import
use_lightop_w8a8
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
,
int8_simulation_fp8
)
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
,
int8_simulation_fp8
)
...
@@ -17,12 +13,11 @@ from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
...
@@ -17,12 +13,11 @@ from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer
,
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
Float8BlockwiseQTensor
,
)
)
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad
import
w8a8_block_int8_matmul_wgrad
from
references.blockwise_quantizer_reference
import
CuBLASScaleMunger
from
references.blockwise_quantizer_reference
import
CuBLASScaleMunger
from
references.blockwise_fp8_gemm_reference
import
CuBLASRefBlockwiseGemm
from
references.blockwise_fp8_gemm_reference
import
CuBLASRefBlockwiseGemm
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.cpp_extensions.gemm
import
w8a8_int8_general_gemm
def
fp8_blockwise_gemm_supported
()
->
bool
:
def
fp8_blockwise_gemm_supported
()
->
bool
:
supported
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
supported
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
...
@@ -174,40 +169,14 @@ def cublas_gemm_fp8_blockwise_case(
...
@@ -174,40 +169,14 @@ def cublas_gemm_fp8_blockwise_case(
bias_dtype
=
TE_DType
[
torch
.
bfloat16
if
bias
is
None
else
bias
.
dtype
]
bias_dtype
=
TE_DType
[
torch
.
bfloat16
if
bias
is
None
else
bias
.
dtype
]
if
IS_HIP_EXTENSION
and
int8_simulation_fp8
:
if
IS_HIP_EXTENSION
and
int8_simulation_fp8
:
if
use_lightop_w8a8
([
block_len
,
block_len
]):
if
(
not
x_columnwise
and
not
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
):
if
(
not
x_columnwise
and
not
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
):
y
=
w8a8_int8_general_gemm
(
qw
,
qx
,
out_dtype
,
False
,
"TN"
,
None
)
y
=
lightop
.
gemm_w8a8_asm
(
elif
(
not
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
):
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
out_dtype
,
'TN'
y
=
w8a8_int8_general_gemm
(
qw
,
qx
,
out_dtype
,
False
,
"NN"
,
None
)
)
elif
(
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
is_w_1d_scaled
):
elif
(
not
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
):
y
=
w8a8_int8_general_gemm
(
qw
,
qx
,
out_dtype
,
accumulate
,
"NT"
,
out
.
clone
()
if
accumulate
else
None
)
y
=
lightop
.
gemm_w8a8_xgrad_asm
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
out_dtype
,
'TN'
)
elif
(
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
is_w_1d_scaled
):
y
=
lightop
.
gemm_w8a8_wgrad_asm
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
out
.
clone
()
if
accumulate
else
None
,
accumulate
,
[
block_len
,
block_len
],
out_dtype
,
'TN'
)
else
:
assert
False
,
"Only fwd, xgrad, and wgrad block scaling supported in int8 simulation mode on ROCm."
else
:
else
:
if
(
not
x_columnwise
and
not
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
):
assert
False
,
"Only fwd, xgrad, and wgrad block scaling supported in int8 simulation mode on ROCm."
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
output_dtype
=
out_dtype
)
elif
(
not
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
not
is_w_1d_scaled
):
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
output_dtype
=
out_dtype
)
elif
(
x_columnwise
and
w_columnwise
and
is_x_1d_scaled
and
is_w_1d_scaled
):
y
,
_
=
w8a8_block_int8_matmul_wgrad
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
out
.
clone
()
if
accumulate
else
None
,
accumulate
,
[
block_len
,
block_len
],
output_dtype
=
out_dtype
)
else
:
assert
False
,
"Only fwd, xgrad, and wgrad block scaling supported in int8 simulation mode on ROCm."
else
:
else
:
# cuBLAS GEMM
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# return type is out, bias_grad, gelu_input, extra_output
...
...
tests/pytorch/test_int8_blockwise_gemm_exact.py
View file @
3f800f01
...
@@ -2,20 +2,13 @@ import pytest
...
@@ -2,20 +2,13 @@ import pytest
import
torch
import
torch
import
transformer_engine
as
te
import
transformer_engine
as
te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
import
warnings
try
:
import
lightop
except
ImportError
:
pass
from
transformer_engine.pytorch.utils
import
use_lightop_w8a8
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
blockwise_fp8_block_len
)
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
Float8BlockQuantizer
,
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
Float8BlockwiseQTensor
,
)
)
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
from
transformer_engine.pytorch.cpp_extensions.gemm
import
w8a8_int8_general_gemm
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad
import
w8a8_block_int8_matmul_wgrad
from
references.blockwise_quantizer_reference
import
CuBLASScaleMunger
from
references.blockwise_quantizer_reference
import
CuBLASScaleMunger
from
references.blockwise_fp8_gemm_reference
import
CuBLASRefBlockwiseGemm
from
references.blockwise_fp8_gemm_reference
import
CuBLASRefBlockwiseGemm
...
@@ -201,15 +194,7 @@ def cublas_gemm_fp8_blockwise_case_fw(
...
@@ -201,15 +194,7 @@ def cublas_gemm_fp8_blockwise_case_fw(
ref_scales_x
=
qx
.
_columnwise_scale_inv
if
x_columnwise
else
qx
.
_rowwise_scale_inv
ref_scales_x
=
qx
.
_columnwise_scale_inv
if
x_columnwise
else
qx
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
if
use_lightop_w8a8
([
block_len
,
block_len
]):
y
=
w8a8_int8_general_gemm
(
qw
,
qx
,
out_dtype
,
False
,
"TN"
,
None
)
y
=
lightop
.
gemm_w8a8_asm
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
out_dtype
,
'TN'
)
else
:
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
block_len
,
block_len
],
output_dtype
=
out_dtype
)
# print("int8 gemm output: ", y)
# print("int8 gemm output: ", y)
# print("int8 gemm output shape: ", y.shape)
# print("int8 gemm output shape: ", y.shape)
...
@@ -384,15 +369,7 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
...
@@ -384,15 +369,7 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
ref_scales_dout
=
qdout
.
_columnwise_scale_inv
if
dout_columnwise
else
qdout
.
_rowwise_scale_inv
ref_scales_dout
=
qdout
.
_columnwise_scale_inv
if
dout_columnwise
else
qdout
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
if
use_lightop_w8a8
([
block_len
,
block_len
]):
y
=
w8a8_int8_general_gemm
(
qw
,
qdout
,
dx_dtype
,
False
,
"NN"
,
None
)
y
=
lightop
.
gemm_w8a8_asm
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
block_len
,
block_len
],
dx_dtype
,
'TN'
)
else
:
y
,
_
=
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
block_len
,
block_len
],
output_dtype
=
dx_dtype
)
# print("int8 gemm dx: ", y)
# print("int8 gemm dx: ", y)
...
@@ -568,17 +545,7 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
...
@@ -568,17 +545,7 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
if
use_lightop_w8a8
([
block_len
,
block_len
]):
y
=
w8a8_int8_general_gemm
(
qx
,
qdout
,
dw_dtype
,
accumulate
,
"NT"
,
dw
.
clone
()
if
accumulate
else
None
)
y
=
lightop
.
gemm_w8a8_wgrad_asm
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
dw
.
clone
()
if
accumulate
else
None
,
accumulate
,
[
block_len
,
block_len
],
dw_dtype
,
'TN'
)
else
:
y
,
_
=
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
dw
.
clone
()
if
accumulate
else
None
,
accumulate
,
[
block_len
,
block_len
],
output_dtype
=
dw_dtype
)
# print("int8 gemm dw: ",y)
# print("int8 gemm dw: ",y)
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
3f800f01
...
@@ -8,20 +8,21 @@ from typing import Iterable, Optional, Tuple, Union, List
...
@@ -8,20 +8,21 @@ from typing import Iterable, Optional, Tuple, Union, List
import
os
import
os
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
import
warnings
try
:
try
:
import
lightop
import
lightop
enable_lightop
=
True
except
ImportError
:
except
ImportError
:
pass
enable_lightop
=
False
from
..constants
import
TE_DType
,
TE_DType_To_Torch
from
..constants
import
TE_DType
,
TE_DType_To_Torch
from
..utils
import
get_sm_count
,
_empty_tensor
from
..utils
import
get_sm_count
,
_empty_tensor
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
,
w8a8_block_int8_matmul_batched
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
,
w8a8_block_int8_matmul_batched
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad
import
w8a8_block_int8_matmul_wgrad
,
w8a8_block_int8_matmul_wgrad_batched
,
w8a8_block_int8_matmul_wgrad_batched_native
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad
import
w8a8_block_int8_matmul_wgrad
,
w8a8_block_int8_matmul_wgrad_batched
from
..tensor.quantized_tensor
import
Quantizer
from
..tensor.quantized_tensor
import
Quantizer
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
..tensor._internal.float8_tensor_base
import
Float8TensorBase
from
..tensor._internal.float8_tensor_base
import
Float8TensorBase
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.utils
import
use_lightop_w8a8
from
transformer_engine.pytorch.triton.per_token_group_quant
import
(
per_token_quant_fp8_to_int8
,
from
transformer_engine.pytorch.triton.per_token_group_quant
import
(
per_token_quant_fp8_to_int8
,
per_token_quant_fp8_to_int8_v2
,
per_token_quant_fp8_to_int8_v2
,
per_token_quant_fp8_to_int8_opt
,
per_token_quant_fp8_to_int8_opt
,
...
@@ -45,6 +46,71 @@ __all__ = [
...
@@ -45,6 +46,71 @@ __all__ = [
"batchgemm"
,
"batchgemm"
,
]
]
def
w8a8_block_int8_matmul_wgrad_batched_native
(
A_list
,
B_list
,
As_list
,
Bs_list
,
C_list
,
accumulate
,
output_dtype
=
torch
.
float16
):
for
i
in
range
(
len
(
C_list
)):
assert
C_list
[
i
]
is
not
None
if
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
:
C_list
[
i
]
=
lightop
.
gemm_w8a8_wgrad_asm
(
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
blockwise_fp8_block_len
,
output_dtype
,
"TN"
)
else
:
C_list
[
i
],
_
=
w8a8_block_int8_matmul_wgrad
(
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
blockwise_fp8_block_len
,
output_dtype
,
None
)
return
C_list
def
w8a8_int8_general_gemm
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
accumulate
:
bool
=
False
,
layout
:
str
=
"TN"
,
out
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
layout
==
"TN"
:
assert
accumulate
is
False
,
"Accumulate not supported in w8a8_general_gemm with TN layout"
assert
out
is
None
,
"Output tensor not supported in w8a8_general_gemm with TN layout"
qx_data
=
(
B
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
qw_data
=
(
A
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
ref_scales_x
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
if
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
:
y
=
lightop
.
gemm_w8a8_asm
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'TN'
)
else
:
warnings
.
warn
(
"Lightop is not available. Using default implementation for w8a8."
)
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
)
return
y
elif
layout
==
"NN"
:
assert
accumulate
is
False
,
"Accumulate not supported in w8a8_general_gemm with NN layout"
assert
out
is
None
,
"Output tensor not supported in w8a8_general_gemm with NN layout"
if
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
:
qdout_data
=
(
B
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
qw_data
=
(
A
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
ref_scales_dout
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
y
=
lightop
.
gemm_w8a8_xgrad_asm
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'NN'
)
else
:
warnings
.
warn
(
"Lightop is not available. Using default implementation for w8a8."
)
qdout_data
=
(
B
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
))
qw_data
=
(
A
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
))
ref_scales_dout
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_columnwise_scale_inv
y
,
_
=
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
)
return
y
elif
layout
==
"NT"
:
qdout_data
=
(
B
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
))
qx_data
=
(
A
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
))
ref_scales_dout
=
B
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
if
get_device_compute_capability
()
>=
(
9
,
3
)
and
blockwise_fp8_block_len
==
128
:
out
=
lightop
.
gemm_w8a8_wgrad_asm
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'TN'
)
else
:
warnings
.
warn
(
"Lightop is not available. Using default implementation for w8a8."
)
out
,
_
=
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
)
return
out
else
:
raise
ValueError
(
f
"Unsupported layout
{
layout
}
in int8 simulation fp8"
)
def
validate_gemm_scale
(
scale
:
Optional
[
float
],
required
:
bool
)
->
float
:
def
validate_gemm_scale
(
scale
:
Optional
[
float
],
required
:
bool
)
->
float
:
"""Validate whether a GEMM scaling factor is consistent with its usage"""
"""Validate whether a GEMM scaling factor is consistent with its usage"""
...
@@ -92,78 +158,6 @@ def general_gemm(
...
@@ -92,78 +158,6 @@ def general_gemm(
# + "a valid `ub` communicator object."
# + "a valid `ub` communicator object."
# )
# )
if
int8_simulation_fp8
and
(
isinstance
(
A
,
Float8BlockwiseQTensorBase
)
or
isinstance
(
B
,
Float8BlockwiseQTensorBase
)):
assert
not
gelu
,
"GELU not supported with int8 simulation"
assert
gelu_in
is
None
,
"GELU input not supported with int8 simulation"
assert
bias
is
None
,
"Bias not supported with int8 simulation"
assert
ub
is
None
,
"User buffer not supported with int8 simulation"
assert
ub_type
is
None
,
"User buffer type not supported with int8 simulation"
assert
extra_output
is
None
,
"Extra output not supported with int8 simulation"
assert
not
bulk_overlap
,
"Bulk overlap not supported with int8 simulation"
if
layout
==
"TN"
:
qx_data
=
(
B
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
)
qw_data
=
(
A
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
)
ref_scales_x
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_rowwise_scale_inv
if
use_lightop_w8a8
([
blockwise_fp8_block_len
,
blockwise_fp8_block_len
]):
y
=
lightop
.
gemm_w8a8_asm
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'TN'
)
else
:
y
,
_
=
w8a8_block_int8_matmul
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
)
return
y
,
None
,
None
,
None
elif
layout
==
"NN"
:
qdout_data
=
(
B
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
)
qw_data
=
(
A
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
)
ref_scales_dout
=
B
.
_rowwise_scale_inv
ref_scales_w
=
A
.
_columnwise_scale_inv
if
use_lightop_w8a8
([
blockwise_fp8_block_len
,
blockwise_fp8_block_len
]):
y
=
lightop
.
gemm_w8a8_asm
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'TN'
)
else
:
y
,
_
=
w8a8_block_int8_matmul
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
)
return
y
,
None
,
None
,
None
elif
layout
==
"NT"
:
qdout_data
=
(
B
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
)
qx_data
=
(
A
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
)
ref_scales_dout
=
B
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
if
use_lightop_w8a8
([
blockwise_fp8_block_len
,
blockwise_fp8_block_len
]):
out
=
lightop
.
gemm_w8a8_wgrad_asm
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
out_dtype
,
'TN'
)
else
:
out
,
_
=
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
output_dtype
=
out_dtype
)
return
out
,
None
,
None
,
None
else
:
raise
ValueError
(
f
"Unsupported layout
{
layout
}
in int8 simulation fp8"
)
if
ub
is
not
None
:
if
ub
is
not
None
:
assert
ub_type
is
not
None
,
"Comm+GEMM overlap requires a valid `comm_type` argument."
assert
ub_type
is
not
None
,
"Comm+GEMM overlap requires a valid `comm_type` argument."
if
ub_type
==
tex
.
CommOverlapType
.
RS
:
if
ub_type
==
tex
.
CommOverlapType
.
RS
:
...
@@ -195,6 +189,25 @@ def general_gemm(
...
@@ -195,6 +189,25 @@ def general_gemm(
or
B
.
_data_format
!=
tex
.
Float8BlockScaleTensorFormat
.
GEMM_READY
or
B
.
_data_format
!=
tex
.
Float8BlockScaleTensorFormat
.
GEMM_READY
):
):
raise
RuntimeError
(
"GEMM with Float8BlockwiseQTensor requires GEMM_READY format"
)
raise
RuntimeError
(
"GEMM with Float8BlockwiseQTensor requires GEMM_READY format"
)
if
int8_simulation_fp8
and
(
isinstance
(
A
,
Float8BlockwiseQTensorBase
)
or
isinstance
(
B
,
Float8BlockwiseQTensorBase
)):
assert
not
gelu
,
"GELU not supported with int8 simulation"
assert
gelu_in
is
None
,
"GELU input not supported with int8 simulation"
assert
bias
is
None
,
"Bias not supported with int8 simulation"
assert
ub
is
None
,
"User buffer not supported with int8 simulation"
assert
ub_type
is
None
,
"User buffer type not supported with int8 simulation"
assert
extra_output
is
None
,
"Extra output not supported with int8 simulation"
assert
not
bulk_overlap
,
"Bulk overlap not supported with int8 simulation"
y
=
w8a8_int8_general_gemm
(
A
,
B
,
out_dtype
,
accumulate
,
layout
,
out
)
return
y
,
None
,
None
,
None
if
int8_simulation_fp8
and
(
isinstance
(
A
,
Float8TensorBase
)
or
isinstance
(
B
,
Float8TensorBase
))
and
int8_simulation_fp8_tensorwise
:
if
int8_simulation_fp8
and
(
isinstance
(
A
,
Float8TensorBase
)
or
isinstance
(
B
,
Float8TensorBase
))
and
int8_simulation_fp8_tensorwise
:
assert
not
gelu
,
"GELU not supported with int8 simulation"
assert
not
gelu
,
"GELU not supported with int8 simulation"
...
@@ -480,8 +493,7 @@ def general_grouped_gemm(
...
@@ -480,8 +493,7 @@ def general_grouped_gemm(
ref_scales_x
=
[
a
.
_columnwise_scale_inv
for
a
in
A
]
ref_scales_x
=
[
a
.
_columnwise_scale_inv
for
a
in
A
]
out
=
w8a8_block_int8_matmul_wgrad_batched_native
(
out
=
w8a8_block_int8_matmul_wgrad_batched_native
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
blockwise_fp8_block_len
,
blockwise_fp8_block_len
],
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
out_dtype
output_dtype
=
out_dtype
)
)
return
out
,
bias
,
gelu_input
return
out
,
bias
,
gelu_input
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
View file @
3f800f01
...
@@ -11,11 +11,6 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
...
@@ -11,11 +11,6 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
import
functools
import
functools
import
logging
import
logging
try
:
import
lightop
except
ImportError
:
pass
from
transformer_engine.pytorch.utils
import
use_lightop_w8a8
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -613,45 +608,41 @@ def apply_w8a8_block_int8_linear_helper(m: int,
...
@@ -613,45 +608,41 @@ def apply_w8a8_block_int8_linear_helper(m: int,
torch_output
=
native_w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
out_dtype
)
torch_output
=
native_w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
out_dtype
)
x_scale
=
x_scale
.
permute
(
1
,
0
).
contiguous
()
x_scale
=
x_scale
.
permute
(
1
,
0
).
contiguous
()
if
use_lightop_w8a8
(
block_size
):
output
,
config
=
w8a8_block_int8_matmul
(
output
=
lightop
.
gemm_w8a8_asm
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
out_dtype
,
'TN'
output_dtype
=
out_dtype
,
)
best_config
=
best_config
else
:
)
output
,
config
=
w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
)
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
print
(
"w8a8_block_int8 精度检查不合格!!!"
)
print
(
"w8a8_block_int8 精度检查不合格!!!"
)
else
:
else
:
print
(
"w8a8_block_int8 精度检查合格"
)
print
(
"w8a8_block_int8 精度检查合格"
)
# unit test end
# unit test end
if
not
use_lightop_w8a8
(
block_size
):
g
=
torch
.
cuda
.
CUDAGraph
()
g
=
torch
.
cuda
.
CUDAG
raph
(
)
with
torch
.
cuda
.
g
raph
(
g
):
with
torch
.
cuda
.
graph
(
g
):
for
it
in
range
(
1000
):
for
it
in
range
(
1000
):
output
,
_
=
w8a8_block_int8_matmul
(
output
,
_
=
w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
q_input
,
weight
,
x_scale
,
weight_scale
,
block_siz
e
,
output_dtype
=
out_dtyp
e
,
output_dtype
=
out_dtype
,
best_config
=
best_config
best_config
=
best_config
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start_time_
=
time
.
time
()
# 开始计时
start_time_
=
time
.
time
()
# 开始计时
g
.
replay
()
g
.
replay
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchroniz
e
()
end_time_
=
time
.
tim
e
()
# 结束计时
end_time_
=
time
.
time
()
# 结束计时
elapsed_time
=
round
((
end_time_
-
start_time_
)
*
1000
,
7
)
# 计算耗时
elapsed_time
=
round
((
end_time_
-
start_time_
)
*
1000
,
7
)
# 计算耗时
print
(
"_time:{} us
\n
"
.
format
(
elapsed_time
))
print
(
"_time:{} us
\n
"
.
format
(
elapsed_time
))
quantiles
=
[
0.5
,
0.2
,
0.8
]
quantiles
=
[
0.5
,
0.2
,
0.8
]
gpu_costtime
=
triton
.
testing
.
do_bench
(
lambda
:
w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
),
quantiles
=
None
,
return_mode
=
"mean"
)
*
1000
gpu_costtime
=
triton
.
testing
.
do_bench
(
lambda
:
w8a8_block_int8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
),
quantiles
=
None
,
return_mode
=
"mean"
)
*
1000
if
bias
is
not
None
:
if
bias
is
not
None
:
output
=
output
+
bias
output
=
output
+
bias
return
output
.
to
(
dtype
=
out_dtype
),
elapsed_time
,
gpu_costtime
,
config
return
output
.
to
(
dtype
=
out_dtype
),
elapsed_time
,
gpu_costtime
,
config
def
get_triton_cache
(
file_path
,
n
,
k
,
block_n
,
block_k
):
def
get_triton_cache
(
file_path
,
n
,
k
,
block_n
,
block_k
):
#会将所报错的json文件以字典的形式return出来
#会将所报错的json文件以字典的形式return出来
...
@@ -808,36 +799,33 @@ def main():
...
@@ -808,36 +799,33 @@ def main():
best_config
=
[]
best_config
=
[]
apply_w8a8_block_int8_linear_batched_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
apply_w8a8_block_int8_linear_batched_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
if
not
use_lightop_w8a8
(
block_size
):
output
,
elapsed_time
,
gpu_costtime
,
config
=
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
output
,
elapsed_time
,
gpu_costtime
,
config
=
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
cost_times
.
append
(
elapsed_time
)
cost_times
.
append
(
elapsed_time
)
gpu_costtimes
.
append
(
gpu_costtime
)
gpu_costtimes
.
append
(
gpu_costtime
)
_n
.
append
(
n_list
[
i
])
_n
.
append
(
n_list
[
i
])
_k
.
append
(
k_list
[
i
])
_k
.
append
(
k_list
[
i
])
_m
.
append
(
m
)
_m
.
append
(
m
)
print
(
f
"zhenggf,
{
config
}
"
)
print
(
f
"zhenggf,
{
config
}
"
)
print
(
f
"zhenggf,
{
config
.
kwargs
}
"
)
print
(
f
"zhenggf,
{
config
.
kwargs
}
"
)
_configs_block_m
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_M'
])
_configs_block_m
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_M'
])
_configs_block_n
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_N'
])
_configs_block_n
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_N'
])
_configs_block_k
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_K'
])
_configs_block_k
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_K'
])
_configs_block_group_m
.
append
(
config
.
kwargs
[
'GROUP_SIZE_M'
])
_configs_block_group_m
.
append
(
config
.
kwargs
[
'GROUP_SIZE_M'
])
_configs_block_num_warps
.
append
(
config
.
num_warps
)
_configs_block_num_warps
.
append
(
config
.
num_warps
)
_configs_block_num_stages
.
append
(
config
.
num_stages
)
_configs_block_num_stages
.
append
(
config
.
num_stages
)
# _configs_kpack.append(config['kpack'])
# _configs_kpack.append(config['kpack'])
else
:
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
if
not
use_lightop_w8a8
(
block_size
):
# 创建一个包含这三个列表的 DataFrame
# 创建一个包含这三个列表的 DataFrame
df
=
pd
.
DataFrame
({
'm'
:
_m
,
'n'
:
_n
,
'k'
:
_k
,
'线性层gemm量化算子耗时'
:
cost_times
,
'GPU算子耗时'
:
gpu_costtimes
,
df
=
pd
.
DataFrame
({
'm'
:
_m
,
'n'
:
_n
,
'k'
:
_k
,
'线性层gemm量化算子耗时
'
:
co
st_times
,
'GPU算子耗时'
:
gpu_costtimes
,
'BLOCK_SIZE_M'
:
_configs_block_m
,
'BLOCK_SIZE_N
'
:
_
co
nfigs_block_n
,
'BLOCK_SIZE_K'
:
_configs_block_k
,
'BLOCK
_SIZE_M'
:
_configs_block_
m
,
'BLOCK_SIZE_N'
:
_configs_block_n
,
'BLOCK_SIZE_K'
:
_configs_block_k
,
'GROUP
_SIZE_M'
:
_configs_block_
group_m
,
'num_warps'
:
_configs_block_num_warps
,
'num_stages'
:
_configs_block_num_stages
,
#'kpack':_configs_kpack
'GROUP_SIZE_M'
:
_configs_block_group_m
,
'num_warps'
:
_configs_block_num_warps
,
'num_stages'
:
_configs_block_num_stages
,
#'kpack':_configs_kpack
})
})
# 将 DataFrame 写入 Excel 文件
# 将 DataFrame 写入 Excel 文件
df
.
to_excel
(
'gemmoutput.xlsx'
,
index
=
False
)
df
.
to_excel
(
'gemmoutput.xlsx'
,
index
=
False
)
print
(
"表格已保存到 gemmoutput.xlsx 文件中。"
)
print
(
"表格已保存到 gemmoutput.xlsx 文件中。"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
View file @
3f800f01
...
@@ -11,11 +11,6 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
...
@@ -11,11 +11,6 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
from
transformer_engine.pytorch.fp8
import
blockwise_fp8_block_len
import
functools
import
functools
import
logging
import
logging
try
:
import
lightop
except
ImportError
:
pass
from
transformer_engine.pytorch.utils
import
use_lightop_w8a8
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
device_name
=
torch
.
cuda
.
get_device_properties
(
'cuda'
).
name
.
replace
(
" "
,
"_"
)
device_name
=
torch
.
cuda
.
get_device_properties
(
'cuda'
).
name
.
replace
(
" "
,
"_"
)
...
@@ -461,24 +456,6 @@ def w8a8_block_int8_matmul_wgrad(
...
@@ -461,24 +456,6 @@ def w8a8_block_int8_matmul_wgrad(
return
C
,
config
return
C
,
config
def
w8a8_block_int8_matmul_wgrad_batched_native
(
A_list
,
B_list
,
As_list
,
Bs_list
,
C_list
,
accumulate
,
block_size
,
output_dtype
=
torch
.
float16
,
best_config
=
None
):
for
i
in
range
(
len
(
C_list
)):
assert
C_list
[
i
]
is
not
None
if
use_lightop_w8a8
(
block_size
):
C_list
[
i
]
=
lightop
.
gemm_w8a8_wgrad_asm
(
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
block_size
,
output_dtype
,
'TN'
)
else
:
C_list
[
i
],
config
=
w8a8_block_int8_matmul_wgrad
(
A_list
[
i
],
B_list
[
i
],
As_list
[
i
],
Bs_list
[
i
],
C_list
[
i
],
accumulate
,
block_size
,
output_dtype
=
output_dtype
,
best_config
=
best_config
)
return
C_list
def
w8a8_block_int8_matmul_wgrad_batched
(
def
w8a8_block_int8_matmul_wgrad_batched
(
A_list
,
B_list
,
As_list
,
Bs_list
,
C_list
,
accumulate
,
A_list
,
B_list
,
As_list
,
Bs_list
,
C_list
,
accumulate
,
block_size
,
output_dtype
=
torch
.
float16
,
best_config
=
None
block_size
,
output_dtype
=
torch
.
float16
,
best_config
=
None
...
@@ -665,46 +642,42 @@ def apply_w8a8_block_int8_linear_helper(m: int,
...
@@ -665,46 +642,42 @@ def apply_w8a8_block_int8_linear_helper(m: int,
N
,
K
=
weight
.
shape
N
,
K
=
weight
.
shape
C_shape
=
q_input
.
shape
[:
-
1
]
+
(
N
,)
C_shape
=
q_input
.
shape
[:
-
1
]
+
(
N
,)
output
=
q_input
.
new_empty
(
C_shape
,
dtype
=
out_dtype
)
output
=
q_input
.
new_empty
(
C_shape
,
dtype
=
out_dtype
)
if
use_lightop_w8a8
(
block_size
):
print
(
f
"zhenggf 转置后传递给triton kernel, q_input:
{
q_input
.
shape
}
, x_scale:
{
x_scale
.
shape
}
, weight:
{
weight
.
shape
}
, weight_scale:
{
weight_scale
.
shape
}
"
)
output
=
lightop
.
gemm_w8a8_wgrad_asm
(
output
,
config
=
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
out_dtype
,
'TN'
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
)
output_dtype
=
out_dtype
,
else
:
best_config
=
best_config
print
(
f
"zhenggf 转置后传递给triton kernel, q_input:
{
q_input
.
shape
}
, x_scale:
{
x_scale
.
shape
}
, weight:
{
weight
.
shape
}
, weight_scale:
{
weight_scale
.
shape
}
"
)
)
output
,
config
=
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
)
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
print
(
"triton 精度检查不合格!!!"
)
print
(
"triton 精度检查不合格!!!"
)
else
:
else
:
print
(
"triton 精度检查合格"
)
print
(
"triton 精度检查合格"
)
# unit test end
# unit test end
if
not
use_lightop_w8a8
(
block_size
):
g
=
torch
.
cuda
.
CUDAGraph
()
g
=
torch
.
cuda
.
CUDAG
raph
(
)
with
torch
.
cuda
.
g
raph
(
g
):
with
torch
.
cuda
.
graph
(
g
):
for
it
in
range
(
1000
):
for
it
in
range
(
1000
):
output
,
_
=
w8a8_block_int8_matmul_wgrad
(
output
,
_
=
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_siz
e
,
output_dtype
=
out_dtyp
e
,
output_dtype
=
out_dtype
,
best_config
=
best_config
best_config
=
best_config
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start_time_
=
time
.
time
()
# 开始计时
start_time_
=
time
.
time
()
# 开始计时
g
.
replay
()
g
.
replay
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchroniz
e
()
end_time_
=
time
.
tim
e
()
# 结束计时
end_time_
=
time
.
time
()
# 结束计时
elapsed_time
=
round
((
end_time_
-
start_time_
)
*
1000
,
7
)
# 计算耗时
elapsed_time
=
round
((
end_time_
-
start_time_
)
*
1000
,
7
)
# 计算耗时
print
(
"_time:{} us
\n
"
.
format
(
elapsed_time
))
print
(
"_time:{} us
\n
"
.
format
(
elapsed_time
))
quantiles
=
[
0.5
,
0.2
,
0.8
]
quantiles
=
[
0.5
,
0.2
,
0.8
]
gpu_costtime
=
triton
.
testing
.
do_bench
(
lambda
:
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
),
quantiles
=
None
,
return_mode
=
"mean"
)
*
1000
gpu_costtime
=
triton
.
testing
.
do_bench
(
lambda
:
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
),
quantiles
=
None
,
return_mode
=
"mean"
)
*
1000
if
bias
is
not
None
:
if
bias
is
not
None
:
output
=
output
+
bias
output
=
output
+
bias
return
output
.
to
(
dtype
=
out_dtype
),
elapsed_time
,
gpu_costtime
,
config
return
output
.
to
(
dtype
=
out_dtype
),
elapsed_time
,
gpu_costtime
,
config
def
get_triton_cache
(
file_path
,
n
,
k
,
block_n
,
block_k
):
def
get_triton_cache
(
file_path
,
n
,
k
,
block_n
,
block_k
):
#会将所报错的json文件以字典的形式return出来
#会将所报错的json文件以字典的形式return出来
...
@@ -862,36 +835,32 @@ def main():
...
@@ -862,36 +835,32 @@ def main():
best_config
=
[]
best_config
=
[]
apply_w8a8_block_int8_linear_batched_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
apply_w8a8_block_int8_linear_batched_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
if
not
use_lightop_w8a8
(
block_size
):
output
,
elapsed_time
,
gpu_costtime
,
config
=
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
output
,
elapsed_time
,
gpu_costtime
,
config
=
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
cost_times
.
append
(
elapsed_time
)
cost_times
.
append
(
elapsed_time
)
gpu_costtimes
.
append
(
gpu_costtime
)
gpu_costtimes
.
append
(
gpu_costtime
)
_n
.
append
(
n_list
[
i
])
_n
.
append
(
n_list
[
i
])
_k
.
append
(
k_list
[
i
])
_k
.
append
(
k_list
[
i
])
_m
.
append
(
m
)
_m
.
append
(
m
)
print
(
f
"zhenggf,
{
config
}
"
)
print
(
f
"zhenggf,
{
config
}
"
)
print
(
f
"zhenggf,
{
config
.
kwargs
}
"
)
print
(
f
"zhenggf,
{
config
.
kwargs
}
"
)
_configs_block_m
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_M'
])
_configs_block_m
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_M'
])
_configs_block_n
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_N'
])
_configs_block_n
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_N'
])
_configs_block_k
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_K'
])
_configs_block_k
.
append
(
config
.
kwargs
[
'BLOCK_SIZE_K'
])
_configs_block_group_m
.
append
(
config
.
kwargs
[
'GROUP_SIZE_M'
])
_configs_block_group_m
.
append
(
config
.
kwargs
[
'GROUP_SIZE_M'
])
_configs_block_num_warps
.
append
(
config
.
num_warps
)
_configs_block_num_warps
.
append
(
config
.
num_warps
)
_configs_block_num_stages
.
append
(
config
.
num_stages
)
_configs_block_num_stages
.
append
(
config
.
num_stages
)
# _configs_kpack.append(config['kpack'])
# _configs_kpack.append(config['kpack'])
else
:
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
if
not
use_lightop_w8a8
:
# 创建一个包含这三个列表的 DataFrame
# 创建一个包含这三个列表的 DataFrame
df
=
pd
.
DataFrame
({
'm'
:
_m
,
'n'
:
_n
,
'k'
:
_k
,
'线性层gemm量化算子耗时'
:
cost_times
,
'GPU算子耗时'
:
gpu_costtimes
,
df
=
pd
.
DataFrame
({
'm'
:
_m
,
'n'
:
_n
,
'k'
:
_k
,
'线性层gemm量化算子耗时'
:
cost_times
,
'GPU算子耗时'
:
gpu_costtimes
,
'BLOCK_SIZE_M'
:
_configs_block_m
,
'BLOCK_SIZE_N'
:
_configs_block_n
,
'BLOCK_SIZE_K'
:
_configs_block_k
,
'BLOCK_SIZE_M'
:
_configs_block_m
,
'BLOCK_SIZE_N'
:
_configs_block_n
,
'BLOCK_SIZE_K'
:
_configs_block_k
,
'GROUP_SIZE_M'
:
_configs_block_group_m
,
'num_warps'
:
_configs_block_num_warps
,
'num_stages'
:
_configs_block_num_stages
,
#'kpack':_configs_kpack
'GROUP_SIZE_M'
:
_configs_block_group_m
,
'num_warps'
:
_configs_block_num_warps
,
'num_stages'
:
_configs_block_num_stages
,
#'kpack':_configs_kpack
})
})
# 将 DataFrame 写入 Excel 文件
# 将 DataFrame 写入 Excel 文件
df
.
to_excel
(
'gemmoutput.xlsx'
,
index
=
False
)
df
.
to_excel
(
'gemmoutput.xlsx'
,
index
=
False
)
print
(
"表格已保存到 gemmoutput.xlsx 文件中。"
)
print
(
"表格已保存到 gemmoutput.xlsx 文件中。"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
...
...
transformer_engine/pytorch/utils.py
View file @
3f800f01
...
@@ -10,12 +10,6 @@ import os
...
@@ -10,12 +10,6 @@ import os
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
warnings
try
:
import
lightop
enable_lightop
=
True
except
ImportError
:
enable_lightop
=
False
import
transformer_engine.pytorch.cpp_extensions
as
ext
import
transformer_engine.pytorch.cpp_extensions
as
ext
from
.
import
torch_version
from
.
import
torch_version
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
@@ -460,18 +454,6 @@ if IS_HIP_EXTENSION:
...
@@ -460,18 +454,6 @@ if IS_HIP_EXTENSION:
import
re
import
re
return
(
re
.
search
(
'BW'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
return
(
re
.
search
(
'BW'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
def
use_lightop_w8a8
(
block_size
:
List
[
int
])
->
bool
:
"""Check whether to use lightop for w8a8"""
# Just return False because lightop is not ready now.
return
False
if
(
enable_lightop
):
return
get_device_compute_capability
()
>=
(
9
,
3
)
and
block_size
[
1
]
==
128
else
:
if
(
get_device_compute_capability
()
>=
(
9
,
3
)
and
block_size
[
1
]
==
128
):
warnings
.
warn
(
"Lightop is not available. Using default implementation for w8a8."
)
return
False
def
is_bf16_compatible
()
->
None
:
def
is_bf16_compatible
()
->
None
:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment