Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
514bdeaa
Unverified
Commit
514bdeaa
authored
Oct 22, 2025
by
Lei Wang
Committed by
GitHub
Oct 22, 2025
Browse files
[Example] Add block level high performance gemv example (#1097)
* add alloc_reducer gemv example * test
parent
f003f371
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
132 additions
and
78 deletions
+132
-78
examples/gemv/example_gemv.py
examples/gemv/example_gemv.py
+131
-77
examples/gemv/test_example_gemv.py
examples/gemv/test_example_gemv.py
+1
-1
No files found.
examples/gemv/example_gemv.py
View file @
514bdeaa
...
...
@@ -216,75 +216,122 @@ def splitk_gemv_vectorized_tvm(
return
main
def
get_best_config
(
N
,
K
):
def
get_configs
():
iter_params
=
dict
(
BLOCK_N
=
[
2
,
4
,
8
,
32
,
64
,
128
],
reduce_threads
=
[
4
,
8
,
32
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())
]
@
autotune
(
configs
=
get_configs
(),
warmup
=
3
,
rep
=
20
,
)
@
jit
(
out_idx
=
[
-
1
],
target
=
"auto"
,
)
def
kernel
(
BLOCK_N
=
None
,
reduce_threads
=
None
,
def
get_block_template_configs
():
iter_params
=
dict
(
block_M
=
[
2
,
4
,
8
,
32
,
64
,
128
],
block_N
=
[
2
,
4
,
8
,
32
,
64
,
128
],
num_stages
=
[
0
,
1
,
2
,
3
,
4
],
threads
=
[
32
,
64
,
128
,
256
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
@
tl
.
autotune
(
configs
=
get_block_template_configs
(),
warmup
=
3
,
rep
=
20
,
)
@
tl
.
jit
(
pass_configs
=
{
tl
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tl
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
},
out_idx
=
[
2
],
)
def
gemv_alloc_reducer
(
M
,
N
,
block_M
=
128
,
block_N
=
128
,
num_stages
=
2
,
threads
=
256
,
dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float"
):
@
T
.
prim_func
def
main
(
a
:
T
.
Tensor
((
M
,
N
),
dtype
),
x
:
T
.
Tensor
(
N
,
dtype
),
o
:
T
.
Tensor
(
M
,
dtype
)):
# type: ignore
with
T
.
Kernel
(
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
i0_m
:
o_reducer
=
T
.
alloc_reducer
(
block_M
,
accum_dtype
,
replication
=
"all"
)
T
.
clear
(
o_reducer
)
for
i0_n
in
T
.
Pipelined
(
T
.
ceildiv
(
N
,
block_N
),
num_stages
=
num_stages
):
a_smem
=
T
.
alloc_shared
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
a
[
i0_m
*
block_M
,
i0_n
*
block_N
],
a_smem
)
a_frag
=
T
.
alloc_fragment
((
block_M
,
block_N
),
dtype
)
T
.
copy
(
a_smem
,
a_frag
)
x_frag
=
T
.
alloc_fragment
(
block_N
,
dtype
)
T
.
copy
(
x
[
i0_n
*
block_N
],
x_frag
)
for
i1_m
,
i1_n
in
T
.
Parallel
(
block_M
,
block_N
):
o_reducer
[
i1_m
]
+=
a_frag
[
i1_m
,
i1_n
]
*
x_frag
[
i1_n
]
T
.
finalize_reducer
(
o_reducer
)
T
.
copy
(
o_reducer
,
o
[
i0_m
*
block_M
])
return
main
def
get_thread_template_configs
():
iter_params
=
dict
(
BLOCK_N
=
[
2
,
4
,
8
,
32
,
64
,
128
],
reduce_threads
=
[
4
,
8
,
32
])
return
[
dict
(
zip
(
iter_params
,
values
))
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
@
autotune
(
configs
=
get_thread_template_configs
(),
warmup
=
3
,
rep
=
20
,
)
@
jit
(
out_idx
=
[
-
1
],
target
=
"auto"
,
)
def
get_autotuned_kernel
(
N
,
K
,
BLOCK_N
=
None
,
reduce_threads
=
None
,
):
dtype
=
"float16"
accum_dtype
=
"float"
MAX_TRANSACTION_SIZE_IN_BITS
=
128
TILE_K
=
MAX_TRANSACTION_SIZE_IN_BITS
//
DataType
(
dtype
).
bits
BLOCK_K
=
reduce_threads
*
TILE_K
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
K
,),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
):
dtype
=
"float16"
accum_dtype
=
"float"
MAX_TRANSACTION_SIZE_IN_BITS
=
128
TILE_K
=
MAX_TRANSACTION_SIZE_IN_BITS
//
DataType
(
dtype
).
bits
BLOCK_K
=
reduce_threads
*
TILE_K
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
K
,),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
N
,),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
BLOCK_N
),
threads
=
(
BLOCK_N
,
reduce_threads
))
as
bn
:
tn
=
T
.
get_thread_binding
(
0
)
tk
=
T
.
get_thread_binding
(
1
)
A_local
=
T
.
alloc_local
((
TILE_K
,),
dtype
)
B_local
=
T
.
alloc_local
((
TILE_K
,),
dtype
)
C_accum
=
T
.
alloc_local
((
1
,),
accum_dtype
)
T
.
clear
(
C_accum
)
for
bk
in
T
.
serial
(
T
.
ceildiv
(
K
,
BLOCK_K
)):
for
k
in
T
.
vectorized
(
TILE_K
):
A_local
[
k
]
=
A
[
bk
*
BLOCK_K
+
tk
*
TILE_K
+
k
]
B_local
[
k
]
=
B
[
bn
*
BLOCK_N
+
tn
,
bk
*
BLOCK_K
+
tk
*
TILE_K
+
k
]
for
k
in
T
.
serial
(
TILE_K
):
C_accum
[
0
]
+=
A_local
[
k
].
astype
(
accum_dtype
)
*
B_local
[
k
].
astype
(
accum_dtype
)
C_reduced
=
T
.
alloc_local
((
1
,),
accum_dtype
)
with
T
.
attr
(
T
.
comm_reducer
(
lambda
x
,
y
:
x
+
y
,
[
T
.
Cast
(
accum_dtype
,
0
)]),
"reduce_scope"
,
T
.
reinterpret
(
T
.
uint64
(
0
),
dtype
=
"handle"
),
):
T
.
evaluate
(
T
.
tvm_thread_allreduce
(
T
.
uint32
(
1
),
C_accum
[
0
],
True
,
C_reduced
[
0
],
tk
,
dtype
=
"handle"
,
))
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reduced
[
0
]
return
main
return
kernel
()
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
BLOCK_N
),
threads
=
(
BLOCK_N
,
reduce_threads
))
as
bn
:
tn
=
T
.
get_thread_binding
(
0
)
tk
=
T
.
get_thread_binding
(
1
)
A_local
=
T
.
alloc_local
((
TILE_K
,),
dtype
)
B_local
=
T
.
alloc_local
((
TILE_K
,),
dtype
)
C_accum
=
T
.
alloc_local
((
1
,),
accum_dtype
)
T
.
clear
(
C_accum
)
for
bk
in
T
.
serial
(
T
.
ceildiv
(
K
,
BLOCK_K
)):
for
k
in
T
.
vectorized
(
TILE_K
):
A_local
[
k
]
=
A
[
bk
*
BLOCK_K
+
tk
*
TILE_K
+
k
]
B_local
[
k
]
=
B
[
bn
*
BLOCK_N
+
tn
,
bk
*
BLOCK_K
+
tk
*
TILE_K
+
k
]
for
k
in
T
.
serial
(
TILE_K
):
C_accum
[
0
]
+=
A_local
[
k
].
astype
(
accum_dtype
)
*
B_local
[
k
].
astype
(
accum_dtype
)
C_reduced
=
T
.
alloc_local
((
1
,),
accum_dtype
)
with
T
.
attr
(
T
.
comm_reducer
(
lambda
x
,
y
:
x
+
y
,
[
T
.
Cast
(
accum_dtype
,
0
)]),
"reduce_scope"
,
T
.
reinterpret
(
T
.
uint64
(
0
),
dtype
=
"handle"
),
):
T
.
evaluate
(
T
.
tvm_thread_allreduce
(
T
.
uint32
(
1
),
C_accum
[
0
],
True
,
C_reduced
[
0
],
tk
,
dtype
=
"handle"
,
))
C
[
bn
*
BLOCK_N
+
tn
]
=
C_reduced
[
0
]
return
main
def
check_correctness_and_bench
(
kernel
,
N
,
K
,
bench_ref
=
True
):
...
...
@@ -297,7 +344,7 @@ def check_correctness_and_bench(kernel, N, K, bench_ref=True):
print
(
f
"TileLang Latency:
{
latency
}
ms
\n
"
)
def
main
():
def
main
(
do_bench
:
bool
=
True
):
parser
=
argparse
.
ArgumentParser
(
description
=
"GEMV Example"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension N"
)
parser
.
add_argument
(
"--k"
,
type
=
int
,
default
=
1024
,
help
=
"Matrix dimension K"
)
...
...
@@ -308,16 +355,23 @@ def main():
check_correctness_and_bench
(
splitk_gemv
(
N
,
K
,
32
,
32
,
32
),
N
,
K
)
check_correctness_and_bench
(
splitk_gemv_vectorized
(
N
,
K
,
2
,
32
),
N
,
K
)
check_correctness_and_bench
(
splitk_gemv_vectorized_tvm
(
N
,
K
,
2
,
32
),
N
,
K
)
check_correctness_and_bench
(
gemv_alloc_reducer
(
N
,
K
,
block_M
=
128
,
block_N
=
128
),
N
,
K
)
print
(
"Test passed!"
)
best_result
=
get_best_config
(
N
,
K
)
best_config
=
best_result
.
config
kernel
=
splitk_gemv_vectorized_tvm
(
N
,
K
,
**
best_config
)
profiler
=
kernel
.
get_profiler
()
latency
=
profiler
.
do_bench
(
lambda
x
,
y
:
x
@
y
.
T
,
warmup
=
500
)
print
(
f
"Torch Latency:
{
latency
}
ms"
)
latency
=
profiler
.
do_bench
(
kernel
,
warmup
=
500
)
print
(
f
"TileLang Latency:
{
latency
}
ms
\n
"
)
if
not
do_bench
:
best_result
=
get_autotuned_kernel
(
N
,
K
)
best_config
=
best_result
.
config
kernel
=
splitk_gemv_vectorized_tvm
(
N
,
K
,
**
best_config
)
profiler
=
kernel
.
get_profiler
()
latency
=
profiler
.
do_bench
(
lambda
x
,
y
:
x
@
y
.
T
,
warmup
=
500
)
print
(
f
"Torch Latency:
{
latency
}
ms"
)
tilelang_thread_latency
=
profiler
.
do_bench
(
kernel
,
warmup
=
500
)
print
(
f
"TileLang SIMT Latency:
{
tilelang_thread_latency
}
ms
\n
"
)
kernel
=
gemv_alloc_reducer
(
N
,
K
)
profiler
=
kernel
.
get_profiler
()
tilelang_tile_latency
=
profiler
.
do_bench
(
kernel
,
warmup
=
500
)
print
(
f
"TileLang BlockReduce Latency:
{
tilelang_tile_latency
}
ms
\n
"
)
if
__name__
==
"__main__"
:
...
...
examples/gemv/test_example_gemv.py
View file @
514bdeaa
...
...
@@ -4,7 +4,7 @@ import example_gemv
def
test_example_gemv
():
example_gemv
.
main
()
example_gemv
.
main
(
do_bench
=
False
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment