"model/models/vscode:/vscode.git/clone" did not exist on "60829f7ec6ba12f8b06aa917bdba26c82f054e1f"
Commit 7ae35298 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Rename clamp functions and enhance dtype handling in tests (#232)

- Renamed `clamp` to `clamp_within_bounds` and `clamp_v2` to `clamp_value_range` for improved clarity.
- Updated dtype handling in `clamp_value_range` to use the correct tensor dtype.
- Modified test cases to reflect the new function names and ensure proper dtype conversion using `map_torch_type`.
- Enhanced the reference program for clamping to utilize the updated tensor dtype, improving accuracy in tests.
parent fe5f7b3b
import tilelang.testing import tilelang.testing
from tilelang.utils.tensor import map_torch_type
def clamp_within_bounds(
def clamp(
N, N,
block_N, block_N,
dtype, dtype,
...@@ -32,7 +32,7 @@ def run_clamp( ...@@ -32,7 +32,7 @@ def run_clamp(
min=None, min=None,
max=None, max=None,
): ):
program = clamp(N, block_N, dtype, min, max) program = clamp_within_bounds(N, block_N, dtype, min, max)
kernel = tilelang.compile(program, out_idx=[1]) kernel = tilelang.compile(program, out_idx=[1])
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
...@@ -46,7 +46,7 @@ def run_clamp( ...@@ -46,7 +46,7 @@ def run_clamp(
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def clamp_v2( def clamp_value_range(
N, N,
block_N, block_N,
dtype, dtype,
...@@ -61,8 +61,8 @@ def clamp_v2( ...@@ -61,8 +61,8 @@ def clamp_v2(
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
# A_shared = T.alloc_shared([1, block_N], dtype=dtype) # A_shared = T.alloc_shared([1, block_N], dtype=dtype)
A_frag = T.alloc_fragment([1, block_N], dtype=dtype) A_frag = T.alloc_fragment([1, block_N], dtype=dtype)
min_frag = T.alloc_fragment([1], dtype="float32") min_frag = T.alloc_fragment([1], dtype=dtype)
max_frag = T.alloc_fragment([1], dtype="float32") max_frag = T.alloc_fragment([1], dtype=dtype)
T.copy(A[0, bx * block_N], A_frag) T.copy(A[0, bx * block_N], A_frag)
T.reduce_min(A_frag, min_frag, dim=1) T.reduce_min(A_frag, min_frag, dim=1)
T.reduce_max(A_frag, max_frag, dim=1) T.reduce_max(A_frag, max_frag, dim=1)
...@@ -75,35 +75,41 @@ def clamp_v2( ...@@ -75,35 +75,41 @@ def clamp_v2(
return main return main
def run_clamp_v2( def run_clamp_value_range(
N, N,
block_N, block_N,
dtype, dtype,
): ):
program = clamp_v2( program = clamp_value_range(
N, N,
block_N, block_N,
dtype, dtype,
) )
kernel = tilelang.compile(program, out_idx=[1]) kernel = tilelang.compile(program, out_idx=[1])
profiler = kernel.get_profiler()
def ref_program(A):
import torch import torch
# Convert string dtype to torch.dtype
torch_dtype = map_torch_type(dtype)
def ref_program(A):
min_val = torch.min(A) * 0.5 min_val = torch.min(A) * 0.5
max_val = torch.max(A) * 0.5 max_val = torch.max(A) * 0.5
output = torch.clamp(A, min_val, max_val) output = torch.clamp(A, min_val, max_val)
return output return output
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) A = torch.randint(-5, 5, (1, N)).cuda().to(dtype=torch_dtype)
B = kernel(A)
ref_b = ref_program(A)
torch.testing.assert_close(B, ref_b)
def test_clamp(): def test_clamp():
# clamp tests for float16 and float32 # clamp tests for float16 and float32
run_clamp(1024, 128, "float16", -0.05, 0.05) run_clamp(1024, 128, "float16", -0.05, 0.05)
run_clamp(1024, 128, "float32", -0.06, 0.05) run_clamp(1024, 128, "float32", -0.06, 0.05)
run_clamp_v2(1024, 128, "float16") run_clamp_value_range(1024, 128, "float16")
run_clamp_v2(1024, 128, "float32") run_clamp_value_range(1024, 128, "float32")
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment