Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
d2f59cfa
"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "16dcfebbb1266e423b045de21d0c5ff1dc746dc3"
Commit
d2f59cfa
authored
Apr 03, 2025
by
Chunan Zeng
Committed by
LeiWang1999
Apr 03, 2025
Browse files
Support block_N sizes that are 2^n in deepgemm example (#319)
parent
853898a7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
8 deletions
+9
-8
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
+9
-8
No files found.
examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
View file @
d2f59cfa
...
@@ -13,6 +13,7 @@ def tl_gemm(
...
@@ -13,6 +13,7 @@ def tl_gemm(
M
,
M
,
N
,
N
,
K
,
K
,
block_N
,
in_dtype
,
in_dtype
,
out_dtype
,
out_dtype
,
accum_dtype
,
accum_dtype
,
...
@@ -25,15 +26,14 @@ def tl_gemm(
...
@@ -25,15 +26,14 @@ def tl_gemm(
"float32"
,
"float32"
,
],
"Currently only float16 and float32 are supported"
],
"Currently only float16 and float32 are supported"
TILE_SIZE
=
(
128
,
128
,
128
)
group_size
=
128
block_M
=
TILE_SIZE
[
0
]
block_M
=
128
block_N
=
TILE_SIZE
[
1
]
block_K
=
128
block_K
=
TILE_SIZE
[
2
]
A_shape
=
(
M
,
K
)
A_shape
=
(
M
,
K
)
Scales_A_shape
=
(
M
,
T
.
ceildiv
(
K
,
block_K
))
Scales_A_shape
=
(
M
,
T
.
ceildiv
(
K
,
group_size
))
B_shape
=
(
N
,
K
)
B_shape
=
(
N
,
K
)
Scales_B_shape
=
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
K
,
block_K
))
Scales_B_shape
=
(
T
.
ceildiv
(
N
,
group_size
),
T
.
ceildiv
(
K
,
group_size
))
A_shared_shape
=
(
block_M
,
block_K
)
A_shared_shape
=
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
C_shared_shape
=
(
block_M
,
block_N
)
C_shared_shape
=
(
block_M
,
block_N
)
...
@@ -67,7 +67,7 @@ def tl_gemm(
...
@@ -67,7 +67,7 @@ def tl_gemm(
# Load B into shared memory
# Load B into shared memory
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
# Load scale into shared memory
# Load scale into shared memory
Scale_B
=
scales_b
[
bx
,
k
]
Scale_B
=
scales_b
[
bx
*
block_N
//
group_size
,
k
]
for
i
in
T
.
Parallel
(
block_M
):
for
i
in
T
.
Parallel
(
block_M
):
Scale_C_shared
[
i
]
=
scales_a
[
by
*
block_M
+
i
,
k
]
*
Scale_B
Scale_C_shared
[
i
]
=
scales_a
[
by
*
block_M
+
i
,
k
]
*
Scale_B
...
@@ -181,4 +181,5 @@ def assert_tl_gemm_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
...
@@ -181,4 +181,5 @@ def assert_tl_gemm_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
for
dtype
in
[
"e4m3_float8"
]:
for
dtype
in
[
"e4m3_float8"
]:
for
out_dtype
in
[
"bfloat16"
,
"float32"
]:
for
out_dtype
in
[
"bfloat16"
,
"float32"
]:
assert_tl_gemm_correctness
(
1024
,
1024
,
8192
,
dtype
,
out_dtype
,
"float32"
)
for
block_N
in
[
16
,
32
,
64
,
128
]:
assert_tl_gemm_correctness
(
1024
,
1024
,
8192
,
block_N
,
dtype
,
out_dtype
,
"float32"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment