Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
94c941fc
Commit
94c941fc
authored
Mar 12, 2025
by
_HYX_
Committed by
GitHub
Mar 12, 2025
Browse files
[Language] Support clamp in language (#192)
* [Dev] Support clamp in language. * [Bugfix]: Fix clamp * [Refactor]
parent
efb2b1d5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
123 additions
and
0 deletions
+123
-0
testing/python/language/test_tilelang_language_clamp.py
testing/python/language/test_tilelang_language_clamp.py
+115
-0
tilelang/language/__init__.py
tilelang/language/__init__.py
+1
-0
tilelang/language/customize.py
tilelang/language/customize.py
+7
-0
No files found.
testing/python/language/test_tilelang_language_clamp.py
0 → 100644
View file @
94c941fc
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from
tilelang
import
tvm
as
tvm
import
tilelang.testing
import
tilelang
as
tl
def
clamp
(
N
,
block_N
,
dtype
,
min_val
=
None
,
max_val
=
None
,
):
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
((
N
,),
dtype
),
B
:
T
.
Buffer
((
N
,),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
A_shared
=
T
.
alloc_shared
([
block_N
],
dtype
)
T
.
copy
(
A
[
bx
*
block_N
],
A_shared
)
for
i
in
T
.
Parallel
(
block_N
):
A_shared
[
i
]
=
T
.
clamp
(
A_shared
[
i
],
min_val
=
min_val
,
max_val
=
max_val
)
T
.
copy
(
A_shared
,
B
[
bx
*
block_N
])
return
main
def
run_clamp
(
N
,
block_N
,
dtype
,
min
=
None
,
max
=
None
,
):
program
=
clamp
(
N
,
block_N
,
dtype
,
min
,
max
)
mod
,
params
=
tl
.
lower
(
program
)
profiler
=
tl
.
Profiler
(
mod
,
params
,
[
1
],
tl
.
TensorSupplyType
.
Integer
)
def
ref_program
(
A
):
import
torch
output
=
torch
.
clamp
(
A
,
min
,
max
)
return
output
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
clamp_v2
(
N
,
block_N
,
dtype
,
):
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
((
1
,
N
),
dtype
),
B
:
T
.
Buffer
((
1
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
threads
=
block_N
)
as
bx
:
# A_shared = T.alloc_shared([1, block_N], dtype=dtype)
A_frag
=
T
.
alloc_fragment
([
1
,
block_N
],
dtype
=
dtype
)
min_frag
=
T
.
alloc_fragment
([
1
],
dtype
=
"float32"
)
max_frag
=
T
.
alloc_fragment
([
1
],
dtype
=
"float32"
)
T
.
copy
(
A
[
0
,
bx
*
block_N
],
A_frag
)
T
.
reduce_min
(
A_frag
,
min_frag
,
dim
=
1
)
T
.
reduce_max
(
A_frag
,
max_frag
,
dim
=
1
)
for
i
in
T
.
Parallel
(
block_N
):
# A_frag[0, i] = T.max(A_frag[0, i], min_frag[0] * 0.5)
# A_frag[0, i] = T.min(A_frag[0, i], max_frag[0] * 0.5)
A_frag
[
0
,
i
]
=
T
.
clamp
(
A_frag
[
0
,
i
],
min_frag
[
0
]
*
0.5
,
max_frag
[
0
]
*
0.5
)
T
.
copy
(
A_frag
,
B
[
0
,
bx
*
block_N
])
return
main
def
run_clamp_v2
(
N
,
block_N
,
dtype
,
):
program
=
clamp_v2
(
N
,
block_N
,
dtype
,
)
mod
,
params
=
tl
.
lower
(
program
)
profiler
=
tl
.
Profiler
(
mod
,
params
,
[
1
],
tl
.
TensorSupplyType
.
Integer
)
def
ref_program
(
A
):
import
torch
min_val
=
torch
.
min
(
A
)
*
0.5
max_val
=
torch
.
max
(
A
)
*
0.5
output
=
torch
.
clamp
(
A
,
min_val
,
max_val
)
return
output
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_clamp
():
# clamp tests for float16 and float32
run_clamp
(
1024
,
128
,
"float16"
,
-
0.05
,
0.05
)
run_clamp
(
1024
,
128
,
"float32"
,
-
0.06
,
0.05
)
run_clamp_v2
(
1024
,
128
,
"float16"
)
run_clamp_v2
(
1024
,
128
,
"float32"
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang/language/__init__.py
View file @
94c941fc
...
@@ -31,6 +31,7 @@ from .customize import (
...
@@ -31,6 +31,7 @@ from .customize import (
atomic_add
,
# noqa: F401
atomic_add
,
# noqa: F401
atomic_addx2
,
# noqa: F401
atomic_addx2
,
# noqa: F401
dp4a
,
# noqa: F401
dp4a
,
# noqa: F401
clamp
,
# noqa: F401
)
)
from
.builtin
import
*
# noqa: F401
from
.builtin
import
*
# noqa: F401
...
...
tilelang/language/customize.py
View file @
94c941fc
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
"""The language interface for tl programs."""
"""The language interface for tl programs."""
from
tvm.script
import
tir
as
T
from
tvm.script
import
tir
as
T
from
tvm.tir
import
PrimExpr
def
atomic_add
(
dst
,
value
):
def
atomic_add
(
dst
,
value
):
...
@@ -15,3 +16,9 @@ def atomic_addx2(dst, value):
...
@@ -15,3 +16,9 @@ def atomic_addx2(dst, value):
def
dp4a
(
A
,
B
,
C
):
def
dp4a
(
A
,
B
,
C
):
return
T
.
call_extern
(
"handle"
,
"DP4A"
,
T
.
address_of
(
A
),
T
.
address_of
(
B
),
T
.
address_of
(
C
))
return
T
.
call_extern
(
"handle"
,
"DP4A"
,
T
.
address_of
(
A
),
T
.
address_of
(
B
),
T
.
address_of
(
C
))
def
clamp
(
dst
,
min_val
:
PrimExpr
,
max_val
:
PrimExpr
):
dst
=
T
.
max
(
dst
,
min_val
)
dst
=
T
.
min
(
dst
,
max_val
)
return
dst
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment