Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
e2bc1cb6
Commit
e2bc1cb6
authored
Mar 14, 2025
by
Yuxuan Hu
Committed by
LeiWang1999
Mar 14, 2025
Browse files
[Bugfix] Fix `K // block_K` to T.ceildiv(K,block_K) and add tests (#210)
parent
227ed7ec
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
3 deletions
+8
-3
testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py
testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py
+8
-3
No files found.
testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py
View file @
e2bc1cb6
...
@@ -119,7 +119,7 @@ def tl_matmul(
...
@@ -119,7 +119,7 @@ def tl_matmul(
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
(
(
K
//
block_K
),
num_stages
=
stage
):
for
ko
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
stage
):
# Load A into shared memory
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
...
@@ -182,7 +182,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
...
@@ -182,7 +182,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
mod
=
TL
.
Profiler
(
mod
,
params
,
[],
TL
.
TensorSupplyType
.
Integer
)
mod
=
TL
.
Profiler
(
mod
,
params
,
[],
TL
.
TensorSupplyType
.
Integer
)
mod
(
compressed_A
,
compressed_B
,
C
)
mod
(
compressed_A
,
compressed_B
,
C
)
print
(
C
)
print
(
C
)
latency
=
mod
.
do_bench
(
mod
.
func
,
warmup
=
25
,
profiler
=
"tvm"
)
latency
=
mod
.
do_bench
(
mod
.
func
,
warmup
=
25
)
print
(
latency
)
print
(
latency
)
# Ensure that the latency is not None
# Ensure that the latency is not None
assert
latency
is
not
None
assert
latency
is
not
None
...
@@ -194,6 +194,11 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
...
@@ -194,6 +194,11 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
torch
.
testing
.
assert_close
(
C
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
torch
.
testing
.
assert_close
(
C
,
ref_c
,
rtol
=
1e-2
,
atol
=
1e-2
)
def
test_assert_tl_matmul_correctness
():
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"int8"
,
"int32"
,
"int32"
)
assert_tl_matmul_correctness
(
128
,
128
,
64
,
"int8"
,
"int32"
,
"int32"
)
@
simplify_prim_func
@
simplify_prim_func
def
tl_matmul_weight_only_transform
(
def
tl_matmul_weight_only_transform
(
M
,
M
,
...
@@ -302,7 +307,7 @@ def tl_matmul_weight_only_transform(
...
@@ -302,7 +307,7 @@ def tl_matmul_weight_only_transform(
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
(
(
K
//
block_K
),
num_stages
=
stage
):
for
ko
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
stage
):
# Load A into shared memory
# Load A into shared memory
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
for
i
,
k
in
T
.
Parallel
(
block_M
,
block_K
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment