Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
......@@ -10,9 +10,9 @@ from index import prepare_token_indices
from utils import get_abs_err, get_err_ratio
BF16 = "bfloat16"
FP32 = "float32"
INT32 = "int32"
BF16 = T.bfloat16
FP32 = T.float32
INT32 = T.int32
pass_configs = {
tl.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True,
......
......@@ -13,11 +13,11 @@ def preprocess(
D,
block_ND=32,
num_stages=5,
dtype="bfloat16",
accum_dtype="float",
dtype=T.bfloat16,
accum_dtype=T.float32,
):
assert dtype == "bfloat16"
assert accum_dtype == "float"
assert dtype == T.bfloat16
assert accum_dtype == T.float32
S = T.symbolic("S")
......@@ -53,11 +53,11 @@ def postprocess(
kv_group=1,
block_N=64,
threads=128,
dtype="bfloat16",
accum_dtype="float",
dtype=T.bfloat16,
accum_dtype=T.float32,
):
assert dtype == "bfloat16"
assert accum_dtype == "float"
assert dtype == T.bfloat16
assert accum_dtype == T.float32
S_kv = T.symbolic("S_kv")
dkv_shape = [S_kv, kv_group, D + D_tail]
......@@ -94,15 +94,15 @@ def bwd(
block_size=32,
num_stages=0,
threads=128,
indices_dtype="int32",
dtype="bfloat16",
accum_dtype="float",
indices_dtype=T.int32,
dtype=T.bfloat16,
accum_dtype=T.float32,
):
assert is_causal == True, "non-casual is not supported now"
assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
assert dtype == "bfloat16"
assert accum_dtype == "float"
assert indices_dtype == "int32"
assert dtype == T.bfloat16
assert accum_dtype == T.float32
assert indices_dtype == T.int32
if sm_scale is None:
sm_scale = (D + D_tail) ** (-0.5)
......@@ -119,9 +119,9 @@ def bwd(
lse_shape = [S, H]
offsets_shape = [B_plus_one]
token_indices_shape = [S, 2]
assert indices_dtype == "int32"
assert dtype == "bfloat16"
assert accum_dtype == "float"
assert indices_dtype == T.int32
assert dtype == T.bfloat16
assert accum_dtype == T.float32
H = H_kv
padded_H = max(tilelang.math.next_power_of_2(H_kv), 16)
......
......@@ -47,9 +47,9 @@ def sparse_mla_fwd(
lse_shape = [seq_len, heads]
offsets_shape = [batch_plus_one]
token_indices_shape = [seq_len, 2]
indices_dtype = "int32"
dtype = "bfloat16"
accum_dtype = "float"
indices_dtype = T.int32
dtype = T.bfloat16
accum_dtype = T.float32
G = kv_group
H = head_kv
......
......@@ -8,9 +8,9 @@ from einops import repeat, rearrange, einsum
from index import prepare_token_indices
from utils import get_abs_err, get_err_ratio
BF16 = "bfloat16"
FP32 = "float32"
INT32 = "int32"
BF16 = T.bfloat16
FP32 = T.float32
INT32 = T.int32
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
......@@ -41,9 +41,9 @@ def tl_sparse_mla_topk_reducesum_impl(
seq_len_kv = T.symbolic("seq_len_kv")
head_kv = heads // kv_group
indices_dtype = "int32"
dtype = "bfloat16"
accum_dtype = "float"
indices_dtype = T.int32
dtype = T.bfloat16
accum_dtype = T.float32
G = kv_group
H = head_kv
......
......@@ -98,8 +98,8 @@ def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtyp
def main(M=16384, N=16384, K=16384):
block_M, block_N, block_K = 128, 128, 32
trans_A, trans_B = False, False
in_dtype, out_dtype = "float16", "float16"
accum_dtype = "float32"
in_dtype, out_dtype = T.float16, T.float16
accum_dtype = T.float32
num_stages = 3
threads = 128
matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads)
......
......@@ -43,11 +43,11 @@ def main(M=1024, N=1024, use_autotune=False):
b = torch.randn(M, N, dtype=torch.float32, device="cuda")
if use_autotune:
kernel = elementwise_add(M, N, in_dtype="float32", out_dtype="float32")
kernel = elementwise_add(M, N, in_dtype=T.float32, out_dtype=T.float32)
else:
# Default config
config = {"block_M": 32, "block_N": 32, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
kernel = elementwise_add(M, N, **config, in_dtype=T.float32, out_dtype=T.float32)
out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
......
......@@ -17,8 +17,8 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
......@@ -89,8 +89,8 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim_v]
blk = 32
......@@ -129,8 +129,8 @@ def make_dq_layout(dQ):
},
)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim_qk):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim_qk]
blk = 64
......@@ -161,8 +161,8 @@ def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, bl
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
......@@ -256,8 +256,8 @@ def flashattn_bwd_split(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M
v_shape = [batch, seq_len, head_kv, dim_v]
dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel
dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
......
......@@ -20,8 +20,8 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
......@@ -94,8 +94,8 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim_v]
blk = 32
......@@ -134,8 +134,8 @@ def make_dq_layout(dQ):
},
)
def flashattn_bwd_postprocess(batch, heads, head_kv, seq_len, dim_qk, dim_v):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
......@@ -178,8 +178,8 @@ def flashattn_bwd_atomic_add(batch, heads, seq_len, dim_qk, dim_v, is_causal, bl
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
......@@ -276,8 +276,8 @@ def flashattn_bwd_split_novarlen(batch, heads, seq_len, dim_qk, dim_v, is_causal
v_shape = [batch, seq_len, head_kv, dim_v]
dk_shape = [groups, batch, seq_len, head_kv, dim_qk] # sum after kernel
dv_shape = [groups, batch, seq_len, head_kv, dim_v] # sum after kernel
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
......
......@@ -33,16 +33,16 @@ def flashattn_fwd(batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, d
k_shape = [total_kv, head_kv, dim_qk]
v_shape = [total_kv, head_kv, dim_v]
o_shape = [total_q, heads, dim_v]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
Q: T.Tensor(q_shape, dtype), # type: ignore
K: T.Tensor(k_shape, dtype), # type: ignore
V: T.Tensor(v_shape, dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
):
......@@ -143,8 +143,8 @@ def flashattn_fwd(batch, total_q, total_kv, N_CTX, heads, max_seq_len, dim_qk, d
},
)
def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
shape = [total_q, heads, dim_v]
blk = 32
......@@ -152,7 +152,7 @@ def flashattn_bwd_preprocess(batch, heads, total_q, N_CTX, max_seq_len, dim_v):
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore
Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch) as (bx, by, bz):
......@@ -198,8 +198,8 @@ def make_dq_layout(dQ):
},
)
def flashattn_bwd_postprocess(total_q, total_kv, heads, head_kv, dim_qk, dim_v):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
q_shape = [total_q, heads, dim_qk]
k_shape = [total_kv, head_kv, dim_qk]
v_shape = [total_kv, head_kv, dim_v]
......@@ -245,8 +245,8 @@ def flashattn_bwd_atomic_add(
k_shape = [total_kv, head_kv, dim_qk]
v_shape = [total_kv, head_kv, dim_v]
do_shape = [total_q, heads, dim_v]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
......@@ -256,8 +256,8 @@ def flashattn_bwd_atomic_add(
dO: T.Tensor(do_shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(k_shape, accum_dtype), # type: ignore
dV: T.Tensor(v_shape, accum_dtype), # type: ignore
......@@ -386,8 +386,8 @@ def flashattn_bwd_split(
do_shape = [total_q, heads, dim_v]
dk_shape = [groups, total_kv, head_kv, dim_qk] # sum after kernel
dv_shape = [groups, total_kv, head_kv, dim_v] # sum after kernel
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
......@@ -397,8 +397,8 @@ def flashattn_bwd_split(
dO: T.Tensor(do_shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, N_CTX], accum_dtype), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], "int32"), # type: ignore
cu_seqlens_q: T.Tensor([batch + 1], T.int32), # type: ignore
cu_seqlens_k: T.Tensor([batch + 1], T.int32), # type: ignore
dQ: T.Tensor(q_shape, accum_dtype), # type: ignore
dK: T.Tensor(dk_shape, dtype), # type: ignore
dV: T.Tensor(dv_shape, dtype), # type: ignore
......
......@@ -17,8 +17,8 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
......@@ -89,8 +89,8 @@ def flashattn_fwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim_v):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim_v]
blk = 32
......@@ -129,8 +129,8 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
q_shape = [batch, seq_len, heads, dim_qk]
k_shape = [batch, seq_len, head_kv, dim_qk]
v_shape = [batch, seq_len, head_kv, dim_v]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
......
......@@ -70,8 +70,8 @@ def flashattn(batch, heads, seq_len, dim, is_causal, groups=1, block_M=64, block
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.macro
def MMA0(
......
......@@ -45,8 +45,8 @@ def flashattn(
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.macro
def MMA0(
......
......@@ -65,16 +65,16 @@ def flashattn(batch_size, groups, UQ, UKV, heads, dim, is_causal, block_M=64, bl
q_shape = [UQ, heads, dim]
kv_shape = [UKV, head_kv, dim]
o_shape = [UQ, heads, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def main(
Q_unpad: T.Tensor(q_shape, dtype),
K_unpad: T.Tensor(kv_shape, dtype),
V_unpad: T.Tensor(kv_shape, dtype),
cu_seqlens_q: T.Tensor([batch_size + 1], "int32"),
cu_seqlens_k: T.Tensor([batch_size + 1], "int32"),
cu_seqlens_q: T.Tensor([batch_size + 1], T.int32),
cu_seqlens_k: T.Tensor([batch_size + 1], T.int32),
max_seqlen_q: T.int32,
Output_unpad: T.Tensor(o_shape, dtype),
):
......
......@@ -15,8 +15,8 @@ import argparse
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
......@@ -91,8 +91,8 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
shape = [batch, heads, seq_len, dim]
blk = 32
......@@ -131,8 +131,8 @@ def make_dq_layout(dQ):
},
)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
shape = [batch, heads, seq_len, dim]
blk = 64
......@@ -160,8 +160,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
sm_scale = (1.0 / dim) ** 0.5
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
......
......@@ -15,8 +15,8 @@ import argparse
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
......@@ -87,8 +87,8 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim]
blk = 32
......@@ -127,8 +127,8 @@ def make_dq_layout(dQ):
},
)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim]
blk = 64
......@@ -156,8 +156,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
sm_scale = (1.0 / dim) ** 0.5
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
......
......@@ -15,8 +15,8 @@ import argparse
def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
......@@ -88,8 +88,8 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim]
blk = 32
......@@ -125,8 +125,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N):
sm_scale = (1.0 / dim) ** 0.5
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd(
......
......@@ -24,8 +24,8 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=6
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
......
......@@ -24,8 +24,8 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=128, block_N=
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
......
......@@ -23,8 +23,8 @@ def get_configs():
def flashattn(batch, heads, seq_len, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.macro
def MMA0(
......
......@@ -23,8 +23,8 @@ def get_configs():
def flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.macro
def MMA0(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment