Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
3653fbfb
Commit
3653fbfb
authored
Jun 16, 2025
by
yuguo
Browse files
[DCU] fix in8 simul fp8 fused wgrad accumulation
parent
ecdd8251
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
36 deletions
+45
-36
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+6
-27
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
...mer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
+39
-9
No files found.
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
3653fbfb
...
...
@@ -79,9 +79,6 @@ def general_gemm(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
128
,
128
],
output_dtype
=
out_dtype
)
if
accumulate
:
assert
out
is
not
None
y
=
y
+
out
return
y
,
None
,
None
,
None
elif
layout
==
"NN"
:
...
...
@@ -98,9 +95,6 @@ def general_gemm(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
128
,
128
],
output_dtype
=
out_dtype
)
if
accumulate
:
assert
out
is
not
None
y
=
y
+
out
return
y
,
None
,
None
,
None
elif
layout
==
"NT"
:
...
...
@@ -113,14 +107,11 @@ def general_gemm(
ref_scales_dout
=
B
.
_columnwise_scale_inv
ref_scales_x
=
A
.
_columnwise_scale_inv
y
,
_
=
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
[
128
,
128
],
out
,
_
=
w8a8_block_int8_matmul_wgrad
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
,
[
128
,
128
],
output_dtype
=
out_dtype
)
if
accumulate
:
assert
out
is
not
None
y
=
y
+
out
return
y
,
None
,
None
,
None
return
out
,
None
,
None
,
None
else
:
raise
ValueError
(
f
"Unsupported layout
{
layout
}
in int8 simulation fp8"
)
...
...
@@ -226,10 +217,6 @@ def general_grouped_gemm(
qx_data
,
qw_data
,
ref_scales_x
,
ref_scales_w
,
[
128
,
128
],
output_dtype
=
out_dtype
)
if
accumulate
:
assert
out
is
not
None
out
=
torch
.
stack
(
out
).
contiguous
()
y
=
y
+
out
return
y
,
None
,
None
elif
layout
==
"NN"
:
...
...
@@ -246,10 +233,6 @@ def general_grouped_gemm(
qdout_data
,
qw_data
,
ref_scales_dout
,
ref_scales_w
,
[
128
,
128
],
output_dtype
=
out_dtype
)
if
accumulate
:
assert
out
is
not
None
out
=
torch
.
stack
(
out
).
contiguous
()
y
=
y
+
out
return
y
,
None
,
None
elif
layout
==
"NT"
:
...
...
@@ -262,15 +245,11 @@ def general_grouped_gemm(
ref_scales_dout
=
[
b
.
_columnwise_scale_inv
for
b
in
B
]
ref_scales_x
=
[
a
.
_columnwise_scale_inv
for
a
in
A
]
y
,
_
=
w8a8_block_int8_matmul_wgrad_batched
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
[
128
,
128
],
out
,
_
=
w8a8_block_int8_matmul_wgrad_batched
(
qdout_data
,
qx_data
,
ref_scales_dout
,
ref_scales_x
,
out
,
accumulate
[
128
,
128
],
output_dtype
=
out_dtype
)
if
accumulate
:
assert
out
is
not
None
out
=
torch
.
stack
(
out
).
contiguous
()
y
=
y
+
out
return
y
,
None
,
None
return
out
,
None
,
None
else
:
raise
ValueError
(
f
"Unsupported layout
{
layout
}
in int8 simulation fp8"
)
...
...
transformer_engine/pytorch/triton/blockwise_int8_gemm_nt_wgrad.py
View file @
3653fbfb
...
...
@@ -82,6 +82,7 @@ def _w8a8_block_int8_matmul(
stride_As_k
,
stride_Bs_k
,
stride_Bs_n
,
accumulate
:
tl
.
constexpr
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
...
...
@@ -151,6 +152,8 @@ def _w8a8_block_int8_matmul(
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
C
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
if
accumulate
:
c
+=
tl
.
load
(
c_ptrs
,
mask
=
c_mask
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
...
...
@@ -185,6 +188,7 @@ def _w8a8_block_int8_matmul_batched(
stride_bs_batch
,
stride_Bs_k
,
stride_Bs_n
,
accumulate
:
tl
.
constexpr
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
...
...
@@ -256,6 +260,8 @@ def _w8a8_block_int8_matmul_batched(
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
C
+
pid_batch
*
stride_c_batch
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
if
accumulate
:
c
+=
tl
.
load
(
c_ptrs
,
mask
=
c_mask
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
@
functools
.
lru_cache
...
...
@@ -304,6 +310,8 @@ def w8a8_block_int8_matmul_wgrad(
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
accumulate
:
bool
,
block_size
:
List
[
int
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
best_config
:
Optional
[
dict
]
=
None
...
...
@@ -338,6 +346,10 @@ def w8a8_block_int8_matmul_wgrad(
# assert triton.cdiv(N, block_n) == Bs.shape[0]
# assert triton.cdiv(K, block_k) == Bs.shape[1]
if
accumulate
:
assert
C
is
not
None
if
C
is
None
:
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
...
...
@@ -435,6 +447,7 @@ def w8a8_block_int8_matmul_wgrad(
As
.
stride
(
0
),
Bs
.
stride
(
-
2
),
Bs
.
stride
(
-
1
),
accumulate
,
# Bs.stride(1),
# Bs.stride(0),
# **config,
...
...
@@ -445,7 +458,7 @@ def w8a8_block_int8_matmul_wgrad(
def
w8a8_block_int8_matmul_wgrad_batched
(
A_list
,
B_list
,
As_list
,
Bs_list
,
A_list
,
B_list
,
As_list
,
Bs_list
,
C_list
,
accumulate
,
block_size
,
output_dtype
=
torch
.
float16
,
best_config
=
None
):
A
=
torch
.
stack
(
A_list
).
contiguous
()
# [B, M, K]
...
...
@@ -462,8 +475,17 @@ def w8a8_block_int8_matmul_wgrad_batched(
batch
,
N
,
K
=
B
.
shape
block_n
,
block_k
=
block_size
if
accumulate
:
if
C_list
is
None
:
assert
False
else
:
C
=
torch
.
stack
(
C_list
).
contiguous
()
else
:
if
C_list
is
None
:
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
else
:
C
=
torch
.
stack
(
C_list
).
contiguous
()
config
=
{
"BLOCK_SIZE_M"
:
64
,
...
...
@@ -506,6 +528,7 @@ def w8a8_block_int8_matmul_wgrad_batched(
Bs
.
stride
(
0
),
Bs
.
stride
(
-
2
),
Bs
.
stride
(
-
1
),
accumulate
,
**
config
,
)
...
...
@@ -527,6 +550,10 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
x_scale_b
=
[
x_scale
.
clone
().
contiguous
()
for
i
in
range
(
batch
)]
weight_b
=
[
weight
.
clone
().
contiguous
()
for
i
in
range
(
batch
)]
weight_scale_b
=
[
weight_scale
.
clone
().
contiguous
()
for
i
in
range
(
batch
)]
N
,
K
=
weight
.
shape
C_shape
=
q_input
.
shape
[:
-
1
]
+
(
N
,)
output
=
[
q_input
.
new_empty
(
C_shape
,
dtype
=
out_dtype
)
for
i
in
range
(
batch
)]
# print(f"zhenggf, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
torch_output
=
native_w8a8_block_int8_matmul_batched
(
q_input_b
,
weight_b
,
x_scale_b
,
weight_scale_b
,
block_size
)
...
...
@@ -537,7 +564,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
# print(f"zhenggf 转置后传递给triton kernel, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
output
=
w8a8_block_int8_matmul_wgrad_batched
(
q_input_b
,
weight_b
,
x_scale_b
,
weight_scale_b
,
block_size
,
q_input_b
,
weight_b
,
x_scale_b
,
weight_scale_b
,
output
,
False
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
)
...
...
@@ -568,9 +595,12 @@ def apply_w8a8_block_int8_linear_helper(m: int,
x_scale
=
x_scale
.
permute
(
1
,
0
).
contiguous
()
weight_scale
=
weight_scale
.
permute
(
1
,
0
).
contiguous
()
N
,
K
=
weight
.
shape
C_shape
=
q_input
.
shape
[:
-
1
]
+
(
N
,)
output
=
q_input
.
new_empty
(
C_shape
,
dtype
=
out_dtype
)
print
(
f
"zhenggf 转置后传递给triton kernel, q_input:
{
q_input
.
shape
}
, x_scale:
{
x_scale
.
shape
}
, weight:
{
weight
.
shape
}
, weight_scale:
{
weight_scale
.
shape
}
"
)
output
,
config
=
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
)
...
...
@@ -587,7 +617,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
with
torch
.
cuda
.
graph
(
g
):
for
it
in
range
(
1000
):
output
,
_
=
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
)
...
...
@@ -600,7 +630,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
elapsed_time
=
round
((
end_time_
-
start_time_
)
*
1000
,
7
)
# 计算耗时
print
(
"_time:{} us
\n
"
.
format
(
elapsed_time
))
quantiles
=
[
0.5
,
0.2
,
0.8
]
gpu_costtime
=
triton
.
testing
.
do_bench
(
lambda
:
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
),
quantiles
=
None
,
return_mode
=
"mean"
)
*
1000
gpu_costtime
=
triton
.
testing
.
do_bench
(
lambda
:
w8a8_block_int8_matmul_wgrad
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
,
False
,
block_size
,
output_dtype
=
out_dtype
,
best_config
=
best_config
),
quantiles
=
None
,
return_mode
=
"mean"
)
*
1000
if
bias
is
not
None
:
output
=
output
+
bias
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment