Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
b7afba08
Commit
b7afba08
authored
Jun 05, 2025
by
yuguo
Browse files
[DCU] support block fp8 simu with int8 for MOE
parent
735227cd
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
507 additions
and
3 deletions
+507
-3
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+62
-2
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
+222
-1
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
...mer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
+223
-0
No files found.
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
b7afba08
...
...
@@ -10,8 +10,8 @@ import torch
import
transformer_engine_torch
as
tex
from
..constants
import
TE_DType
from
..utils
import
get_sm_count
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad
import
w8a8_block_int8_matmul_wgrad
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt
import
w8a8_block_int8_matmul
,
w8a8_block_int8_matmul_batched
from
transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad
import
w8a8_block_int8_matmul_wgrad
,
w8a8_block_int8_matmul_wgrad_batched
from
..tensor.quantized_tensor
import
Quantizer
from
..tensor._internal.float8_blockwise_tensor_base
import
Float8BlockwiseQTensorBase
from
...debug.pytorch.debug_quantization
import
DebugQuantizer
...
...
@@ -205,6 +205,63 @@ def general_grouped_gemm(
transa
=
layout
[
0
]
==
"T"
transb
=
layout
[
1
]
==
"T"
if
int8_simulation_fp8
and
(
isinstance
(
A
[
0
],
Float8BlockwiseQTensorBase
)
or
isinstance
(
B
[
0
],
Float8BlockwiseQTensorBase
)):
assert
len
(
set
(
m_splits
))
==
1
,
"Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert
not
gelu
,
"GELU not supported with int8 simulation groupgemm."
assert
bias
is
None
,
"Bias not supported with int8 simulation groupgemm."
assert
not
accumulate
,
"Accumulation not supported with int8 simulation groupgemm."
if
layout
==
"TN"
:
qx_data
=
[
b
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
for
b
in
B
]
qw_data
=
[
a
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
for
a
in
A
]
ref_scales_x
=
[
b
.
_rowwise_scale_inv
for
b
in
B
]
ref_scales_w
=
[
a
.
_rowwise_scale_inv
for
a
in
A
]
y
,
_
=
w8a8_block_int8_matmul_batched
(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
128
,
128
],
output_dtype
=
out_dtype
)
return
y
,
None
,
None
elif
layout
==
"NN"
:
qdout_data
=
[
b
.
_rowwise_data
.
view
(
dtype
=
torch
.
int8
)
for
b
in
B
]
qw_data
=
[
a
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
for
a
in
A
]
ref_scales_dout
=
[
b
.
_rowwise_scale_inv
for
b
in
B
]
ref_scales_w
=
[
a
.
_columnwise_scale_inv
for
a
in
A
]
y
,
_
=
w8a8_block_int8_matmul_batched
(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
128
,
128
],
output_dtype
=
out_dtype
)
return
y
,
None
,
None
elif
layout
==
"NT"
:
qdout_data
=
[
b
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
for
b
in
B
]
qx_data
=
[
a
.
_columnwise_data
.
view
(
dtype
=
torch
.
int8
)
for
a
in
A
]
ref_scales_dout
=
[
b
.
_columnwise_scale_inv
for
b
in
B
]
ref_scales_x
=
[
a
.
_columnwise_scale_inv
for
a
in
A
]
y
,
_
=
w8a8_block_int8_matmul_wgrad_batched
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
[
128
,
128
],
output_dtype
=
out_dtype
)
return
y
,
None
,
None
else
:
raise
ValueError
(
f
"Unsupported layout
{
layout
}
in int8 simulation fp8"
)
empty_tensor
=
_empty_tensor
()
empty_tensors
=
[
empty_tensor
]
*
num_gemms
...
...
@@ -276,6 +333,9 @@ def batchgemm(
empty_tensor
=
torch
.
Tensor
()
empty_tensors
=
[
torch
.
Tensor
()]
*
num_gemms
if
int8_simulation_fp8
:
assert
0
,
"If you want to use batchgemm in int8 simulation, please unset GROUPED_GEMM_BatchLinear and use moe groupgemm with pad token."
if
gelu
and
not
grad
:
gelu_input
=
[
torch
.
empty_like
(
o
,
dtype
=
dtype
,
memory_format
=
torch
.
contiguous_format
)
for
o
in
out
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt.py
View file @
b7afba08
...
...
@@ -338,7 +338,210 @@ def w8a8_block_int8_matmul(
return
C
,
config
@
triton
.
jit
def
_w8a8_block_int8_matmul_batched
(
# Pointers to inputs and output
A
,
B
,
C
,
As
,
Bs
,
# Shape for matmul
M
,
N
,
K
,
# Block size for block-wise quantization
group_n
,
group_k
,
# Stride for inputs and output
stride_a_batch
,
stride_am
,
stride_ak
,
stride_b_batch
,
stride_bk
,
stride_bn
,
stride_c_batch
,
stride_cm
,
stride_cn
,
stride_as_batch
,
stride_As_m
,
stride_As_k
,
stride_bs_batch
,
stride_Bs_k
,
stride_Bs_n
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization,
and store the result in output tensor `C`.
"""
pid_mn
=
tl
.
program_id
(
axis
=
0
)
pid_batch
=
tl
.
program_id
(
axis
=
1
)
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid_mn
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
(
pid_mn
%
group_size_m
)
pid_n
=
(
pid_mn
%
num_pid_in_group
)
//
group_size_m
offs_am
=
(
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
))
%
M
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
# offs_bsn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_bsn
=
pid_n
*
BLOCK_SIZE_N
//
group_n
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
A
+
pid_batch
*
stride_a_batch
+
(
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
B
+
pid_batch
*
stride_b_batch
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
# a_ptrs = A + (offs_am[:, None] * stride_am)
# b_ptrs = B + (offs_bn[None, :] * stride_bn)
As_ptrs
=
As
+
pid_batch
*
stride_as_batch
+
offs_am
*
stride_As_m
# offs_bsn = offs_bn // group_n
Bs_ptrs
=
Bs
+
pid_batch
*
stride_bs_batch
+
offs_bsn
*
stride_Bs_n
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_s
=
tl
.
load
(
As_ptrs
+
offs_ks
*
stride_As_k
)
b_s
=
tl
.
load
(
Bs_ptrs
+
offs_ks
*
stride_Bs_k
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
accumulator
+=
tl
.
dot
(
a
,
b
).
to
(
tl
.
float32
)
*
a_s
[:,
None
]
*
b_s
[
None
,
:]
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
C
.
dtype
.
element_ty
==
tl
.
bfloat16
:
c
=
accumulator
.
to
(
tl
.
bfloat16
)
elif
C
.
dtype
.
element_ty
==
tl
.
float16
:
c
=
accumulator
.
to
(
tl
.
float16
)
else
:
c
=
accumulator
.
to
(
tl
.
float32
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
C
+
pid_batch
*
stride_c_batch
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
def
w8a8_block_int8_matmul_batched
(
A_list
,
B_list
,
As_list
,
Bs_list
,
block_size
,
output_dtype
=
torch
.
float16
,
best_config
=
None
):
A
=
torch
.
stack
(
A_list
).
contiguous
()
# [B, M, K]
B
=
torch
.
stack
(
B_list
).
contiguous
()
# [B, N, K]
As
=
torch
.
stack
(
As_list
).
contiguous
()
Bs
=
torch
.
stack
(
Bs_list
).
contiguous
()
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
//
A
.
shape
[
0
]
batch
,
N
,
K
=
B
.
shape
block_n
,
block_k
=
block_size
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
block_n
,
"BLOCK_SIZE_K"
:
block_k
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
4
,
"num_stages"
:
1
,
}
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
batch
,
)
_w8a8_block_int8_matmul_batched
[
grid
](
A
,
B
,
C
,
As
,
Bs
,
M
,
N
,
K
,
block_n
,
block_k
,
A
.
stride
(
0
),
A
.
stride
(
-
2
),
A
.
stride
(
-
1
),
B
.
stride
(
0
),
B
.
stride
(
-
1
),
B
.
stride
(
-
2
),
C
.
stride
(
0
),
C
.
stride
(
-
2
),
C
.
stride
(
-
1
),
As
.
stride
(
0
),
As
.
stride
(
-
1
),
As
.
stride
(
-
2
),
Bs
.
stride
(
0
),
Bs
.
stride
(
-
1
),
Bs
.
stride
(
-
2
),
**
config
,
)
return
C
def
apply_w8a8_block_int8_linear_batched_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
best_config
:
Optional
[
dict
]
=
None
):
batch
=
4
q_input
,
x_scale
,
weight
,
weight_scale
=
_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
out_dtype
=
out_dtype
,
device
=
device
,
block_size
=
block_size
)
q_input_b
=
[
q_input
.
clone
().
contiguous
()
for
i
in
range
(
batch
)]
x_scale_b
=
[
x_scale
.
clone
().
contiguous
()
for
i
in
range
(
batch
)]
weight_b
=
[
weight
.
clone
().
contiguous
()
for
i
in
range
(
batch
)]
weight_scale_b
=
[
weight_scale
.
clone
().
contiguous
()
for
i
in
range
(
batch
)]
# print(f"zhenggf, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
torch_output
=
native_w8a8_block_int8_matmul_batched
(
q_input_b
,
weight_b
,
x_scale_b
,
weight_scale_b
,
block_size
)
# print(f"zhenggf, torch_output:{torch_output.shape}")
x_scale_b
=
[
xs
.
permute
(
1
,
0
).
contiguous
()
for
xs
in
x_scale_b
]
output
=
w8a8_block_int8_matmul_batched
(
q_input_b
,
weight_b
,
x_scale_b
,
weight_scale_b
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
)
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
print
(
"triton 精度检查不合格!!!"
)
else
:
print
(
"triton 精度检查合格"
)
# unit test end
def
apply_w8a8_block_int8_linear_helper
(
m
:
int
,
n
:
int
,
...
...
@@ -489,6 +692,24 @@ def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.b
C
=
C
.
reshape
(
origin_C_shape
).
to
(
output_dtype
)
return
C
def
native_w8a8_block_int8_matmul_batched
(
A_list
,
B_list
,
As_list
,
Bs_list
,
block_size
,
output_dtype
=
torch
.
bfloat16
):
"""
Batched version of native block-wise quantized matmul.
Args:
A_list (List[Tensor]): [B, M, K]
B_list (List[Tensor]): [B, N, K]
As_list (List[Tensor]): [B, M, K // block_k]
Bs_list (List[Tensor]): [B, N // block_n, K // block_k]
Returns:
Tensor: [B, M, N]
"""
results
=
[]
for
A
,
B
,
As
,
Bs
in
zip
(
A_list
,
B_list
,
As_list
,
Bs_list
):
C
=
native_w8a8_block_int8_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
)
results
.
append
(
C
)
return
torch
.
stack
(
results
)
def
main
():
m1
=
[
item
if
item
<
17
else
1
<<
(
item
-
27
)
for
item
in
range
(
1
,
17
)]
m2
=
[
item
<<
2
if
item
<
17
else
(
item
-
8
)
<<
3
for
item
in
range
(
5
,
29
)]
...
...
@@ -529,7 +750,7 @@ def main():
best_config
=
[]
apply_w8a8_block_int8_linear_batched_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
output
,
elapsed_time
,
gpu_costtime
,
config
=
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
cost_times
.
append
(
elapsed_time
)
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
View file @
b7afba08
...
...
@@ -154,6 +154,110 @@ def _w8a8_block_int8_matmul(
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
@
triton
.
jit
def
_w8a8_block_int8_matmul_batched
(
# Pointers to inputs and output
A
,
B
,
C
,
As
,
Bs
,
# Shape for matmul
M
,
N
,
K
,
# Block size for block-wise quantization
group_n
,
group_k
,
# Stride for inputs and output
stride_a_batch
,
stride_am
,
stride_ak
,
stride_b_batch
,
stride_bk
,
stride_bn
,
stride_c_batch
,
stride_cm
,
stride_cn
,
stride_as_batch
,
stride_As_m
,
stride_As_k
,
stride_bs_batch
,
stride_Bs_k
,
stride_Bs_n
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization,
and store the result in output tensor `C`.
"""
pid_mn
=
tl
.
program_id
(
axis
=
0
)
pid_batch
=
tl
.
program_id
(
axis
=
1
)
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid_mn
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
(
pid_mn
%
group_size_m
)
pid_n
=
(
pid_mn
%
num_pid_in_group
)
//
group_size_m
offs_am
=
(
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
))
%
M
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_bsn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
# offs_bsn = pid_n * BLOCK_SIZE_N // group_n
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
A
+
pid_batch
*
stride_a_batch
+
(
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
B
+
pid_batch
*
stride_b_batch
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
As_ptrs
=
As
+
pid_batch
*
stride_as_batch
+
offs_am
*
stride_As_m
# offs_bsn = offs_bn // group_n
Bs_ptrs
=
Bs
+
pid_batch
*
stride_bs_batch
+
offs_bsn
*
stride_Bs_n
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_s
=
tl
.
load
(
As_ptrs
+
offs_ks
*
stride_As_k
)
b_s
=
tl
.
load
(
Bs_ptrs
+
offs_ks
*
stride_Bs_k
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
accumulator
+=
tl
.
dot
(
a
,
b
).
to
(
tl
.
float32
)
*
a_s
[:,
None
]
*
b_s
[
None
,
:]
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
C
.
dtype
.
element_ty
==
tl
.
bfloat16
:
c
=
accumulator
.
to
(
tl
.
bfloat16
)
elif
C
.
dtype
.
element_ty
==
tl
.
float16
:
c
=
accumulator
.
to
(
tl
.
float16
)
else
:
c
=
accumulator
.
to
(
tl
.
float32
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
C
+
pid_batch
*
stride_c_batch
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
@
functools
.
lru_cache
def
get_w8a8_block_int8_configs
(
N
:
int
,
K
:
int
,
block_n
:
int
,
block_k
:
int
...
...
@@ -338,6 +442,107 @@ def w8a8_block_int8_matmul_wgrad(
return
C
,
config
def
w8a8_block_int8_matmul_wgrad_batched
(
A_list
,
B_list
,
As_list
,
Bs_list
,
block_size
,
output_dtype
=
torch
.
float16
,
best_config
=
None
):
A
=
torch
.
stack
(
A_list
).
contiguous
()
# [B, M, K]
B
=
torch
.
stack
(
B_list
).
contiguous
()
# [B, N, K]
As
=
torch
.
stack
(
As_list
).
contiguous
()
Bs
=
torch
.
stack
(
Bs_list
).
contiguous
()
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
//
A
.
shape
[
0
]
batch
,
N
,
K
=
B
.
shape
block_n
,
block_k
=
block_size
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
block_n
,
"BLOCK_SIZE_K"
:
block_k
,
"GROUP_SIZE_M"
:
8
,
"num_warps"
:
4
,
"num_stages"
:
1
,
}
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
batch
,
)
_w8a8_block_int8_matmul_batched
[
grid
](
A
,
B
,
C
,
As
,
Bs
,
M
,
N
,
K
,
block_n
,
block_k
,
A
.
stride
(
0
),
A
.
stride
(
-
2
),
A
.
stride
(
-
1
),
B
.
stride
(
0
),
B
.
stride
(
-
1
),
B
.
stride
(
-
2
),
C
.
stride
(
0
),
C
.
stride
(
-
2
),
C
.
stride
(
-
1
),
As
.
stride
(
0
),
As
.
stride
(
-
1
),
As
.
stride
(
-
2
),
Bs
.
stride
(
0
),
Bs
.
stride
(
-
2
),
Bs
.
stride
(
-
1
),
**
config
,
)
return
C
def
apply_w8a8_block_int8_linear_batched_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda"
,
block_size
:
List
[
int
]
=
[
128
,
128
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
best_config
:
Optional
[
dict
]
=
None
):
batch
=
4
q_input
,
x_scale
,
weight
,
weight_scale
=
_int8_gemm_helper_b
(
m
=
m
,
n
=
n
,
k
=
k
,
out_dtype
=
out_dtype
,
device
=
device
,
block_size
=
block_size
)
q_input_b
=
[
q_input
.
clone
().
contiguous
()
for
i
in
range
(
batch
)]
x_scale_b
=
[
x_scale
.
clone
().
contiguous
()
for
i
in
range
(
batch
)]
weight_b
=
[
weight
.
clone
().
contiguous
()
for
i
in
range
(
batch
)]
weight_scale_b
=
[
weight_scale
.
clone
().
contiguous
()
for
i
in
range
(
batch
)]
# print(f"zhenggf, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
torch_output
=
native_w8a8_block_int8_matmul_batched
(
q_input_b
,
weight_b
,
x_scale_b
,
weight_scale_b
,
block_size
)
# print(f"zhenggf, torch_output:{torch_output.shape}")
x_scale_b
=
[
xs
.
permute
(
1
,
0
).
contiguous
()
for
xs
in
x_scale_b
]
weight_scale_b
=
[
ws
.
permute
(
1
,
0
).
contiguous
()
for
ws
in
weight_scale_b
]
# print(f"zhenggf 转置后传递给triton kernel, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
output
=
w8a8_block_int8_matmul_wgrad_batched
(
q_input_b
,
weight_b
,
x_scale_b
,
weight_scale_b
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
)
if
not
torch
.
allclose
(
output
,
torch_output
,
rtol
=
1e-2
,
atol
=
5e-2
):
print
(
"triton 精度检查不合格!!!"
)
else
:
print
(
"triton 精度检查合格"
)
# unit test end
def
apply_w8a8_block_int8_linear_helper
(
m
:
int
,
...
...
@@ -494,6 +699,23 @@ def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.b
C
=
C
.
reshape
(
origin_C_shape
).
to
(
output_dtype
)
return
C
def
native_w8a8_block_int8_matmul_batched
(
A_list
,
B_list
,
As_list
,
Bs_list
,
block_size
,
output_dtype
=
torch
.
bfloat16
):
"""
Batched version of native block-wise quantized matmul.
Args:
A_list (List[Tensor]): [B, M, K]
B_list (List[Tensor]): [B, N, K]
As_list (List[Tensor]): [B, M, K // block_k]
Bs_list (List[Tensor]): [B, N, K // block_k]
Returns:
Tensor: [B, M, N]
"""
results
=
[]
for
A
,
B
,
As
,
Bs
in
zip
(
A_list
,
B_list
,
As_list
,
Bs_list
):
C
=
native_w8a8_block_int8_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
)
results
.
append
(
C
)
return
torch
.
stack
(
results
)
def
main
():
m1
=
[
item
if
item
<
17
else
1
<<
(
item
-
27
)
for
item
in
range
(
1
,
17
)]
m2
=
[
item
<<
2
if
item
<
17
else
(
item
-
8
)
<<
3
for
item
in
range
(
5
,
29
)]
...
...
@@ -534,6 +756,7 @@ def main():
best_config
=
[]
apply_w8a8_block_int8_linear_batched_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
output
,
elapsed_time
,
gpu_costtime
,
config
=
apply_w8a8_block_int8_linear_helper
(
m
=
m
,
n
=
n_list
[
i
],
k
=
k_list
[
i
],
block_size
=
block_size
,
out_dtype
=
out_dtype
,
best_config
=
best_config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment