Commit bf8a6fc1 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Deprecated `T.Buffer` as arguments and rename related calls into `T.Tensor` (#281)

* [Refactor] Improve flash attention example and layout comparison logic

- Removed unnecessary annotation for `lse_local_split` in the flash attention example to streamline the code.
- Updated the handling of `lse_local_split` to utilize parallel processing for better performance.
- Refactored kernel compilation and profiling logic to enhance clarity and maintainability in the flash attention example.
- Added a condition in `FragmentNode::IsEqual` to handle broadcast cases, improving the robustness of layout comparisons.

* lint fix

* [Enhancement] Add support for shared memory scope in Fill operation

- Introduced handling for `shared.dyn` and `shared` memory scopes in the Fill operation.
- Implemented parallel operation and layout inference for improved performance in shared memory scenarios.
- Updated thread loop partitioning and vectorization logic to accommodate new memory scope handling.

* [Refactor] Remove deprecated decorator and enhance Cython kernel handling

- Removed the deprecated decorator from the main module and added a new implementation in the utils module for better organization.
- Introduced a pointer map in the Cython kernel adapter to manage pointer arguments, improving runtime shape resolution.
- Updated the Cython kernel wrapper to utilize the new pointer map for handling kernel arguments.
- Enhanced error checking in the tensor utility functions to ensure static shapes are enforced.
- Added a new proxy module for buffer and tensor handling, streamlining the interface for TIR programs.

* [Feature] Add matrix multiplication test and kernel implementation

- Introduced a new test file `test_tilelang_language_ptr.py` that implements a matrix multiplication function using TileLang's primitives.
- The `matmul_test` function defines a kernel for performing tile-level GEMM operations with customizable block sizes and data types.
- Added a `run_matmul` function to compile and execute the kernel, along with a test function to validate the implementation.
- Updated the `proxy.py` file to enhance type handling for buffer and tensor proxies, ensuring compatibility with TIR programs.
- Minor formatting improvements in `deprecated.py` for better readability.

* lint fix

* [Refactor] Update tensor creation in matrix multiplication test

- Replaced `T.Tensor.from_ptr` with `T.make_tensor` in `matmul_test` for improved clarity and consistency.
- Updated imports in `__init__.py` to include `make_tensor`.
- Added `make_tensor` function in `proxy.py` to streamline tensor creation from pointers.

* [Refactor] Update tensor definitions across multiple files

- Replaced instances of `T.Tensor` with updated tensor definitions in various benchmark and example files to enhance consistency and clarity.
- Adjusted tensor shapes and types in functions related to matrix multiplication, attention mechanisms, and other operations.
- Improved documentation in README and example files to reflect changes in tensor usage.

* lint fix

* [Refactor] Update tensor types in attention and matrix multiplication examples

- Replaced instances of `T.Tensor` with `T.SharedTensor` and `T.FragmentTensor` in various attention and matrix multiplication functions to improve consistency and clarity.
- Adjusted tensor definitions in benchmark and example files to align with the new tensor types.
- Enhanced the overall structure and readability of the code by standardizing tensor usage across multiple files.

* lint fix

* [Refactor] Update tensor types in GEMM example and test files

- Replaced instances of `T.Tensor` with `T.LocalTensor` and `T.Buffer` in the GEMM example and related test functions to improve consistency and clarity.
- Enhanced the overall structure of the code by standardizing tensor usage across multiple files, aligning with recent updates in tensor definitions.

* [Refactor] Update tensor usage in customize.py

- Replaced instances of `T.Tensor` with `T.Buffer` in the `reshape` and `view` functions to enhance consistency with recent tensor definitions.
- Improved code clarity by standardizing buffer usage across the file.

* [Refactor] Update tensor types in test_tilelang_transform_annotate_device_regions.py

- Replaced instances of `T.Tensor` with `T.Buffer` in the `before` and `expected` methods of the `TestAnnotateThreadExtent` and `TestAnnotateDeviceScope` classes to enhance consistency with recent tensor definitions.
- Improved code clarity by standardizing buffer usage across the test file.

* [Refactor] Update tensor types to SharedBuffer and FragmentBuffer

- Replaced instances of `T.SharedTensor` and `T.FragmentTensor` with `T.SharedBuffer` and `T.FragmentBuffer` across multiple benchmark, example, and test files to enhance consistency with recent tensor definitions.
- Improved code clarity and structure by standardizing buffer usage in attention and matrix multiplication functions.

* [Refactor] Introduce Tensor alias for Buffer in proxy.py

- Added a new alias `Tensor` for `Buffer` in `proxy.py` to facilitate JIT compilation, ensuring that inputs and outputs are mapped with `torch.Tensor`.
- This change enhances clarity and consistency in tensor usage across the codebase.
parent 73d2c62e
...@@ -88,13 +88,13 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, ...@@ -88,13 +88,13 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
def kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads): def kernel_func(block_M, block_N, block_K, block_Dstate, num_stages, threads):
@T.prim_func @T.prim_func
def main(cb: T.Buffer((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), def main(cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype),
x: T.Buffer((batch, seqlen, nheads, headdim), dtype), dt: T.Buffer( x: T.Tensor((batch, seqlen, nheads, headdim), dtype), dt: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Buffer( (batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), C: T.Buffer( (batch, nheads, nchunks, chunk_size), dtype), C: T.Tensor(
(batch, seqlen, ngroups, dstate), dtype), prev_states: T.Buffer( (batch, seqlen, ngroups, dstate), dtype), prev_states: T.Tensor(
(batch, nchunks, nheads, headdim, dstate), dtype), D: T.Buffer( (batch, nchunks, nheads, headdim, dstate), dtype), D: T.Tensor(
(nheads), dtype), Output: T.Buffer( (nheads), dtype), Output: T.Tensor(
(batch, seqlen, nheads, headdim), dtype)): (batch, seqlen, nheads, headdim), dtype)):
with T.Kernel( with T.Kernel(
nheads, nheads,
......
...@@ -71,10 +71,10 @@ def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, ...@@ -71,10 +71,10 @@ def chunk_state_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
def kernel_func(block_M, block_N, block_K, num_stages, threads): def kernel_func(block_M, block_N, block_K, num_stages, threads):
@T.prim_func @T.prim_func
def main(B: T.Buffer((batch, seqlen, ngroups, dstate), dtype), x: T.Buffer( def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor(
(batch, seqlen, nheads, headdim), dtype), dt: T.Buffer( (batch, seqlen, nheads, headdim), dtype), dt: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Buffer( (batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), Output: T.Buffer( (batch, nheads, nchunks, chunk_size), dtype), Output: T.Tensor(
(batch, nchunks, nheads, headdim, dstate), dtype)): (batch, nchunks, nheads, headdim, dstate), dtype)):
with T.Kernel( with T.Kernel(
nheads, nheads,
......
...@@ -12,11 +12,11 @@ def retnet(batch, heads, seq_len, dim_qk, dim_v, block_M, block_N): ...@@ -12,11 +12,11 @@ def retnet(batch, heads, seq_len, dim_qk, dim_v, block_M, block_N):
@T.prim_func @T.prim_func
def main( def main(
Q: T.Buffer(qk_shape, dtype), Q: T.Tensor(qk_shape, dtype),
K: T.Buffer(qk_shape, dtype), K: T.Tensor(qk_shape, dtype),
V: T.Buffer(v_shape, dtype), V: T.Tensor(v_shape, dtype),
mask: T.Buffer([heads, seq_len, seq_len], dtype), mask: T.Tensor([heads, seq_len, seq_len], dtype),
Output: T.Buffer(v_shape, dtype), Output: T.Tensor(v_shape, dtype),
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128 * 2) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128 * 2) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype) Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
...@@ -118,13 +118,13 @@ def retnet_inference(batch, heads, dim_qk, dim_v, block_M): ...@@ -118,13 +118,13 @@ def retnet_inference(batch, heads, dim_qk, dim_v, block_M):
@T.prim_func @T.prim_func
def main( def main(
Q: T.Buffer(qk_shape, dtype), Q: T.Tensor(qk_shape, dtype),
K: T.Buffer(qk_shape, dtype), K: T.Tensor(qk_shape, dtype),
V: T.Buffer(v_shape, dtype), V: T.Tensor(v_shape, dtype),
prev_kv: T.Buffer([batch, heads, dim_v, dim_qk], dtype), prev_kv: T.Tensor([batch, heads, dim_v, dim_qk], dtype),
prev_scale: T.Buffer([heads], dtype), prev_scale: T.Tensor([heads], dtype),
decay: T.Buffer([heads], dtype), decay: T.Tensor([heads], dtype),
Output: T.Buffer([batch, heads, dim_v], dtype), Output: T.Tensor([batch, heads, dim_v], dtype),
): ):
with T.Kernel(T.ceildiv(dim_v, block_M), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(dim_v, block_M), heads, batch, threads=128) as (bx, by, bz):
Q_local = T.alloc_fragment([1, dim_qk], dtype) Q_local = T.alloc_fragment([1, dim_qk], dtype)
......
...@@ -476,11 +476,11 @@ def tilelang_sparse_attention(batch, ...@@ -476,11 +476,11 @@ def tilelang_sparse_attention(batch,
@T.prim_func @T.prim_func
def tilelang_sparse_attention( def tilelang_sparse_attention(
Q: T.Buffer(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Buffer(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Buffer(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
BlockIndices: T.Buffer(block_indices_shape, block_indices_dtype), BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
Output: T.Buffer(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
): ):
with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype) Q_shared = T.alloc_shared([G, BK], dtype)
......
...@@ -58,12 +58,12 @@ def tilelang_kernel_fwd( ...@@ -58,12 +58,12 @@ def tilelang_kernel_fwd(
@tilelang.jit @tilelang.jit
@T.prim_func @T.prim_func
def native_sparse_attention( def native_sparse_attention(
Q: T.Buffer(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Buffer(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Buffer(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
BlockIndices: T.Buffer(block_indices_shape, block_indices_dtype), BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
O_slc: T.Buffer(o_slc_shape, dtype), O_slc: T.Tensor(o_slc_shape, dtype),
LSE_slc: T.Buffer(lse_slc_shape, accum_dtype), LSE_slc: T.Tensor(lse_slc_shape, accum_dtype),
): ):
with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype) Q_shared = T.alloc_shared([G, BK], dtype)
...@@ -196,15 +196,15 @@ def tilelang_kernel_bwd_dkv( ...@@ -196,15 +196,15 @@ def tilelang_kernel_bwd_dkv(
@tilelang.jit @tilelang.jit
@T.prim_func @T.prim_func
def flash_bwd_dkv( def flash_bwd_dkv(
Q: T.Buffer(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Buffer(k_shape, dtype), K: T.Tensor(k_shape, dtype),
V: T.Buffer(v_shape, dtype), V: T.Tensor(v_shape, dtype),
LSE_slc: T.Buffer(lse_slc_shape, accum_dtype), LSE_slc: T.Tensor(lse_slc_shape, accum_dtype),
Delta_slc: T.Buffer(delta_slc_shape, accum_dtype), Delta_slc: T.Tensor(delta_slc_shape, accum_dtype),
DO_slc: T.Buffer(do_slc_shape, dtype), DO_slc: T.Tensor(do_slc_shape, dtype),
DK: T.Buffer(dk_shape, dtype), DK: T.Tensor(dk_shape, dtype),
DV: T.Buffer(dv_shape, dtype), DV: T.Tensor(dv_shape, dtype),
BlockMask: T.Buffer(block_mask_shape, "int32"), BlockMask: T.Tensor(block_mask_shape, "int32"),
): ):
with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
K_shared = T.alloc_shared([BS, BK], dtype) K_shared = T.alloc_shared([BS, BK], dtype)
...@@ -360,16 +360,16 @@ def tilelang_kernel_bwd_dqkv( ...@@ -360,16 +360,16 @@ def tilelang_kernel_bwd_dqkv(
@tilelang.jit @tilelang.jit
@T.prim_func @T.prim_func
def flash_bwd_dqkv( def flash_bwd_dqkv(
Q: T.Buffer(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Buffer(k_shape, dtype), K: T.Tensor(k_shape, dtype),
V: T.Buffer(v_shape, dtype), V: T.Tensor(v_shape, dtype),
LSE_slc: T.Buffer(lse_slc_shape, accum_dtype), LSE_slc: T.Tensor(lse_slc_shape, accum_dtype),
Delta_slc: T.Buffer(delta_slc_shape, accum_dtype), Delta_slc: T.Tensor(delta_slc_shape, accum_dtype),
DO_slc: T.Buffer(do_slc_shape, dtype), DO_slc: T.Tensor(do_slc_shape, dtype),
DQ: T.Buffer(dq_shape, dtype), DQ: T.Tensor(dq_shape, dtype),
DK: T.Buffer(dk_shape, dtype), DK: T.Tensor(dk_shape, dtype),
DV: T.Buffer(dv_shape, dtype), DV: T.Tensor(dv_shape, dtype),
BlockMask: T.Buffer(block_mask_shape, "int32"), BlockMask: T.Tensor(block_mask_shape, "int32"),
): ):
with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
K_shared = T.alloc_shared([BS, BK], dtype) K_shared = T.alloc_shared([BS, BK], dtype)
...@@ -489,9 +489,9 @@ def tilelang_kernel_preprocess( ...@@ -489,9 +489,9 @@ def tilelang_kernel_preprocess(
@tilelang.jit(out_idx=[2], execution_backend="cython") @tilelang.jit(out_idx=[2], execution_backend="cython")
@T.prim_func @T.prim_func
def flash_bwd_prep( def flash_bwd_prep(
O: T.Buffer(shape, dtype), # type: ignore O: T.Tensor(shape, dtype), # type: ignore
dO: T.Buffer(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Buffer([batch, seq_len, heads], accum_dtype), # type: ignore Delta: T.Tensor([batch, seq_len, heads], accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype) o = T.alloc_fragment([blk, blk], dtype)
...@@ -532,9 +532,9 @@ def tilelang_kernel_block_mask( ...@@ -532,9 +532,9 @@ def tilelang_kernel_block_mask(
@tilelang.jit(out_idx=[2], execution_backend="cython") @tilelang.jit(out_idx=[2], execution_backend="cython")
@T.prim_func @T.prim_func
def flash_bwd_block_mask( def flash_bwd_block_mask(
BlockIndices: T.Buffer(block_indices_shape, dtype), # type: ignore BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore
BlockCounts: T.Buffer(block_counts_shape, dtype), # type: ignore BlockCounts: T.Tensor(block_counts_shape, dtype), # type: ignore
BlockMask: T.Buffer(block_mask_shape, dtype), # type: ignore BlockMask: T.Tensor(block_mask_shape, dtype), # type: ignore
): ):
with T.Kernel(seq_len, batch, heads * S) as (bx, by, bz): with T.Kernel(seq_len, batch, heads * S) as (bx, by, bz):
i_t, i_b, i_hs = bx, by, bz i_t, i_b, i_hs = bx, by, bz
......
...@@ -44,12 +44,12 @@ def native_sparse_attention( ...@@ -44,12 +44,12 @@ def native_sparse_attention(
@T.prim_func @T.prim_func
def native_sparse_attention( def native_sparse_attention(
Q: T.Buffer(q_shape, dtype), # [batch, 1, heads, dim] Q: T.Tensor(q_shape, dtype), # [batch, 1, heads, dim]
K: T.Buffer(kv_shape, dtype), # [batch, seq_len, head_kv, dim] K: T.Tensor(kv_shape, dtype), # [batch, seq_len, head_kv, dim]
V: T.Buffer(kv_shape, dtype), # Same shape as K V: T.Tensor(kv_shape, dtype), # Same shape as K
BlockIndices: T.Buffer(block_indices_shape, BlockIndices: T.Tensor(block_indices_shape,
block_indices_dtype), # Selected block indices block_indices_dtype), # Selected block indices
Output: T.Buffer(q_shape, dtype), # Output attention tensor Output: T.Tensor(q_shape, dtype), # Output attention tensor
): ):
with T.Kernel(1, NV, batch * head_kv, threads=threads) as (bx, by, bz): with T.Kernel(1, NV, batch * head_kv, threads=threads) as (bx, by, bz):
# Shared memory allocations for tile storage # Shared memory allocations for tile storage
......
...@@ -45,11 +45,11 @@ def native_sparse_attention(batch, ...@@ -45,11 +45,11 @@ def native_sparse_attention(batch,
@T.prim_func @T.prim_func
def native_sparse_attention( def native_sparse_attention(
Q: T.Buffer(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Buffer(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Buffer(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
BlockIndices: T.Buffer(block_indices_shape, block_indices_dtype), BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
Output: T.Buffer(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
): ):
with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype) Q_shared = T.alloc_shared([G, BK], dtype)
......
...@@ -54,14 +54,14 @@ def native_sparse_attention_varlen(batch, ...@@ -54,14 +54,14 @@ def native_sparse_attention_varlen(batch,
@T.prim_func @T.prim_func
def native_sparse_attention_varlen( def native_sparse_attention_varlen(
Q: T.Buffer(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Buffer(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Buffer(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
O_slc: T.Buffer(o_slc_shape, dtype), O_slc: T.Tensor(o_slc_shape, dtype),
BlockIndices: T.Buffer(block_indices_shape, block_indices_dtype), BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
BlockCounts: T.Buffer(block_counts_shape, block_counts_dtype), BlockCounts: T.Tensor(block_counts_shape, block_counts_dtype),
Offsets: T.Buffer(offsets_shape, offsets_dtype), Offsets: T.Tensor(offsets_shape, offsets_dtype),
TokenIndices: T.Buffer(token_indices_shape, token_indices_dtype), TokenIndices: T.Tensor(token_indices_shape, token_indices_dtype),
): ):
with T.Kernel(c_seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz): with T.Kernel(c_seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype) Q_shared = T.alloc_shared([G, BK], dtype)
......
...@@ -7,7 +7,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k): ...@@ -7,7 +7,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k):
dtype = "float" dtype = "float"
@T.prim_func @T.prim_func
def main(A: T.Buffer((M, N), dtype), B: T.Buffer((M, N), dtype)): def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx:
A_shared = T.alloc_shared((blk_m, blk_k), dtype) A_shared = T.alloc_shared((blk_m, blk_k), dtype)
A_local = T.alloc_fragment((blk_m, blk_k), dtype) A_local = T.alloc_fragment((blk_m, blk_k), dtype)
...@@ -37,7 +37,7 @@ def rms_norm(M, N, blk_m): ...@@ -37,7 +37,7 @@ def rms_norm(M, N, blk_m):
dtype = "float" dtype = "float"
@T.prim_func @T.prim_func
def main(A: T.Buffer((M, N), dtype), B: T.Buffer((M, N), dtype)): def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx:
A_shared = T.alloc_shared((blk_m, N), dtype) A_shared = T.alloc_shared((blk_m, N), dtype)
A_local = T.alloc_fragment((blk_m, N), dtype) A_local = T.alloc_fragment((blk_m, N), dtype)
......
...@@ -12,9 +12,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -12,9 +12,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Buffer((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
...@@ -47,13 +47,13 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c ...@@ -47,13 +47,13 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -77,19 +77,19 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c ...@@ -77,19 +77,19 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Buffer(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Buffer(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Buffer(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
BlockSparseMask: T.Buffer(block_mask_shape, block_mask_dtype), BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Buffer(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
): ):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
......
...@@ -217,9 +217,9 @@ def matmul(M, N, K, with_roller): ...@@ -217,9 +217,9 @@ def matmul(M, N, K, with_roller):
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Buffer((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
""" """
The compiled TVM function for block-level matrix multiplication. The compiled TVM function for block-level matrix multiplication.
......
...@@ -82,9 +82,9 @@ def tl_matmul( ...@@ -82,9 +82,9 @@ def tl_matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
......
...@@ -27,7 +27,7 @@ def matmul( ...@@ -27,7 +27,7 @@ def matmul(
vec_size = 4 * k_pack vec_size = 4 * k_pack
@T.prim_func @T.prim_func
def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor(
(M, N), out_dtype)): (M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......
...@@ -191,9 +191,9 @@ def matmul(M, N, K, with_roller): ...@@ -191,9 +191,9 @@ def matmul(M, N, K, with_roller):
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Buffer((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
""" """
The compiled TVM function for block-level matrix multiplication. The compiled TVM function for block-level matrix multiplication.
......
...@@ -195,9 +195,9 @@ def matmul(M, N, K, with_roller): ...@@ -195,9 +195,9 @@ def matmul(M, N, K, with_roller):
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Buffer((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
""" """
The compiled TVM function for block-level matrix multiplication. The compiled TVM function for block-level matrix multiplication.
......
...@@ -28,9 +28,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -28,9 +28,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Buffer((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
......
...@@ -10,9 +10,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -10,9 +10,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
@T.prim_func @T.prim_func
def matmul( def matmul(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Buffer((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by):
A_local = T.alloc_local((block_M, block_K), dtype) A_local = T.alloc_local((block_M, block_K), dtype)
...@@ -66,9 +66,9 @@ def test_matmul_compile(): ...@@ -66,9 +66,9 @@ def test_matmul_compile():
# a simple kernel just for jit test # a simple kernel just for jit test
@T.prim_func @T.prim_func
def matmul( def matmul(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Buffer((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), is_cpu=True) as (bx, by):
A_local = T.alloc_local((block_M, block_K), dtype) A_local = T.alloc_local((block_M, block_K), dtype)
......
...@@ -8,7 +8,7 @@ import tilelang.language as T ...@@ -8,7 +8,7 @@ import tilelang.language as T
def debug_print_buffer(M=16, N=16, dtype="float16"): def debug_print_buffer(M=16, N=16, dtype="float16"):
@T.prim_func @T.prim_func
def program(Q: T.Buffer((M, N), dtype)): def program(Q: T.Tensor((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
shared_buf = T.alloc_shared([M, N], dtype) shared_buf = T.alloc_shared([M, N], dtype)
T.print(shared_buf) T.print(shared_buf)
...@@ -28,7 +28,7 @@ def debug_print_buffer_conditional(M=16, N=16): ...@@ -28,7 +28,7 @@ def debug_print_buffer_conditional(M=16, N=16):
dtype = "float16" dtype = "float16"
@T.prim_func @T.prim_func
def program(Q: T.Buffer((M, N), dtype)): def program(Q: T.Tensor((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
shared_buf = T.alloc_shared([M, N], dtype) shared_buf = T.alloc_shared([M, N], dtype)
...@@ -48,7 +48,7 @@ def debug_print_value_conditional(M=16, N=16): ...@@ -48,7 +48,7 @@ def debug_print_value_conditional(M=16, N=16):
dtype = "float16" dtype = "float16"
@T.prim_func @T.prim_func
def program(Q: T.Buffer((M, N), dtype)): def program(Q: T.Tensor((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
tid = T.get_thread_binding() tid = T.get_thread_binding()
if tid == 0: if tid == 0:
...@@ -67,7 +67,7 @@ def debug_print_register_files(M=16, N=16): ...@@ -67,7 +67,7 @@ def debug_print_register_files(M=16, N=16):
dtype = "float16" dtype = "float16"
@T.prim_func @T.prim_func
def program(Q: T.Buffer((M, N), dtype)): def program(Q: T.Tensor((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
register_buf = T.alloc_fragment([M, N], dtype) register_buf = T.alloc_fragment([M, N], dtype)
for i, j in T.Parallel(M, N): for i, j in T.Parallel(M, N):
...@@ -86,7 +86,7 @@ def debug_print_msg(M=16, N=16): ...@@ -86,7 +86,7 @@ def debug_print_msg(M=16, N=16):
dtype = "float16" dtype = "float16"
@T.prim_func @T.prim_func
def program(Q: T.Buffer((M, N), dtype)): def program(Q: T.Tensor((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
tid = T.get_thread_binding() tid = T.get_thread_binding()
if tid == 0: if tid == 0:
......
...@@ -96,9 +96,9 @@ def tl_matmul_macro( ...@@ -96,9 +96,9 @@ def tl_matmul_macro(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
...@@ -207,7 +207,7 @@ def tl_matmul_block( ...@@ -207,7 +207,7 @@ def tl_matmul_block(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func @T.prim_func
def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor(
(M, N), out_dtype)): (M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -306,7 +306,7 @@ def tl_matmul_block_all_dynamic( ...@@ -306,7 +306,7 @@ def tl_matmul_block_all_dynamic(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func @T.prim_func
def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor(
(M, N), out_dtype)): (M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment