"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "537d37c26c619a9aea48bcef44c1d8f45d6d7b1a"
Commit 872f5613 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Language] Introduce `T.reshape` and `T.view` (#212)

* [Feature] Add reshape and view functionalities to tilelang language module

- Introduced new test files for reshape and view operations in the tilelang language.
- Implemented reshape and view functions in the customize module, enhancing buffer manipulation capabilities.
- Updated the language initialization to include the new functionalities.
- Removed unnecessary import from test_tilelang_language_clamp.py for cleaner code.

* Update copyright to Tile-AI Corporation

* [Refactor] Clean up whitespace in test files for reshape and view functionalities

- Removed unnecessary blank lines in `test_tilelang_language_reshape.py` and `test_tilelang_language_view.py` for improved readability.
- Ensured consistent formatting across test files to enhance code clarity.
parent 86f96f8a
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang as tl import tilelang as tl
......
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
def reshape_test(N, M, dtype):
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer((N,), dtype),
B: T.Buffer((N // M, M), dtype),
):
with T.Kernel(1) as _:
A_reshaped = T.reshape(A, [N // M, M])
T.copy(A_reshaped, B)
return main
def run_reshape(N, M, dtype):
program = reshape_test(N, M, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
profiler = jit_kernel.get_profiler()
def ref_program(A):
return A.reshape(N // M, M)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reshape_smem():
# Test reshape
run_reshape(1024, 32, "float32")
run_reshape(2048, 64, "float16")
def reshape_test_smem(N, M, dtype):
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer((N,), dtype),
B: T.Buffer((N // M, M), dtype),
):
with T.Kernel(1) as _:
A_shared = T.alloc_shared((N,), dtype)
for i in range(N):
A_shared[i] = A[i]
A_smem_reshaped = T.reshape(A_shared, [N // M, M])
for i in range(N // M):
for j in range(M):
B[i, j] = A_smem_reshaped[i, j]
return main
def run_reshape_smem(N, M, dtype):
program = reshape_test_smem(N, M, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
profiler = jit_kernel.get_profiler()
def ref_program(A):
return A.reshape(N // M, M)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reshape_smem_shared():
run_reshape_smem(1024, 32, "float32")
run_reshape_smem(2048, 64, "float16")
if __name__ == "__main__":
tilelang.testing.main()
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
def view_test(N, M, dtype, new_dtype=None):
import tilelang.language as T
new_shape = [N // M, M]
if new_dtype:
from tvm import DataType
dtype_src = DataType(dtype)
dtype_dst = DataType(new_dtype)
src_bits = dtype_src.bits
dst_bits = dtype_dst.bits
scale = src_bits / dst_bits
new_shape[-1] = int(M * scale)
@T.prim_func
def main(
A: T.Buffer((N,), dtype),
B: T.Buffer(new_shape, new_dtype if new_dtype else dtype),
):
with T.Kernel(1) as _:
A_viewed = T.view(A, new_shape, dtype=new_dtype)
T.copy(A_viewed, B)
return main
def run_view(N, M, dtype, new_dtype=None):
program = view_test(N, M, dtype, new_dtype)
jit_kernel = tl.compile(program, out_idx=-1)
profiler = jit_kernel.get_profiler()
def ref_program(A):
if new_dtype:
from tilelang.utils.tensor import map_torch_type
torch_dtype = map_torch_type(new_dtype)
return A.view(N // M, M).view(dtype=torch_dtype)
return A.view(N // M, M)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reshape_view():
# Test view with same dtype
run_view(1024, 32, "float32")
run_view(2048, 64, "float16")
# Test view with dtype conversion
run_view(1024, 32, "float32", "float16")
run_view(2048, 64, "float16", "float32")
if __name__ == "__main__":
tilelang.testing.main()
...@@ -75,6 +75,8 @@ class TLCUDASourceWrapper(object): ...@@ -75,6 +75,8 @@ class TLCUDASourceWrapper(object):
} }
backend = "tl" backend = "tl"
device_mod: Optional[IRModule] = None
host_mod: Optional[IRModule] = None
def __init__(self, def __init__(self,
scheduled_ir_module: IRModule, scheduled_ir_module: IRModule,
...@@ -108,7 +110,7 @@ class TLCUDASourceWrapper(object): ...@@ -108,7 +110,7 @@ class TLCUDASourceWrapper(object):
if param in self.prim_func.buffer_map: if param in self.prim_func.buffer_map:
buffer = self.prim_func.buffer_map[param] buffer = self.prim_func.buffer_map[param]
function_args.append({ function_args.append({
"name": buffer.name, "name": buffer.data.name,
"type": self._TYPE_MAP[buffer.dtype] + "* __restrict__", "type": self._TYPE_MAP[buffer.dtype] + "* __restrict__",
}) })
elif isinstance(param, tvm.tir.Var): elif isinstance(param, tvm.tir.Var):
...@@ -243,8 +245,11 @@ class TLCUDASourceWrapper(object): ...@@ -243,8 +245,11 @@ class TLCUDASourceWrapper(object):
def parse_source_information(self): def parse_source_information(self):
with tvm.transform.PassContext(opt_level=3, config=self.pass_configs): with tvm.transform.PassContext(opt_level=3, config=self.pass_configs):
device_mod, host_mod = get_annotated_mod(self.mod, self.target) device_mod, host_mod = get_annotated_mod(self.mod, self.target)
assert (len(device_mod.functions) >= 1), "Device module should have at least one function." assert (len(device_mod.functions) >= 1), "Device module should have at least one function."
assert (len(host_mod.functions) == 1), "Only support one function in host module." assert (len(host_mod.functions) == 1), "Only support one function in host module."
self.device_mod = device_mod
self.host_mod = host_mod
block_info_map = {} block_info_map = {}
grid_info_map = {} grid_info_map = {}
......
...@@ -32,6 +32,8 @@ from .customize import ( ...@@ -32,6 +32,8 @@ from .customize import (
atomic_addx2, # noqa: F401 atomic_addx2, # noqa: F401
dp4a, # noqa: F401 dp4a, # noqa: F401
clamp, # noqa: F401 clamp, # noqa: F401
reshape, # noqa: F401
view, # noqa: F401
) )
from .builtin import * # noqa: F401 from .builtin import * # noqa: F401
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The language interface for tl programs.""" """The language interface for tl programs."""
from tvm.script import tir as T from tvm.script import tir as T
from tvm.tir import PrimExpr from tvm.tir import PrimExpr, Buffer
from typing import List, Union
def atomic_add(dst, value): def atomic_add(dst: Buffer, value: PrimExpr) -> PrimExpr:
return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value) return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value)
def atomic_addx2(dst, value): def atomic_addx2(dst: Buffer, value: PrimExpr) -> PrimExpr:
return T.call_extern("handle", "AtomicAddx2", T.address_of(dst), T.address_of(value)) return T.call_extern("handle", "AtomicAddx2", T.address_of(dst), T.address_of(value))
def dp4a(A, B, C): def dp4a(A: Buffer, B: Buffer, C: Buffer) -> PrimExpr:
return T.call_extern("handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C)) return T.call_extern("handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C))
def clamp(dst, min_val: PrimExpr, max_val: PrimExpr): def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr:
dst = T.max(dst, min_val) """Clamps the input value dst between [min_val, max_val]
dst = T.min(dst, max_val)
Args:
dst: Input value to be clamped
min_val: Minimum value
max_val: Maximum value
Returns:
Value clamped to the specified range
"""
dst = T.max(dst, min_val) # Ensure value is not less than minimum
dst = T.min(dst, max_val) # Ensure value is not greater than maximum
return dst return dst
def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer:
"""Reshapes the input buffer to the specified shape.
Args:
src: Input buffer to be reshaped
shape: New shape for the buffer
"""
return T.Buffer(shape, src.dtype, src.data)
def view(src: Buffer,
shape: Union[List[PrimExpr], None] = None,
dtype: Union[str, None] = None) -> Buffer:
"""Views the input buffer to the specified shape.
Args:
src: Input buffer to be viewed
shape: New shape for the buffer
dtype: New dtype for the buffer
"""
if shape is None:
shape = src.shape
if dtype is None:
dtype = src.dtype
return T.Buffer(shape, dtype, src.data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment