Commit bf8a6fc1 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Deprecated `T.Buffer` as arguments and rename related calls into `T.Tensor` (#281)

* [Refactor] Improve flash attention example and layout comparison logic

- Removed unnecessary annotation for `lse_local_split` in the flash attention example to streamline the code.
- Updated the handling of `lse_local_split` to utilize parallel processing for better performance.
- Refactored kernel compilation and profiling logic to enhance clarity and maintainability in the flash attention example.
- Added a condition in `FragmentNode::IsEqual` to handle broadcast cases, improving the robustness of layout comparisons.

* lint fix

* [Enhancement] Add support for shared memory scope in Fill operation

- Introduced handling for `shared.dyn` and `shared` memory scopes in the Fill operation.
- Implemented parallel operation and layout inference for improved performance in shared memory scenarios.
- Updated thread loop partitioning and vectorization logic to accommodate new memory scope handling.

* [Refactor] Remove deprecated decorator and enhance Cython kernel handling

- Removed the deprecated decorator from the main module and added a new implementation in the utils module for better organization.
- Introduced a pointer map in the Cython kernel adapter to manage pointer arguments, improving runtime shape resolution.
- Updated the Cython kernel wrapper to utilize the new pointer map for handling kernel arguments.
- Enhanced error checking in the tensor utility functions to ensure static shapes are enforced.
- Added a new proxy module for buffer and tensor handling, streamlining the interface for TIR programs.

* [Feature] Add matrix multiplication test and kernel implementation

- Introduced a new test file `test_tilelang_language_ptr.py` that implements a matrix multiplication function using TileLang's primitives.
- The `matmul_test` function defines a kernel for performing tile-level GEMM operations with customizable block sizes and data types.
- Added a `run_matmul` function to compile and execute the kernel, along with a test function to validate the implementation.
- Updated the `proxy.py` file to enhance type handling for buffer and tensor proxies, ensuring compatibility with TIR programs.
- Minor formatting improvements in `deprecated.py` for better readability.

* lint fix

* [Refactor] Update tensor creation in matrix multiplication test

- Replaced `T.Tensor.from_ptr` with `T.make_tensor` in `matmul_test` for improved clarity and consistency.
- Updated imports in `__init__.py` to include `make_tensor`.
- Added `make_tensor` function in `proxy.py` to streamline tensor creation from pointers.

* [Refactor] Update tensor definitions across multiple files

- Replaced instances of `T.Tensor` with updated tensor definitions in various benchmark and example files to enhance consistency and clarity.
- Adjusted tensor shapes and types in functions related to matrix multiplication, attention mechanisms, and other operations.
- Improved documentation in README and example files to reflect changes in tensor usage.

* lint fix

* [Refactor] Update tensor types in attention and matrix multiplication examples

- Replaced instances of `T.Tensor` with `T.SharedTensor` and `T.FragmentTensor` in various attention and matrix multiplication functions to improve consistency and clarity.
- Adjusted tensor definitions in benchmark and example files to align with the new tensor types.
- Enhanced the overall structure and readability of the code by standardizing tensor usage across multiple files.

* lint fix

* [Refactor] Update tensor types in GEMM example and test files

- Replaced instances of `T.Tensor` with `T.LocalTensor` and `T.Buffer` in the GEMM example and related test functions to improve consistency and clarity.
- Enhanced the overall structure of the code by standardizing tensor usage across multiple files, aligning with recent updates in tensor definitions.

* [Refactor] Update tensor usage in customize.py

- Replaced instances of `T.Tensor` with `T.Buffer` in the `reshape` and `view` functions to enhance consistency with recent tensor definitions.
- Improved code clarity by standardizing buffer usage across the file.

* [Refactor] Update tensor types in test_tilelang_transform_annotate_device_regions.py

- Replaced instances of `T.Tensor` with `T.Buffer` in the `before` and `expected` methods of the `TestAnnotateThreadExtent` and `TestAnnotateDeviceScope` classes to enhance consistency with recent tensor definitions.
- Improved code clarity by standardizing buffer usage across the test file.

* [Refactor] Update tensor types to SharedBuffer and FragmentBuffer

- Replaced instances of `T.SharedTensor` and `T.FragmentTensor` with `T.SharedBuffer` and `T.FragmentBuffer` across multiple benchmark, example, and test files to enhance consistency with recent tensor definitions.
- Improved code clarity and structure by standardizing buffer usage in attention and matrix multiplication functions.

* [Refactor] Introduce Tensor alias for Buffer in proxy.py

- Added a new alias `Tensor` for `Buffer` in `proxy.py` to facilitate JIT compilation, ensuring that inputs and outputs are mapped with `torch.Tensor`.
- This change enhances clarity and consistency in tensor usage across the codebase.
parent 73d2c62e
......@@ -129,9 +129,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
......@@ -51,10 +51,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@T.macro
def MMA0(
K: T.Buffer(shape, dtype),
Q_shared: T.Buffer([block_M, dim], dtype),
K_shared: T.Buffer([block_N, dim], dtype),
acc_s: T.Buffer([block_M, block_N], accum_dtype),
K: T.Tensor(shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
......@@ -71,10 +71,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@T.macro
def MMA1(
V: T.Buffer(shape, dtype),
V_shared: T.Buffer([block_M, dim], dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype),
acc_o: T.Buffer([block_M, dim], accum_dtype),
V: T.Tensor(shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
......@@ -84,13 +84,13 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@T.macro
def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -114,19 +114,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@T.macro
def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Buffer(shape, dtype),
K: T.Buffer(shape, dtype),
V: T.Buffer(shape, dtype),
BlockSparseMask: T.Buffer(block_mask_shape, block_mask_dtype),
Output: T.Buffer(shape, dtype),
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
......
......@@ -217,9 +217,9 @@ def matmul(M, N, K, with_roller):
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
......
......@@ -103,9 +103,9 @@ def tl_matmul(
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
......
......@@ -27,7 +27,7 @@ Please note that this tutorial does not delve deeply into the design principles
def elementwise_add(N, threads=256, dtype="bfloat16"):
@T.prim_func
def main(A: T.Buffer((N), dtype), B: T.Buffer((N), dtype), C: T.Buffer((N), dtype)):
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
with T.Kernel(T.ceildiv(N, threads), threads=threads) as (b_x):
# vector add.
for i in T.Parallel(threads):
......@@ -67,9 +67,9 @@ def elementwise_add(
):
@T.prim_func
def main(
A: T.Buffer((M, N), in_dtype),
B: T.Buffer((M, N), in_dtype),
C: T.Buffer((M, N), out_dtype),
A: T.Tensor((M, N), in_dtype),
B: T.Tensor((M, N), in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
start_x = bx * block_N
......@@ -105,7 +105,7 @@ When compiling the example below, let's set `N` to 2047:
def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"):
@T.prim_func
def main(A: T.Buffer((N), dtype), B: T.Buffer((N), dtype), C: T.Buffer((N), dtype)):
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
with T.Kernel(T.ceildiv(N, threads * num_per_thread), threads=threads) as (b_x):
# vector add.
for i, j in T.Parallel(threads, num_per_thread):
......@@ -179,7 +179,7 @@ In such scenarios, explicitly specifying the number of elements computed per thr
def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"):
@T.prim_func
def main(A: T.Buffer((N), dtype), B: T.Buffer((N), dtype), C: T.Buffer((N), dtype)):
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
with T.Kernel(T.ceildiv(N, threads * num_per_thread), threads=threads) as (b_x):
# vector add.
for i, j in T.Parallel(threads, num_per_thread):
......@@ -215,7 +215,7 @@ But what happens if we provide additional hints to TileLang? For instance, by ex
def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype="bfloat16"):
@T.prim_func
def main(A: T.Buffer((N), dtype), B: T.Buffer((N), dtype), C: T.Buffer((N), dtype)):
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
with T.Kernel(T.ceildiv(N, threads * NUM_ELE_PER_THREAD), threads=threads) as (b_x):
A_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype)
B_register = T.alloc_fragment((threads * NUM_ELE_PER_THREAD), dtype)
......
......@@ -67,9 +67,9 @@ from tilelang.intrinsics import make_mma_swizzle_layout
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
......@@ -66,7 +66,7 @@ For example, consider a case where a simple `T.copy` in 1D causes the lowering p
```python
@T.prim_func
def main(Q: T.Buffer(shape_q, dtype)):
def main(Q: T.Tensor(shape_q, dtype)):
# ...existing code...
```
......
......@@ -46,10 +46,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@T.macro
def MMA0(
K: T.Buffer(shape, dtype),
Q_shared: T.Buffer([block_M, dim], dtype),
K_shared: T.Buffer([block_N, dim], dtype),
acc_s: T.Buffer([block_M, block_N], accum_dtype),
K: T.Tensor(shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
......@@ -66,10 +66,10 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@T.macro
def MMA1(
V: T.Buffer(shape, dtype),
V_shared: T.Buffer([block_M, dim], dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype),
acc_o: T.Buffer([block_M, dim], accum_dtype),
V: T.Tensor(shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
......@@ -79,13 +79,13 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@T.macro
def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -109,19 +109,19 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@T.macro
def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Buffer(shape, dtype),
K: T.Buffer(shape, dtype),
V: T.Buffer(shape, dtype),
BlockSparseMask: T.Buffer(block_mask_shape, block_mask_dtype),
Output: T.Buffer(shape, dtype),
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
......
......@@ -79,10 +79,10 @@ def blocksparse_matmul(M,
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
BlockMask: T.Buffer(block_mask_shape, "bool"),
C: T.Buffer((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
......
......@@ -46,9 +46,9 @@ def convolution(N, C, H, W, F, K, S, D, P, tune=False):
@T.prim_func
def main(
data: T.Buffer((N, H, W, C), dtype),
kernel: T.Buffer((KH, KW, C, F), dtype),
out: T.Buffer((N, OH, OW, F), dtype),
data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
......@@ -58,8 +58,8 @@ def convolution(N, C, H, W, F, K, S, D, P, tune=False):
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)
kernel_flat = T.Buffer((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Buffer((N * OH * OW, F), dtype, out.data)
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
......
......@@ -73,9 +73,9 @@ def gpu_2d_continuous_cumsum(
def block_inclusive_inside_block(
batch: T.int32,
cur_len: T.int32,
source: T.Buffer,
output: T.Buffer,
tmp_buf: T.Buffer,
source: T.Tensor,
output: T.Tensor,
tmp_buf: T.Tensor,
src_offset: T.int32,
tmp_offset: T.int32,
):
......@@ -125,8 +125,8 @@ def gpu_2d_continuous_cumsum(
def update_cross_block(
batch: T.int32,
cur_len: T.int32,
source: T.Buffer,
output: T.Buffer,
source: T.Tensor,
output: T.Tensor,
src_offset: T.int32,
out_offset: T.int32,
):
......@@ -141,8 +141,8 @@ def gpu_2d_continuous_cumsum(
source[by, src_offset + bx - 1], 0)
@T.prim_func
def cumsum(A: T.Buffer((M, N), dtype="int32"), Out: T.Buffer((M, N), dtype="int32"),
Tmp: T.Buffer((M, N), dtype="int32")):
def cumsum(A: T.Tensor((M, N), dtype="int32"), Out: T.Tensor((M, N), dtype="int32"),
Tmp: T.Tensor((M, N), dtype="int32")):
ceil_log2 = T.Cast("int32", T.ceil(T.log2(T.Cast("float32", N))))
total_rounds = ceil_log2 // LOG_BLOCK_N
......
......@@ -40,11 +40,11 @@ def tl_gemm(
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
scales_a: T.Buffer(Scales_A_shape, "float32"),
scales_b: T.Buffer(Scales_B_shape, "float32"),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
scales_a: T.Tensor(Scales_A_shape, "float32"),
scales_b: T.Tensor(Scales_B_shape, "float32"),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
......@@ -17,11 +17,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.macro
def flash_attn(
Q: T.Buffer([batch, heads, dim], dtype),
Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Buffer([batch, heads, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=256) as (bx, by):
Q_shared = T.alloc_shared([block_H, dim], dtype)
......@@ -84,12 +84,12 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.macro
def flash_attn_split(
Q: T.Buffer([batch, heads, dim], dtype),
Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
):
with T.Kernel(
batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
......@@ -161,9 +161,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.macro
def combine(
glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
Output: T.Buffer([batch, heads, dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype)
......@@ -198,26 +198,26 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.prim_func
def main_split(
Q: T.Buffer([batch, heads, dim], dtype),
Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
Output: T.Buffer([batch, heads, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def main_no_split(
Q: T.Buffer([batch, heads, dim], dtype),
Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
Output: T.Buffer([batch, heads, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn(Q, Q_pe, KV, K_pe, Output)
......
......@@ -19,13 +19,13 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
@T.macro
def flash_mla_kernel(
Q: T.Buffer([batch, h_q, dv], dtype),
Q_pe: T.Buffer([batch, h_q, dpe], dtype),
KV: T.Buffer([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Buffer([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Buffer([batch, max_seqlen_pad // block_size], "int32"),
CACHE_SEQLENS: T.Buffer([batch], "int32"),
Output: T.Buffer([batch, h_q, dv], dtype),
Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
CACHE_SEQLENS: T.Tensor([batch], "int32"),
Output: T.Tensor([batch, h_q, dv], dtype),
):
with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by):
Q_shared = T.alloc_shared([block_H, dv], dtype)
......@@ -98,14 +98,14 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
@T.macro
def flash_mla_split_kv_kernel(
Q: T.Buffer([batch, h_q, dv], dtype),
Q_pe: T.Buffer([batch, h_q, dpe], dtype),
KV: T.Buffer([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Buffer([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Buffer([batch, max_seqlen_pad // block_size], "int32"),
CACHE_SEQLENS: T.Buffer([batch], "int32"),
glse: T.Buffer([batch, h_q, num_split], dtype),
Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype),
Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
CACHE_SEQLENS: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
):
with T.Kernel(
batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
......@@ -185,9 +185,9 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
@T.macro
def combine(
glse: T.Buffer([batch, h_q, num_split], dtype),
Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype),
Output: T.Buffer([batch, h_q, dv], dtype),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype),
):
with T.Kernel(h_q, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dv], dtype)
......@@ -222,15 +222,15 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
@T.prim_func
def main_split(
Q: T.Buffer([batch, h_q, dv], dtype),
Q_pe: T.Buffer([batch, h_q, dpe], dtype),
KV: T.Buffer([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Buffer([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Buffer([batch, max_seqlen_pad // block_size], "int32"),
cache_seqlens: T.Buffer([batch], "int32"),
glse: T.Buffer([batch, h_q, num_split], dtype),
Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype),
Output: T.Buffer([batch, h_q, dv], dtype),
Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype),
):
flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse,
Output_partial)
......@@ -238,15 +238,15 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
@T.prim_func
def main_no_split(
Q: T.Buffer([batch, h_q, dv], dtype),
Q_pe: T.Buffer([batch, h_q, dpe], dtype),
KV: T.Buffer([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Buffer([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Buffer([batch, max_seqlen_pad // block_size], "int32"),
cache_seqlens: T.Buffer([batch], "int32"),
glse: T.Buffer([batch, h_q, num_split], dtype),
Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype),
Output: T.Buffer([batch, h_q, dv], dtype),
Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype),
):
flash_mla_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, Output)
......
......@@ -6,9 +6,9 @@ An example of implementing a dequantization GEMM:
```python
@T.prim_func
def dequant_matmul(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
Ct: T.Buffer((N, M), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......
......@@ -38,9 +38,9 @@ def matmul(
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -236,9 +236,9 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads,
......
......@@ -63,8 +63,8 @@ def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
@T.prim_func
def main(
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((N, K), in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
......@@ -141,9 +141,9 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@T.prim_func
def main_split(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
Ct: T.Buffer((N, M), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
SplitC = T.alloc_buffer([
split, (N + block_N - 1) // block_N * block_N,
......@@ -191,9 +191,9 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
Ct: T.Buffer((N, M), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
......
......@@ -15,9 +15,9 @@ def elementwise_add(
@T.prim_func
def main(
A: T.Buffer((M, N), in_dtype),
B: T.Buffer((M, N), in_dtype),
C: T.Buffer((M, N), out_dtype),
A: T.Tensor((M, N), in_dtype),
B: T.Tensor((M, N), in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
start_x = bx * block_N
......
......@@ -5,10 +5,10 @@ Using tile-lang, we can define buffers at different memory layers. For instance,
```python
@T.prim_func
def flash_attention(
Q: T.Buffer(shape, dtype),
K: T.Buffer(shape, dtype),
V: T.Buffer(shape, dtype),
Output: T.Buffer(shape, dtype),
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype),
):
# Launch a specialized T.Kernel with 3D mapping: (bx, by, bz)
# bx: block index in sequence dimension
......
......@@ -18,11 +18,11 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_casual, block_M, bloc
@T.prim_func
def flash_fwd(
Q: T.Buffer(q_shape, dtype), # type: ignore
K: T.Buffer(k_shape, dtype), # type: ignore
V: T.Buffer(v_shape, dtype), # type: ignore
Output: T.Buffer([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
Output: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim_qk], dtype)
......@@ -86,9 +86,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
@T.prim_func
def flash_bwd_prep(
O: T.Buffer(shape, dtype), # type: ignore
dO: T.Buffer(shape, dtype), # type: ignore
Delta: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
......@@ -121,8 +121,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
@T.prim_func
def flash_bwd_post(
dQ: T.Buffer(shape, accum_dtype), # type: ignore
dQ_out: T.Buffer(shape, dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
......@@ -146,15 +146,15 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_casual, block_M, bloc
@T.prim_func
def flash_bwd(
Q: T.Buffer(q_shape, dtype), # type: ignore
K: T.Buffer(k_shape, dtype), # type: ignore
V: T.Buffer(v_shape, dtype), # type: ignore
dO: T.Buffer([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Buffer(q_shape, accum_dtype), # type: ignore
dK: T.Buffer(k_shape, dtype), # type: ignore
dV: T.Buffer(v_shape, dtype), # type: ignore
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
dO: T.Tensor([batch, seq_len, heads, dim_v], dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, dtype), # type: ignore
dV: T.Tensor(v_shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim_qk], dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment