Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
15479958
Unverified
Commit
15479958
authored
Sep 17, 2025
by
Lei Wang
Committed by
GitHub
Sep 17, 2025
Browse files
[DSL] Support python tenary if then else expression (#822)
* support python tenary if then else expression * lint fix
parent
907c3ff0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
1 deletion
+45
-1
3rdparty/tvm
3rdparty/tvm
+1
-1
testing/python/language/test_tilelang_language_ternary.py
testing/python/language/test_tilelang_language_ternary.py
+44
-0
No files found.
tvm
@
b56420b3
Compare
87b845fa
...
b56420b3
Subproject commit
87b845fa0e14c2029bbf5799fbbbb9d490db4f20
Subproject commit
b56420b34277b6e257b0426eb78ecec1f1fb45fb
testing/python/language/test_tilelang_language_ternary.py
0 → 100644
View file @
15479958
import
tilelang
import
tilelang.language
as
T
import
torch
import
tilelang.testing
@
tilelang
.
jit
(
out_idx
=
[
1
],)
def
tilelang_ternary
(
M
,
N
,
block_M
,
block_N
,
dtype
=
"float16"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
# Initialize Kernel Context
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
block_M
,
block_N
):
B
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
=
(
A
[
by
*
block_M
+
i
,
bx
*
block_N
+
j
]
if
(
by
*
block_M
+
i
)
<
(
M
//
2
)
else
0
)
return
main
def
run_tilelang_ternary
(
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
,
dtype
=
"float16"
):
kernel
=
tilelang_ternary
(
M
,
N
,
block_M
,
block_N
,
dtype
)
a
=
torch
.
randn
(
M
,
N
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
dtype
))
b
=
kernel
(
a
)
ref_b
=
torch
.
zeros_like
(
b
)
for
i
in
range
(
M
):
for
j
in
range
(
N
):
if
i
<
M
//
2
:
ref_b
[
i
,
j
]
=
a
[
i
,
j
]
else
:
ref_b
[
i
,
j
]
=
0
torch
.
testing
.
assert_close
(
b
,
ref_b
,
rtol
=
1e-2
,
atol
=
1e-2
)
def
test_tilelang_ternary
():
run_tilelang_ternary
(
M
=
128
,
N
=
128
,
block_M
=
32
,
block_N
=
32
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment