"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "b2e5ad6b86ca8cfa5427608b8a76dca1207807bb"
Commit bf8a6fc1 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Deprecated `T.Buffer` as arguments and rename related calls into `T.Tensor` (#281)

* [Refactor] Improve flash attention example and layout comparison logic

- Removed unnecessary annotation for `lse_local_split` in the flash attention example to streamline the code.
- Updated the handling of `lse_local_split` to utilize parallel processing for better performance.
- Refactored kernel compilation and profiling logic to enhance clarity and maintainability in the flash attention example.
- Added a condition in `FragmentNode::IsEqual` to handle broadcast cases, improving the robustness of layout comparisons.

* lint fix

* [Enhancement] Add support for shared memory scope in Fill operation

- Introduced handling for `shared.dyn` and `shared` memory scopes in the Fill operation.
- Implemented parallel operation and layout inference for improved performance in shared memory scenarios.
- Updated thread loop partitioning and vectorization logic to accommodate new memory scope handling.

* [Refactor] Remove deprecated decorator and enhance Cython kernel handling

- Removed the deprecated decorator from the main module and added a new implementation in the utils module for better organization.
- Introduced a pointer map in the Cython kernel adapter to manage pointer arguments, improving runtime shape resolution.
- Updated the Cython kernel wrapper to utilize the new pointer map for handling kernel arguments.
- Enhanced error checking in the tensor utility functions to ensure static shapes are enforced.
- Added a new proxy module for buffer and tensor handling, streamlining the interface for TIR programs.

* [Feature] Add matrix multiplication test and kernel implementation

- Introduced a new test file `test_tilelang_language_ptr.py` that implements a matrix multiplication function using TileLang's primitives.
- The `matmul_test` function defines a kernel for performing tile-level GEMM operations with customizable block sizes and data types.
- Added a `run_matmul` function to compile and execute the kernel, along with a test function to validate the implementation.
- Updated the `proxy.py` file to enhance type handling for buffer and tensor proxies, ensuring compatibility with TIR programs.
- Minor formatting improvements in `deprecated.py` for better readability.

* lint fix

* [Refactor] Update tensor creation in matrix multiplication test

- Replaced `T.Tensor.from_ptr` with `T.make_tensor` in `matmul_test` for improved clarity and consistency.
- Updated imports in `__init__.py` to include `make_tensor`.
- Added `make_tensor` function in `proxy.py` to streamline tensor creation from pointers.

* [Refactor] Update tensor definitions across multiple files

- Replaced instances of `T.Tensor` with updated tensor definitions in various benchmark and example files to enhance consistency and clarity.
- Adjusted tensor shapes and types in functions related to matrix multiplication, attention mechanisms, and other operations.
- Improved documentation in README and example files to reflect changes in tensor usage.

* lint fix

* [Refactor] Update tensor types in attention and matrix multiplication examples

- Replaced instances of `T.Tensor` with `T.SharedTensor` and `T.FragmentTensor` in various attention and matrix multiplication functions to improve consistency and clarity.
- Adjusted tensor definitions in benchmark and example files to align with the new tensor types.
- Enhanced the overall structure and readability of the code by standardizing tensor usage across multiple files.

* lint fix

* [Refactor] Update tensor types in GEMM example and test files

- Replaced instances of `T.Tensor` with `T.LocalTensor` and `T.Buffer` in the GEMM example and related test functions to improve consistency and clarity.
- Enhanced the overall structure of the code by standardizing tensor usage across multiple files, aligning with recent updates in tensor definitions.

* [Refactor] Update tensor usage in customize.py

- Replaced instances of `T.Tensor` with `T.Buffer` in the `reshape` and `view` functions to enhance consistency with recent tensor definitions.
- Improved code clarity by standardizing buffer usage across the file.

* [Refactor] Update tensor types in test_tilelang_transform_annotate_device_regions.py

- Replaced instances of `T.Tensor` with `T.Buffer` in the `before` and `expected` methods of the `TestAnnotateThreadExtent` and `TestAnnotateDeviceScope` classes to enhance consistency with recent tensor definitions.
- Improved code clarity by standardizing buffer usage across the test file.

* [Refactor] Update tensor types to SharedBuffer and FragmentBuffer

- Replaced instances of `T.SharedTensor` and `T.FragmentTensor` with `T.SharedBuffer` and `T.FragmentBuffer` across multiple benchmark, example, and test files to enhance consistency with recent tensor definitions.
- Improved code clarity and structure by standardizing buffer usage in attention and matrix multiplication functions.

* [Refactor] Introduce Tensor alias for Buffer in proxy.py

- Added a new alias `Tensor` for `Buffer` in `proxy.py` to facilitate JIT compilation, ensuring that inputs and outputs are mapped with `torch.Tensor`.
- This change enhances clarity and consistency in tensor usage across the codebase.
parent 73d2c62e
...@@ -69,10 +69,10 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -69,10 +69,10 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Buffer(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
Q_shared: T.Buffer([block_M, dim], dtype), Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.Buffer([block_N, dim], dtype), K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32, k: T.int32,
bx: T.int32, bx: T.int32,
by: T.int32, by: T.int32,
...@@ -89,26 +89,26 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -89,26 +89,26 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
@T.macro @T.macro
def MMA1( def MMA1(
V: T.Buffer(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
V_shared: T.Buffer([block_M, dim], dtype), V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32, k: T.int32,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -132,18 +132,18 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -132,18 +132,18 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Buffer(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Buffer(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Buffer(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
Output: T.Buffer(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
......
...@@ -36,10 +36,10 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -36,10 +36,10 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Buffer(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
Q_shared: T.Buffer([block_M, dim], dtype), Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.Buffer([block_N, dim], dtype), K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32, k: T.int32,
bx: T.int32, bx: T.int32,
by: T.int32, by: T.int32,
...@@ -56,26 +56,26 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -56,26 +56,26 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
@T.macro @T.macro
def MMA1( def MMA1(
V: T.Buffer(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
V_shared: T.Buffer([block_M, dim], dtype), V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32, k: T.int32,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -99,18 +99,18 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): ...@@ -99,18 +99,18 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Buffer(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Buffer(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Buffer(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
Output: T.Buffer(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
......
...@@ -14,11 +14,11 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): ...@@ -14,11 +14,11 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
@T.prim_func @T.prim_func
def flash_fwd( def flash_fwd(
Q: T.Buffer(shape, dtype), # type: ignore Q: T.Tensor(shape, dtype), # type: ignore
K: T.Buffer(shape, dtype), # type: ignore K: T.Tensor(shape, dtype), # type: ignore
V: T.Buffer(shape, dtype), # type: ignore V: T.Tensor(shape, dtype), # type: ignore
Output: T.Buffer(shape, dtype), # type: ignore Output: T.Tensor(shape, dtype), # type: ignore
lse: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -86,9 +86,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): ...@@ -86,9 +86,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
@T.prim_func @T.prim_func
def flash_bwd_prep( def flash_bwd_prep(
O: T.Buffer(shape, dtype), # type: ignore O: T.Tensor(shape, dtype), # type: ignore
dO: T.Buffer(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype) o = T.alloc_fragment([blk, blk], dtype)
...@@ -121,8 +121,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): ...@@ -121,8 +121,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
@T.prim_func @T.prim_func
def flash_bwd_post( def flash_bwd_post(
dQ: T.Buffer(shape, accum_dtype), # type: ignore dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Buffer(shape, dtype), # type: ignore dQ_out: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)}) T.annotate_layout({dQ: make_dq_layout(dQ)})
...@@ -143,15 +143,15 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): ...@@ -143,15 +143,15 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
@T.prim_func @T.prim_func
def flash_bwd( def flash_bwd(
Q: T.Buffer(shape, dtype), # type: ignore Q: T.Tensor(shape, dtype), # type: ignore
K: T.Buffer(shape, dtype), # type: ignore K: T.Tensor(shape, dtype), # type: ignore
V: T.Buffer(shape, dtype), # type: ignore V: T.Tensor(shape, dtype), # type: ignore
dO: T.Buffer(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Buffer(shape, accum_dtype), # type: ignore dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Buffer(shape, dtype), # type: ignore dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Buffer(shape, dtype), # type: ignore dV: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_M, dim], dtype)
......
...@@ -18,11 +18,11 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): ...@@ -18,11 +18,11 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
@T.prim_func @T.prim_func
def flash_fwd( def flash_fwd(
Q: T.Buffer(shape, dtype), # type: ignore Q: T.Tensor(shape, dtype), # type: ignore
K: T.Buffer(shape, dtype), # type: ignore K: T.Tensor(shape, dtype), # type: ignore
V: T.Buffer(shape, dtype), # type: ignore V: T.Tensor(shape, dtype), # type: ignore
Output: T.Buffer(shape, dtype), # type: ignore Output: T.Tensor(shape, dtype), # type: ignore
lse: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
...@@ -90,9 +90,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim): ...@@ -90,9 +90,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
@T.prim_func @T.prim_func
def flash_bwd_prep( def flash_bwd_prep(
O: T.Buffer(shape, dtype), # type: ignore O: T.Tensor(shape, dtype), # type: ignore
dO: T.Buffer(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype) o = T.alloc_fragment([blk, blk], dtype)
...@@ -125,8 +125,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): ...@@ -125,8 +125,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
@T.prim_func @T.prim_func
def flash_bwd_post( def flash_bwd_post(
dQ: T.Buffer(shape, accum_dtype), # type: ignore dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Buffer(shape, dtype), # type: ignore dQ_out: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)}) T.annotate_layout({dQ: make_dq_layout(dQ)})
...@@ -147,15 +147,15 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): ...@@ -147,15 +147,15 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
@T.prim_func @T.prim_func
def flash_bwd( def flash_bwd(
Q: T.Buffer(shape, dtype), # type: ignore Q: T.Tensor(shape, dtype), # type: ignore
K: T.Buffer(shape, dtype), # type: ignore K: T.Tensor(shape, dtype), # type: ignore
V: T.Buffer(shape, dtype), # type: ignore V: T.Tensor(shape, dtype), # type: ignore
dO: T.Buffer(shape, dtype), # type: ignore dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Buffer(shape, accum_dtype), # type: ignore dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Buffer(shape, dtype), # type: ignore dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Buffer(shape, dtype), # type: ignore dV: T.Tensor(shape, dtype), # type: ignore
): ):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_M, dim], dtype)
......
...@@ -35,10 +35,10 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): ...@@ -35,10 +35,10 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Buffer(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
Q_shared: T.Buffer([block_M, dim], dtype), Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.Buffer([block_N, dim], dtype), K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32, k: T.int32,
bx: T.int32, bx: T.int32,
by: T.int32, by: T.int32,
...@@ -57,26 +57,26 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): ...@@ -57,26 +57,26 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
@T.macro @T.macro
def MMA1( def MMA1(
V: T.Buffer(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
V_shared: T.Buffer([block_M, dim], dtype), V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32, k: T.int32,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -101,18 +101,18 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False): ...@@ -101,18 +101,18 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, tune=False):
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Buffer(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
K: T.Buffer(kv_shape, dtype), K: T.Tensor(kv_shape, dtype),
V: T.Buffer(kv_shape, dtype), V: T.Tensor(kv_shape, dtype),
Output: T.Buffer(q_shape, dtype), Output: T.Tensor(q_shape, dtype),
): ):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
......
...@@ -34,10 +34,10 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -34,10 +34,10 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Buffer(shape, dtype), K: T.Tensor(shape, dtype),
Q_shared: T.Buffer([block_M, dim], dtype), Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.Buffer([block_N, dim], dtype), K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32, k: T.int32,
bx: T.int32, bx: T.int32,
by: T.int32, by: T.int32,
...@@ -54,26 +54,26 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -54,26 +54,26 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
@T.macro @T.macro
def MMA1( def MMA1(
V: T.Buffer(shape, dtype), V: T.Tensor(shape, dtype),
V_shared: T.Buffer([block_M, dim], dtype), V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32, k: T.int32,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -97,18 +97,18 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -97,18 +97,18 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Buffer(shape, dtype), Q: T.Tensor(shape, dtype),
K: T.Buffer(shape, dtype), K: T.Tensor(shape, dtype),
V: T.Buffer(shape, dtype), V: T.Tensor(shape, dtype),
Output: T.Buffer(shape, dtype), Output: T.Tensor(shape, dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
......
...@@ -34,10 +34,10 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -34,10 +34,10 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Buffer(shape, dtype), K: T.Tensor(shape, dtype),
Q_shared: T.Buffer([block_M, dim], dtype), Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.Buffer([block_N, dim], dtype), K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32, k: T.int32,
bx: T.int32, bx: T.int32,
by: T.int32, by: T.int32,
...@@ -54,26 +54,26 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -54,26 +54,26 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
@T.macro @T.macro
def MMA1( def MMA1(
V: T.Buffer(shape, dtype), V: T.Tensor(shape, dtype),
V_shared: T.Buffer([block_M, dim], dtype), V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32, k: T.int32,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -97,18 +97,18 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): ...@@ -97,18 +97,18 @@ def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Buffer(shape, dtype), Q: T.Tensor(shape, dtype),
K: T.Buffer(shape, dtype), K: T.Tensor(shape, dtype),
V: T.Buffer(shape, dtype), V: T.Tensor(shape, dtype),
Output: T.Buffer(shape, dtype), Output: T.Tensor(shape, dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
......
...@@ -235,13 +235,13 @@ def flashattn(batch_size, UQ, UKV, heads, dim, is_causal): ...@@ -235,13 +235,13 @@ def flashattn(batch_size, UQ, UKV, heads, dim, is_causal):
@T.prim_func @T.prim_func
def main( def main(
Q_unpad: T.Buffer(q_shape, dtype), Q_unpad: T.Tensor(q_shape, dtype),
K_unpad: T.Buffer(k_shape, dtype), K_unpad: T.Tensor(k_shape, dtype),
V_unpad: T.Buffer(v_shape, dtype), V_unpad: T.Tensor(v_shape, dtype),
cu_seqlens_q: T.Buffer([batch_size + 1], "int32"), cu_seqlens_q: T.Tensor([batch_size + 1], "int32"),
cu_seqlens_k: T.Buffer([batch_size + 1], "int32"), cu_seqlens_k: T.Tensor([batch_size + 1], "int32"),
max_seqlen_q: T.int32, max_seqlen_q: T.int32,
Output_unpad: T.Buffer(o_shape, dtype), Output_unpad: T.Tensor(o_shape, dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(max_seqlen_q, block_M), heads, batch_size, T.ceildiv(max_seqlen_q, block_M), heads, batch_size,
......
...@@ -18,10 +18,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -18,10 +18,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Buffer(shape_kv, dtype), K: T.Tensor(shape_kv, dtype),
Q_shared: T.Buffer([block_M, dim], dtype), Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.Buffer([block_N, dim], dtype), K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32, k: T.int32,
mid: T.int32, mid: T.int32,
hid: T.int32, hid: T.int32,
...@@ -42,10 +42,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -42,10 +42,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@T.macro @T.macro
def MMA1( def MMA1(
V: T.Buffer(shape_kv, dtype), V: T.Tensor(shape_kv, dtype),
V_shared: T.Buffer([block_M, dim], dtype), V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32, k: T.int32,
hid: T.int32, hid: T.int32,
bid: T.int32, bid: T.int32,
...@@ -58,13 +58,13 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -58,13 +58,13 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -88,19 +88,19 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -88,19 +88,19 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.macro @T.macro
def flash_attn_split( def flash_attn_split(
Q: T.Buffer(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Buffer(shape_kv, dtype), K: T.Tensor(shape_kv, dtype),
V: T.Buffer(shape_kv, dtype), V: T.Tensor(shape_kv, dtype),
glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype), glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Buffer(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(seqlen_q, block_M), heads * batch, num_split, T.ceildiv(seqlen_q, block_M), heads * batch, num_split,
...@@ -151,9 +151,9 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -151,9 +151,9 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@T.macro @T.macro
def combine( def combine(
glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype), glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Buffer(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
Output: T.Buffer(shape_q, dtype), Output: T.Tensor(shape_q, dtype),
): ):
with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz):
po_local = T.alloc_fragment([block_M, dim], dtype) po_local = T.alloc_fragment([block_M, dim], dtype)
...@@ -201,12 +201,12 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -201,12 +201,12 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@T.prim_func @T.prim_func
def main( def main(
Q: T.Buffer(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Buffer(shape_kv, dtype), K: T.Tensor(shape_kv, dtype),
V: T.Buffer(shape_kv, dtype), V: T.Tensor(shape_kv, dtype),
glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype), glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Buffer(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim]
Output: T.Buffer(shape_q, dtype), Output: T.Tensor(shape_q, dtype),
): ):
flash_attn_split(Q, K, V, glse, Output_partial) flash_attn_split(Q, K, V, glse, Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
......
...@@ -42,11 +42,11 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -42,11 +42,11 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
@T.macro @T.macro
def flash_attn( def flash_attn(
Q: T.Buffer(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Buffer(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Buffer(shape_v, dtype), V: T.Tensor(shape_v, dtype),
mask: T.Buffer([batch, seqlen_kv, groups], "uint8"), mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
Output: T.Buffer([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
): ):
with T.Kernel( with T.Kernel(
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
...@@ -111,12 +111,12 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -111,12 +111,12 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
@T.macro @T.macro
def flash_attn_split( def flash_attn_split(
Q: T.Buffer(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Buffer(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Buffer(shape_v, dtype), V: T.Tensor(shape_v, dtype),
mask: T.Buffer([batch, seqlen_kv, groups], "uint8"), mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Buffer([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Buffer(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
): ):
with T.Kernel( with T.Kernel(
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
...@@ -195,9 +195,9 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -195,9 +195,9 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
@T.macro @T.macro
def combine( def combine(
glse: T.Buffer([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Buffer(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
Output: T.Buffer(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
): ):
with T.Kernel(heads, batch, threads=128) as (by, bz): with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype) po_local = T.alloc_fragment([dim], dtype)
...@@ -238,26 +238,26 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False): ...@@ -238,26 +238,26 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, tune=False):
@T.prim_func @T.prim_func
def main_split( def main_split(
Q: T.Buffer(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Buffer(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Buffer(shape_v, dtype), V: T.Tensor(shape_v, dtype),
mask: T.Buffer([batch, seqlen_kv, groups], "uint8"), mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Buffer([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Buffer(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
Output: T.Buffer(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
): ):
flash_attn_split(Q, K, V, mask, glse, Output_partial) flash_attn_split(Q, K, V, mask, glse, Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
@T.prim_func @T.prim_func
def main_no_split( def main_no_split(
Q: T.Buffer(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Buffer(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Buffer(shape_v, dtype), V: T.Tensor(shape_v, dtype),
mask: T.Buffer([batch, seqlen_kv, groups], "uint8"), mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Buffer([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Buffer(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
Output: T.Buffer(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
): ):
flash_attn(Q, K, V, mask, Output) flash_attn(Q, K, V, mask, Output)
......
...@@ -18,10 +18,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -18,10 +18,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Buffer(shape_kv, dtype), K: T.Tensor(shape_kv, dtype),
Q_shared: T.Buffer([block_M, dim], dtype), Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.Buffer([block_N, dim], dtype), K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32, k: T.int32,
mid: T.int32, mid: T.int32,
hid: T.int32, hid: T.int32,
...@@ -42,10 +42,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -42,10 +42,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@T.macro @T.macro
def MMA1( def MMA1(
V: T.Buffer(shape_kv, dtype), V: T.Tensor(shape_kv, dtype),
V_shared: T.Buffer([block_M, dim], dtype), V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32, k: T.int32,
hid: T.int32, hid: T.int32,
bid: T.int32, bid: T.int32,
...@@ -58,13 +58,13 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -58,13 +58,13 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -88,19 +88,19 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -88,19 +88,19 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.macro @T.macro
def flash_attn_split( def flash_attn_split(
Q: T.Buffer(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Buffer(shape_kv, dtype), K: T.Tensor(shape_kv, dtype),
V: T.Buffer(shape_kv, dtype), V: T.Tensor(shape_kv, dtype),
glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype), glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Buffer(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(seqlen_q, block_M), heads * batch, num_split, T.ceildiv(seqlen_q, block_M), heads * batch, num_split,
...@@ -150,9 +150,9 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -150,9 +150,9 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@T.macro @T.macro
def combine( def combine(
glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype), glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Buffer(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
Output: T.Buffer(shape_q, dtype), Output: T.Tensor(shape_q, dtype),
): ):
with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz):
po_local = T.alloc_fragment([block_M, dim], dtype) po_local = T.alloc_fragment([block_M, dim], dtype)
...@@ -201,12 +201,12 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -201,12 +201,12 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@T.prim_func @T.prim_func
def main( def main(
Q: T.Buffer(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Buffer(shape_kv, dtype), K: T.Tensor(shape_kv, dtype),
V: T.Buffer(shape_kv, dtype), V: T.Tensor(shape_kv, dtype),
glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype), glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Buffer(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim]
Output: T.Buffer(shape_q, dtype), Output: T.Tensor(shape_q, dtype),
): ):
flash_attn_split(Q, K, V, glse, Output_partial) flash_attn_split(Q, K, V, glse, Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
......
...@@ -53,9 +53,9 @@ import tilelang.language as T ...@@ -53,9 +53,9 @@ import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Buffer((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
# Define a grid with enough blocks to cover M×N # Define a grid with enough blocks to cover M×N
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
...@@ -176,9 +176,9 @@ from tilelang.intrinsics import make_mma_swizzle_layout as make_swizzle_layout ...@@ -176,9 +176,9 @@ from tilelang.intrinsics import make_mma_swizzle_layout as make_swizzle_layout
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Buffer((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
# Allocate shared and local fragments # Allocate shared and local fragments
...@@ -326,9 +326,9 @@ def tl_matmul( ...@@ -326,9 +326,9 @@ def tl_matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
......
...@@ -92,9 +92,9 @@ def get_best_config(M, N, K, with_roller=False): ...@@ -92,9 +92,9 @@ def get_best_config(M, N, K, with_roller=False):
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Buffer((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
...@@ -143,9 +143,9 @@ def matmul(M, ...@@ -143,9 +143,9 @@ def matmul(M,
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Buffer((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
......
...@@ -100,9 +100,9 @@ def tl_matmul( ...@@ -100,9 +100,9 @@ def tl_matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
......
...@@ -7,9 +7,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -7,9 +7,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Buffer((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
......
...@@ -15,9 +15,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): ...@@ -15,9 +15,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Buffer((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
......
...@@ -12,9 +12,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): ...@@ -12,9 +12,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Buffer((M, N), accum_dtype), C: T.Tensor((M, N), accum_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
......
...@@ -105,9 +105,9 @@ def tl_matmul( ...@@ -105,9 +105,9 @@ def tl_matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
......
...@@ -9,9 +9,9 @@ def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_d ...@@ -9,9 +9,9 @@ def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_d
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Buffer((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Buffer((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
......
...@@ -79,12 +79,12 @@ def tl_matmul_streamk( ...@@ -79,12 +79,12 @@ def tl_matmul_streamk(
@T.macro @T.macro
def compute_first_wave( def compute_first_wave(
pid: T.int32, pid: T.int32,
A_buf: T.Buffer, A_buf: T.Tensor,
A_buf_shared: T.Buffer, A_buf_shared: T.SharedBuffer,
B_buf: T.Buffer, B_buf: T.Tensor,
B_buf_shared: T.Buffer, B_buf_shared: T.SharedBuffer,
C: T.Buffer, C: T.Tensor,
C_local: T.Buffer, C_local: T.LocalBuffer,
): ):
start_iter = T.alloc_fragment((1,), "int32", "local") start_iter = T.alloc_fragment((1,), "int32", "local")
end_iter = T.alloc_fragment((1,), "int32", "local") end_iter = T.alloc_fragment((1,), "int32", "local")
...@@ -127,12 +127,12 @@ def tl_matmul_streamk( ...@@ -127,12 +127,12 @@ def tl_matmul_streamk(
@T.macro @T.macro
def compute_full_tiles( def compute_full_tiles(
pid: T.int32, pid: T.int32,
A_buf: T.Buffer, A_buf: T.Tensor,
A_shared: T.Buffer, A_shared: T.SharedBuffer,
B_buf: T.Buffer, B_buf: T.Tensor,
B_shared: T.Buffer, B_shared: T.SharedBuffer,
C: T.Buffer, C: T.Tensor,
C_local: T.Buffer, C_local: T.LocalBuffer,
): ):
for p in T.serial(sm_patition_factor): for p in T.serial(sm_patition_factor):
...@@ -149,9 +149,9 @@ def tl_matmul_streamk( ...@@ -149,9 +149,9 @@ def tl_matmul_streamk(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, dtypeAB), A: T.Tensor(A_shape, dtypeAB),
B: T.Buffer(B_shape, dtypeAB), B: T.Tensor(B_shape, dtypeAB),
C: T.Buffer((M, N), dtypeC), C: T.Tensor((M, N), dtypeC),
): ):
with T.Kernel(streamk_programs, threads=threads) as pid: with T.Kernel(streamk_programs, threads=threads) as pid:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment