Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
c538d8ab
Unverified
Commit
c538d8ab
authored
Sep 25, 2025
by
Lei Wang
Committed by
GitHub
Sep 25, 2025
Browse files
[Language] Support sequence comparisons (#872)
* Update submodule 'tvm' to latest commit 7a71ee34 * lint fix
parent
2d4b848f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
1 deletion
+53
-1
3rdparty/tvm
3rdparty/tvm
+1
-1
testing/python/language/test_tilelang_language_if_range.py
testing/python/language/test_tilelang_language_if_range.py
+52
-0
No files found.
tvm
@
7a71ee34
Compare
0524f760
...
7a71ee34
Subproject commit
0524f7601d77df47c56253c9a675a6807f737d79
Subproject commit
7a71ee3411e49c3e05b1f1a910cf7f73adc7a5b2
testing/python/language/test_tilelang_language_if_range.py
0 → 100644
View file @
c538d8ab
import
tilelang
import
tilelang.language
as
T
import
torch
import
tilelang.testing
@
tilelang
.
jit
(
out_idx
=
[
1
],)
def
tilelang_if_range
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
row_idx
=
by
*
block_M
+
i
col_idx
=
bx
*
block_N
+
j
# Test condition: ca < i < cb where ca=16, cb=96
if
16
<
row_idx
<
96
:
B
[
row_idx
,
col_idx
]
=
A
[
row_idx
,
col_idx
]
*
2.0
else
:
B
[
row_idx
,
col_idx
]
=
A
[
row_idx
,
col_idx
]
*
0.5
return
main
def
run_tilelang_if_range
(
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float16"
):
kernel
=
tilelang_if_range
(
M
,
N
,
block_M
,
block_N
,
dtype
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
# Reference computation
ref_b
=
torch
.
zeros_like
(
a
)
for
i
in
range
(
M
):
for
j
in
range
(
N
):
# ca < i < cb where ca=16, cb=96
if
16
<
i
<
96
:
ref_b
[
i
,
j
]
=
a
[
i
,
j
]
*
2.0
else
:
ref_b
[
i
,
j
]
=
a
[
i
,
j
]
*
0.5
torch
.
testing
.
assert_close
(
b
,
ref_b
,
rtol
=
1e-2
,
atol
=
1e-2
)
def
test_tilelang_if_range
():
run_tilelang_if_range
(
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment