Commit 94c941fc authored by _HYX_'s avatar _HYX_ Committed by GitHub
Browse files

[Language] Support clamp in language (#192)

* [Dev] Support clamp in language.

* [Bugfix]: Fix clamp

* [Refactor]
parent efb2b1d5
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
def clamp(
N,
block_N,
dtype,
min_val=None,
max_val=None,
):
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer((N,), dtype),
B: T.Buffer((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared([block_N], dtype)
T.copy(A[bx * block_N], A_shared)
for i in T.Parallel(block_N):
A_shared[i] = T.clamp(A_shared[i], min_val=min_val, max_val=max_val)
T.copy(A_shared, B[bx * block_N])
return main
def run_clamp(
N,
block_N,
dtype,
min=None,
max=None,
):
program = clamp(N, block_N, dtype, min, max)
mod, params = tl.lower(program)
profiler = tl.Profiler(mod, params, [1], tl.TensorSupplyType.Integer)
def ref_program(A):
import torch
output = torch.clamp(A, min, max)
return output
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def clamp_v2(
N,
block_N,
dtype,
):
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer((1, N), dtype),
B: T.Buffer((1, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
# A_shared = T.alloc_shared([1, block_N], dtype=dtype)
A_frag = T.alloc_fragment([1, block_N], dtype=dtype)
min_frag = T.alloc_fragment([1], dtype="float32")
max_frag = T.alloc_fragment([1], dtype="float32")
T.copy(A[0, bx * block_N], A_frag)
T.reduce_min(A_frag, min_frag, dim=1)
T.reduce_max(A_frag, max_frag, dim=1)
for i in T.Parallel(block_N):
# A_frag[0, i] = T.max(A_frag[0, i], min_frag[0] * 0.5)
# A_frag[0, i] = T.min(A_frag[0, i], max_frag[0] * 0.5)
A_frag[0, i] = T.clamp(A_frag[0, i], min_frag[0] * 0.5, max_frag[0] * 0.5)
T.copy(A_frag, B[0, bx * block_N])
return main
def run_clamp_v2(
N,
block_N,
dtype,
):
program = clamp_v2(
N,
block_N,
dtype,
)
mod, params = tl.lower(program)
profiler = tl.Profiler(mod, params, [1], tl.TensorSupplyType.Integer)
def ref_program(A):
import torch
min_val = torch.min(A) * 0.5
max_val = torch.max(A) * 0.5
output = torch.clamp(A, min_val, max_val)
return output
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_clamp():
# clamp tests for float16 and float32
run_clamp(1024, 128, "float16", -0.05, 0.05)
run_clamp(1024, 128, "float32", -0.06, 0.05)
run_clamp_v2(1024, 128, "float16")
run_clamp_v2(1024, 128, "float32")
if __name__ == "__main__":
tilelang.testing.main()
...@@ -31,6 +31,7 @@ from .customize import ( ...@@ -31,6 +31,7 @@ from .customize import (
atomic_add, # noqa: F401 atomic_add, # noqa: F401
atomic_addx2, # noqa: F401 atomic_addx2, # noqa: F401
dp4a, # noqa: F401 dp4a, # noqa: F401
clamp, # noqa: F401
) )
from .builtin import * # noqa: F401 from .builtin import * # noqa: F401
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
"""The language interface for tl programs.""" """The language interface for tl programs."""
from tvm.script import tir as T from tvm.script import tir as T
from tvm.tir import PrimExpr
def atomic_add(dst, value): def atomic_add(dst, value):
...@@ -15,3 +16,9 @@ def atomic_addx2(dst, value): ...@@ -15,3 +16,9 @@ def atomic_addx2(dst, value):
def dp4a(A, B, C): def dp4a(A, B, C):
return T.call_extern("handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C)) return T.call_extern("handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C))
def clamp(dst, min_val: PrimExpr, max_val: PrimExpr):
dst = T.max(dst, min_val)
dst = T.min(dst, max_val)
return dst
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment