Commit bf8a6fc1 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Deprecated `T.Buffer` as arguments and rename related calls into `T.Tensor` (#281)

* [Refactor] Improve flash attention example and layout comparison logic

- Removed unnecessary annotation for `lse_local_split` in the flash attention example to streamline the code.
- Updated the handling of `lse_local_split` to utilize parallel processing for better performance.
- Refactored kernel compilation and profiling logic to enhance clarity and maintainability in the flash attention example.
- Added a condition in `FragmentNode::IsEqual` to handle broadcast cases, improving the robustness of layout comparisons.

* lint fix

* [Enhancement] Add support for shared memory scope in Fill operation

- Introduced handling for `shared.dyn` and `shared` memory scopes in the Fill operation.
- Implemented parallel operation and layout inference for improved performance in shared memory scenarios.
- Updated thread loop partitioning and vectorization logic to accommodate new memory scope handling.

* [Refactor] Remove deprecated decorator and enhance Cython kernel handling

- Removed the deprecated decorator from the main module and added a new implementation in the utils module for better organization.
- Introduced a pointer map in the Cython kernel adapter to manage pointer arguments, improving runtime shape resolution.
- Updated the Cython kernel wrapper to utilize the new pointer map for handling kernel arguments.
- Enhanced error checking in the tensor utility functions to ensure static shapes are enforced.
- Added a new proxy module for buffer and tensor handling, streamlining the interface for TIR programs.

* [Feature] Add matrix multiplication test and kernel implementation

- Introduced a new test file `test_tilelang_language_ptr.py` that implements a matrix multiplication function using TileLang's primitives.
- The `matmul_test` function defines a kernel for performing tile-level GEMM operations with customizable block sizes and data types.
- Added a `run_matmul` function to compile and execute the kernel, along with a test function to validate the implementation.
- Updated the `proxy.py` file to enhance type handling for buffer and tensor proxies, ensuring compatibility with TIR programs.
- Minor formatting improvements in `deprecated.py` for better readability.

* lint fix

* [Refactor] Update tensor creation in matrix multiplication test

- Replaced `T.Tensor.from_ptr` with `T.make_tensor` in `matmul_test` for improved clarity and consistency.
- Updated imports in `__init__.py` to include `make_tensor`.
- Added `make_tensor` function in `proxy.py` to streamline tensor creation from pointers.

* [Refactor] Update tensor definitions across multiple files

- Replaced instances of `T.Tensor` with updated tensor definitions in various benchmark and example files to enhance consistency and clarity.
- Adjusted tensor shapes and types in functions related to matrix multiplication, attention mechanisms, and other operations.
- Improved documentation in README and example files to reflect changes in tensor usage.

* lint fix

* [Refactor] Update tensor types in attention and matrix multiplication examples

- Replaced instances of `T.Tensor` with `T.SharedTensor` and `T.FragmentTensor` in various attention and matrix multiplication functions to improve consistency and clarity.
- Adjusted tensor definitions in benchmark and example files to align with the new tensor types.
- Enhanced the overall structure and readability of the code by standardizing tensor usage across multiple files.

* lint fix

* [Refactor] Update tensor types in GEMM example and test files

- Replaced instances of `T.Tensor` with `T.LocalTensor` and `T.Buffer` in the GEMM example and related test functions to improve consistency and clarity.
- Enhanced the overall structure of the code by standardizing tensor usage across multiple files, aligning with recent updates in tensor definitions.

* [Refactor] Update tensor usage in customize.py

- Replaced instances of `T.Tensor` with `T.Buffer` in the `reshape` and `view` functions to enhance consistency with recent tensor definitions.
- Improved code clarity by standardizing buffer usage across the file.

* [Refactor] Update tensor types in test_tilelang_transform_annotate_device_regions.py

- Replaced instances of `T.Tensor` with `T.Buffer` in the `before` and `expected` methods of the `TestAnnotateThreadExtent` and `TestAnnotateDeviceScope` classes to enhance consistency with recent tensor definitions.
- Improved code clarity by standardizing buffer usage across the test file.

* [Refactor] Update tensor types to SharedBuffer and FragmentBuffer

- Replaced instances of `T.SharedTensor` and `T.FragmentTensor` with `T.SharedBuffer` and `T.FragmentBuffer` across multiple benchmark, example, and test files to enhance consistency with recent tensor definitions.
- Improved code clarity and structure by standardizing buffer usage in attention and matrix multiplication functions.

* [Refactor] Introduce Tensor alias for Buffer in proxy.py

- Added a new alias `Tensor` for `Buffer` in `proxy.py` to facilitate JIT compilation, ensuring that inputs and outputs are mapped with `torch.Tensor`.
- This change enhances clarity and consistency in tensor usage across the codebase.
parent 73d2c62e
...@@ -8,9 +8,9 @@ def matmul(M, N, K, block_M, block_N, block_K, threads, dtype="float16", accum_d ...@@ -8,9 +8,9 @@ def matmul(M, N, K, block_M, block_N, block_K, threads, dtype="float16", accum_d
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Buffer((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
......
...@@ -8,9 +8,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -8,9 +8,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Buffer((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
......
...@@ -29,9 +29,9 @@ def matmul( ...@@ -29,9 +29,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -139,9 +139,9 @@ def matmu_jit_kernel( ...@@ -139,9 +139,9 @@ def matmu_jit_kernel(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......
...@@ -31,9 +31,9 @@ def matmul( ...@@ -31,9 +31,9 @@ def matmul(
) )
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -147,9 +147,9 @@ def matmu_jit_kernel( ...@@ -147,9 +147,9 @@ def matmu_jit_kernel(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......
...@@ -28,9 +28,9 @@ def matmul( ...@@ -28,9 +28,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -138,9 +138,9 @@ def matmu_jit_kernel( ...@@ -138,9 +138,9 @@ def matmu_jit_kernel(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......
...@@ -28,9 +28,9 @@ def matmul( ...@@ -28,9 +28,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -138,9 +138,9 @@ def matmu_jit_kernel( ...@@ -138,9 +138,9 @@ def matmu_jit_kernel(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -498,9 +498,9 @@ def matmul_int_variable( ...@@ -498,9 +498,9 @@ def matmul_int_variable(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
offset: T.int32, offset: T.int32,
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
...@@ -570,9 +570,9 @@ def matmul_float_variable( ...@@ -570,9 +570,9 @@ def matmul_float_variable(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
offset: T.float32, offset: T.float32,
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
......
...@@ -106,9 +106,9 @@ def tl_matmul( ...@@ -106,9 +106,9 @@ def tl_matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
......
...@@ -11,9 +11,9 @@ def convolution(N, C, H, W, F, K, S, D, P, in_dtype, out_dtype, dtypeAccum, bloc ...@@ -11,9 +11,9 @@ def convolution(N, C, H, W, F, K, S, D, P, in_dtype, out_dtype, dtypeAccum, bloc
@T.prim_func @T.prim_func
def main( def main(
data: T.Buffer((N, H, W, C), in_dtype), data: T.Tensor((N, H, W, C), in_dtype),
kernel: T.Buffer((KH, KW, C, F), in_dtype), kernel: T.Tensor((KH, KW, C, F), in_dtype),
out: T.Buffer((N, OH, OW, F), out_dtype), out: T.Tensor((N, OH, OW, F), out_dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
...@@ -22,8 +22,8 @@ def convolution(N, C, H, W, F, K, S, D, P, in_dtype, out_dtype, dtypeAccum, bloc ...@@ -22,8 +22,8 @@ def convolution(N, C, H, W, F, K, S, D, P, in_dtype, out_dtype, dtypeAccum, bloc
kernel_shared = T.alloc_shared((block_K, block_N), in_dtype) kernel_shared = T.alloc_shared((block_K, block_N), in_dtype)
out_local = T.alloc_fragment((block_M, block_N), dtypeAccum) out_local = T.alloc_fragment((block_M, block_N), dtypeAccum)
kernel_flat = T.Buffer((KH * KW * C, F), in_dtype, kernel.data) kernel_flat = T.Tensor((KH * KW * C, F), in_dtype, kernel.data)
out_flat = T.Buffer((N * OH * OW, F), out_dtype, out.data) out_flat = T.Tensor((N * OH * OW, F), out_dtype, out.data)
T.clear(out_local) T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
......
...@@ -64,8 +64,8 @@ def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): ...@@ -64,8 +64,8 @@ def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
@T.prim_func @T.prim_func
def main( def main(
B: T.Buffer(B_shape, storage_dtype), B: T.Tensor(B_shape, storage_dtype),
C: T.Buffer((N, K), in_dtype), C: T.Tensor((N, K), in_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
...@@ -132,9 +132,9 @@ def matmul_fp16xfp4(M, ...@@ -132,9 +132,9 @@ def matmul_fp16xfp4(M,
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype), B: T.Tensor(B_shape, storage_dtype),
Ct: T.Buffer((N, M), out_dtype), Ct: T.Tensor((N, M), out_dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
...@@ -239,9 +239,9 @@ def matmul( ...@@ -239,9 +239,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype), B: T.Tensor(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -437,9 +437,9 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -437,9 +437,9 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype), B: T.Tensor(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads,
......
...@@ -16,9 +16,9 @@ def elementwise_add( ...@@ -16,9 +16,9 @@ def elementwise_add(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, N), in_dtype), A: T.Tensor((M, N), in_dtype),
B: T.Buffer((M, N), in_dtype), B: T.Tensor((M, N), in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
start_x = bx * block_N start_x = bx * block_N
......
...@@ -13,13 +13,13 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, ...@@ -13,13 +13,13 @@ def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate,
p = 1.44269504 p = 1.44269504
@T.prim_func @T.prim_func
def main(cb: T.Buffer((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), x: T.Buffer( def main(cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), x: T.Tensor(
(batch, seqlen, nheads, headdim), dtype), dt: T.Buffer( (batch, seqlen, nheads, headdim), dtype), dt: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Buffer( (batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), (batch, nheads, nchunks, chunk_size), dtype),
C: T.Buffer((batch, seqlen, ngroups, dstate), dtype), prev_states: T.Buffer( C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), prev_states: T.Tensor(
(batch, nchunks, nheads, headdim, dstate), dtype), D: T.Buffer( (batch, nchunks, nheads, headdim, dstate), dtype), D: T.Tensor(
(nheads), dtype), Output: T.Buffer((batch, seqlen, nheads, headdim), dtype)): (nheads), dtype), Output: T.Tensor((batch, seqlen, nheads, headdim), dtype)):
with T.Kernel( with T.Kernel(
nheads, nheads,
T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N),
...@@ -199,10 +199,10 @@ def chunk_state_fwd(batch, ...@@ -199,10 +199,10 @@ def chunk_state_fwd(batch,
p = 1.44269504 p = 1.44269504
@T.prim_func @T.prim_func
def main(B: T.Buffer((batch, seqlen, ngroups, dstate), dtype), x: T.Buffer( def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor(
(batch, seqlen, nheads, headdim), dtype), dt: T.Buffer( (batch, seqlen, nheads, headdim), dtype), dt: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Buffer( (batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor(
(batch, nheads, nchunks, chunk_size), dtype), Output: T.Buffer( (batch, nheads, nchunks, chunk_size), dtype), Output: T.Tensor(
(batch, nchunks, nheads, headdim, dstate), dtype)): (batch, nchunks, nheads, headdim, dstate), dtype)):
with T.Kernel( with T.Kernel(
nheads, nheads,
......
...@@ -15,9 +15,9 @@ def matmul_nt(M, N, K, bM, bN, bK, in_dtype, out_dtype, accum_dtype): ...@@ -15,9 +15,9 @@ def matmul_nt(M, N, K, bM, bN, bK, in_dtype, out_dtype, accum_dtype):
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), in_dtype), A: T.Tensor((M, K), in_dtype),
B: T.Buffer((N, K), in_dtype), B: T.Tensor((N, K), in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, bN), T.ceildiv(M, bM), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, bN), T.ceildiv(M, bM), threads=128) as (bx, by):
A_shared = T.alloc_shared((bM, bK), in_dtype) A_shared = T.alloc_shared((bM, bK), in_dtype)
......
...@@ -105,9 +105,9 @@ def tl_matmul( ...@@ -105,9 +105,9 @@ def tl_matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
......
...@@ -50,10 +50,10 @@ def gemv_simt( ...@@ -50,10 +50,10 @@ def gemv_simt(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
Bias: T.Buffer(Bias_shape, out_dtype), Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype), C: T.Tensor(C_shape, out_dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as ( T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as (
......
...@@ -26,9 +26,9 @@ def matmul( ...@@ -26,9 +26,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......
...@@ -106,9 +106,9 @@ def tl_matmul( ...@@ -106,9 +106,9 @@ def tl_matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
......
...@@ -76,9 +76,9 @@ def tl_matmul_simt( ...@@ -76,9 +76,9 @@ def tl_matmul_simt(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype), C: T.Tensor(C_shape, out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
......
...@@ -50,10 +50,10 @@ def gemv_simt( ...@@ -50,10 +50,10 @@ def gemv_simt(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
Bias: T.Buffer(Bias_shape, out_dtype), Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype), C: T.Tensor(C_shape, out_dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as ( T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as (
......
...@@ -92,9 +92,9 @@ def tl_matmul( ...@@ -92,9 +92,9 @@ def tl_matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
...@@ -280,9 +280,9 @@ def tl_matmul_weight_only_transform( ...@@ -280,9 +280,9 @@ def tl_matmul_weight_only_transform(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
......
...@@ -11,10 +11,10 @@ def flashattn(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stage ...@@ -11,10 +11,10 @@ def flashattn(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stage
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Buffer(shape, dtype), K: T.Tensor(shape, dtype),
Q_shared: T.Buffer([block_M, dim], dtype), Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.Buffer([block_N, dim], dtype), K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32, k: T.int32,
bx: T.int32, bx: T.int32,
by: T.int32, by: T.int32,
...@@ -31,26 +31,26 @@ def flashattn(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stage ...@@ -31,26 +31,26 @@ def flashattn(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stage
@T.macro @T.macro
def MMA1( def MMA1(
V: T.Buffer(shape, dtype), V: T.Tensor(shape, dtype),
V_shared: T.Buffer([block_M, dim], dtype), V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32, k: T.int32,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -74,18 +74,18 @@ def flashattn(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stage ...@@ -74,18 +74,18 @@ def flashattn(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stage
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Buffer(shape, dtype), Q: T.Tensor(shape, dtype),
K: T.Buffer(shape, dtype), K: T.Tensor(shape, dtype),
V: T.Buffer(shape, dtype), V: T.Tensor(shape, dtype),
Output: T.Buffer(shape, dtype), Output: T.Tensor(shape, dtype),
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment