Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
......@@ -137,7 +137,7 @@ import tilelang.language as T
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float):
@T.prim_func
def matmul_relu_kernel(
......
......@@ -40,9 +40,9 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len]
dtype = "float16"
accum_dtype = "float"
block_mask_dtype = "bool"
dtype = T.float16
accum_dtype = T.float32
block_mask_dtype = T.bool
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
......
......@@ -202,8 +202,8 @@ def chunk_scan_fwd(
num_stages=2,
threads=128,
):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504
......
......@@ -62,9 +62,9 @@ def get_configs(args, kwargs):
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float",
in_dtype=T.float16,
out_dtype=T.float16,
accum_dtype=T.float32,
).with_arch(arch)
func = carve_template.equivalent_function()
......@@ -155,8 +155,8 @@ def matmul(
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def main(
......
......@@ -49,22 +49,22 @@ def tl_matmul(
enable_rasteration=False,
):
assert in_dtype in [
"float16",
"int8",
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
if out_dtype == T.int32:
micro_size_k = 32
# This is a debug config
# chunk = 32 if in_dtype == "float16" else 64
# chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn"
block_M = block_row_warps * warp_row_tiles
......@@ -194,9 +194,9 @@ def get_configs(args, kwargs):
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
in_dtype=T.float16,
out_dtype=T.float16,
accum_dtype=T.float16,
).with_arch(arch)
func = carve_template.equivalent_function()
......@@ -251,9 +251,9 @@ def matmul(
M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
in_dtype=T.float16,
out_dtype=T.float16,
accum_dtype=T.float16,
with_roller=False,
block_row_warps=None,
block_col_warps=None,
......@@ -295,9 +295,9 @@ if __name__ == "__main__":
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
in_dtype = args.dtype
out_dtype = "float32" if in_dtype == "int8" else "float16"
accum_dtype = "float32" if in_dtype == "int8" else "float16"
in_dtype = T.dtype(args.dtype)
out_dtype = T.float32 if in_dtype == T.int8 else T.float16
accum_dtype = T.float32 if in_dtype == T.int8 else T.float16
with_roller = args.with_roller
with_roller = True
# Compute total floating-point operations
......
......@@ -262,7 +262,7 @@ if __name__ == "__main__":
total_flops = 2 * M * N * K
# matmul(...) returns (best_latency, best_config, ref_latency)
best_result = matmul_sp(M, N, K, "float16", args.accum_dtype)
best_result = matmul_sp(M, N, K, T.float16, args.accum_dtype)
best_latency = best_result.latency
best_config = best_result.config
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
......
......@@ -63,9 +63,9 @@ def get_configs(args, kwargs):
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float",
in_dtype=T.float16,
out_dtype=T.float16,
accum_dtype=T.float32,
).with_arch(arch)
func = carve_template.equivalent_function()
......@@ -159,8 +159,8 @@ def matmul(
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float8_e4m3fnuz" if torch.version.hip is not None else "float8_e4m3"
accum_dtype = "float"
dtype = T.float8_e4m3fnuz if torch.version.hip is not None else T.float8_e4m3fn
accum_dtype = T.float32
@T.prim_func
def main(
......
......@@ -24,7 +24,7 @@ Please note that this tutorial does not delve deeply into the design principles
## Elementwise add in TileLang
```python
def elementwise_add(N, threads=256, dtype="bfloat16"):
def elementwise_add(N, threads=256, dtype=T.bfloat16):
@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
......@@ -43,7 +43,7 @@ Those familiar with CUDA programming might wonder where `threadIdx` fits into th
The program can be compiled using the following code:
```python
program = elementwise_add(1024, threads=256, dtype="bfloat16")
program = elementwise_add(1024, threads=256, dtype=T.bfloat16)
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
```
Launching the kernel is straightforward, just call it directly like a function:
......@@ -89,7 +89,7 @@ def elementwise_add(
In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this:
```python
program = elementwise_add(T.dynamic("N"), threads=256, dtype="bfloat16")
program = elementwise_add(T.dynamic("N"), threads=256, dtype=T.bfloat16)
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
```
......@@ -102,7 +102,7 @@ TileLang automatically incorporates boundary-checking conditions; however, this
When compiling the example below, let's set `N` to 2047:
```python
def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"):
def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16):
@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
......@@ -176,7 +176,7 @@ While TileLang incorporates various optimizations for the aforementioned case, i
In such scenarios, explicitly specifying the number of elements computed per thread can help "guide" TileLang's code generation process, leading to implementations that are more closely aligned with the intended design.
```python
def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"):
def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16):
@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
......@@ -212,7 +212,7 @@ Aha, this CUDA code aligns closely with conventional programming practices, maki
But what happens if we provide additional hints to TileLang? For instance, by explicitly specifying register copies using the `T.copy(...)` operation. The example below demonstrates a vector addition implementation. Unlike the previous examples, this code explicitly loads data into registers before performing computations.
```python
def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype="bfloat16"):
def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype=T.bfloat16):
@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
......
......@@ -87,8 +87,8 @@ def fast_flashattn(
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
vec_size = qk_coalesced_width
v_vec_size = v_coalesced_width
......@@ -109,7 +109,7 @@ def fast_flashattn(
num_q_blocks = T.ceildiv(seq_len, block_M)
bx_loop_var = T.alloc_var("int32")
bx_loop_var = T.alloc_var(T.int32)
bx_loop_var = b_split
with T.While(bx_loop_var < num_q_blocks):
......@@ -236,8 +236,8 @@ def get_bwd_configs():
@tilelang.jit(out_idx=[2])
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim]
blk = 32
......@@ -280,8 +280,8 @@ def flashattn_bwd(
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def flash_bwd_kernel(
......@@ -368,8 +368,8 @@ def flashattn_bwd(
@tilelang.jit(out_idx=[1])
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
shape = [batch, seq_len, heads, dim]
blk = 64
......
......@@ -100,8 +100,8 @@ def fast_flashattn(
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
vec_size = qk_coalesced_width
v_vec_size = v_coalesced_width
......@@ -121,7 +121,7 @@ def fast_flashattn(
num_q_blocks = T.ceildiv(seq_len, block_M)
bx = T.alloc_var("int32")
bx = T.alloc_var(T.int32)
bx = b_split
with T.While(bx < num_q_blocks):
......
......@@ -21,9 +21,9 @@ M = N = K = 1024
def kernel(block_M=128, block_N=128, block_K=32, num_stages=3, thread_num=128):
@T.prim_func
def main(A: T.Tensor((M, K), "float16"),
B: T.Tensor((N, K), "float16"),
C: T.Tensor((M, N), "float")):
def main(A: T.Tensor((M, K), T.float16),
B: T.Tensor((N, K), T.float16),
C: T.Tensor((M, N), T.float)):
# ... (kernel definition)
return main
......@@ -40,9 +40,9 @@ from tilelang.carver.arch import CUDA
def kernel(N=64, C=256, H=512, W=512, F=512, K=3, block_M=64, block_N=128):
@T.prim_func
def main(data: T.Tensor((N, H, W, C), "float16"),
kernel: T.Tensor((K, K, C, F), "float16"),
out: T.Tensor((N, (H-K+1), (W-K+1), F), "float")):
def main(data: T.Tensor((N, H, W, C), T.float16),
kernel: T.Tensor((K, K, C, F), T.float16),
out: T.Tensor((N, (H-K+1), (W-K+1), F), T.float)):
# ... (convolution kernel definition)
return main
......
......@@ -25,12 +25,12 @@ def check_hopper():
return False
def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"):
def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
is_hopper = check_hopper()
@T.prim_func
......
......@@ -15,8 +15,8 @@ def kernel(
thread_num=None,
enable_rasteration=None,
):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def matmul(
......
import torch
import argparse
from tilelang.profiler import do_bench
from tilelang import language as T
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
......@@ -135,7 +136,8 @@ def main(
dtype: str = "float16",
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print("Using sliding window attention.")
assert window_size <= seq_q
......
import torch
import argparse
from tilelang.profiler import do_bench
from tilelang import language as T
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
......@@ -131,7 +132,8 @@ def main(
dtype: str = "float16",
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print("Using sliding window attention.")
assert window_size <= seq_q
......
......@@ -37,7 +37,7 @@ def flashattn_fwd(
block_N=64,
num_stages=1,
threads=128,
dtype: str = "float16",
dtype: T.dtype = T.float16,
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
......@@ -49,7 +49,7 @@ def flashattn_fwd(
head_kv = heads // groups
q_shape = [batch, heads, seq_len, dim]
kv_shape = [batch, head_kv, seq_len, dim]
accum_dtype = "float"
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
......@@ -140,8 +140,8 @@ def flashattn_fwd(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float"
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16):
accum_dtype = T.float32
shape = [batch, heads, seq_len, dim]
blk = 32
......@@ -179,8 +179,8 @@ def make_dq_layout(dQ):
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float"
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16):
accum_dtype = T.float32
shape = [batch, heads, seq_len, dim]
blk = 64
......@@ -204,7 +204,7 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale=None, dtype="float16"): # None for full attention
def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale=None, dtype=T.float16): # None for full attention
if sm_scale is None:
sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e)
......@@ -212,7 +212,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale
head_kv = heads // groups
q_shape = [batch, heads, seq_len, dim]
kv_shape = [batch, head_kv, seq_len, dim]
accum_dtype = "float"
accum_dtype = T.float32
block_M, block_N, num_stages, threads = get_bwd_configs()
......@@ -309,8 +309,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale
@tilelang.jit(out_idx=-1)
def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16"):
accum_dtype = "float"
def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: T.dtype = T.float16):
accum_dtype = T.float32
shape = [batch, heads, seq_len]
@T.prim_func
......@@ -346,7 +346,7 @@ class _attention(torch.autograd.Function):
q, k, v, sinks = [maybe_contiguous(x) for x in (q, k, v, sinks)]
BATCH, H, N_CTX, D_HEAD = q.shape
dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16
kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype)
o, lse = kernel(q, k, v, sinks)
ctx.save_for_backward(q, k, v, sinks, o, lse)
......@@ -359,7 +359,7 @@ class _attention(torch.autograd.Function):
q, k, v, sinks, o, lse = ctx.saved_tensors
BATCH, H, N_CTX, D_HEAD = q.shape
groups = ctx.groups
dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
......@@ -440,7 +440,8 @@ def main(
window_size: Optional[int] = None,
dtype: str = "float16",
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print("Using sliding window attention.")
assert window_size <= N_CTX
......@@ -472,8 +473,8 @@ def main(
# Checks
rtol, atol = {
"float16": (1e-2, 1e-2),
"bfloat16": (2e-2, 2e-2),
T.float16: (1e-2, 1e-2),
T.bfloat16: (2e-2, 2e-2),
}[dtype]
assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}"
assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}"
......
......@@ -41,7 +41,7 @@ def flashattn(
block_N=128,
num_stages=2,
threads=256,
dtype: str = "float16",
dtype: T.dtype = T.float16,
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
......@@ -53,7 +53,7 @@ def flashattn(
head_kv = heads // groups
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, head_kv, seq_kv, dim]
accum_dtype = "float"
accum_dtype = T.float32
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
......@@ -263,10 +263,11 @@ def main(
dim: int = 128,
groups: int = 8,
window_size: Optional[int] = None,
dtype: str = "float16",
dtype: T.dtype = T.float16,
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print("Using sliding window attention.")
assert window_size <= seq_q
......
......@@ -36,7 +36,7 @@ def flashattn_fwd(
block_N=64,
num_stages=1,
threads=128,
dtype: str = "float16",
dtype: T.dtype = T.float16,
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
......@@ -46,7 +46,7 @@ def flashattn_fwd(
scale = sm_scale * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
accum_dtype = "float"
accum_dtype = T.float32
@T.prim_func
def flash_fwd(
......@@ -137,8 +137,8 @@ def flashattn_fwd(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float"
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16):
accum_dtype = T.float32
shape = [batch, heads, seq_len, dim]
blk = 32
......@@ -176,8 +176,8 @@ def make_dq_layout(dQ):
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float"
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16):
accum_dtype = T.float32
shape = [batch, heads, seq_len, dim]
blk = 64
......@@ -208,7 +208,7 @@ def flashattn_bwd(
dim,
window_size=None, # None for full attention
sm_scale=None,
dtype: str = "float16",
dtype: T.dtype = T.float16,
):
block_M, block_N, num_stages, threads = get_bwd_configs()
......@@ -217,7 +217,7 @@ def flashattn_bwd(
scale = sm_scale * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim]
accum_dtype = "float"
accum_dtype = T.float32
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
......@@ -315,8 +315,8 @@ def flashattn_bwd(
@tilelang.jit(out_idx=-1)
def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16"):
accum_dtype = "float"
def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: T.dtype = T.float16):
accum_dtype = T.float32
shape = [batch, heads, seq_len]
@T.prim_func
......@@ -346,7 +346,7 @@ class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sinks, window_size):
BATCH, H, N_CTX, D_HEAD = q.shape
dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16
kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype)
o, lse = kernel(q, k, v, sinks)
ctx.save_for_backward(q, k, v, sinks, o, lse)
......@@ -364,7 +364,7 @@ class _attention(torch.autograd.Function):
return x
do, q, k, v, sinks, o = [maybe_contiguous(x) for x in (do, q, k, v, sinks, o)]
dtype = "float16" if q.dtype == torch.float16 else "bfloat16"
dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
delta = kernel_prep(o, do)
......@@ -433,8 +433,9 @@ def ref_program(
return output.transpose(1, 2).contiguous()
def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window_size: Optional[int] = None, dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window_size: Optional[int] = None, dtype: T.dtype = T.float16):
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print("Using sliding window attention.")
assert window_size <= N_CTX
......@@ -466,8 +467,8 @@ def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window
# Checks
rtol, atol = {
"float16": (1e-2, 1e-2),
"bfloat16": (2e-2, 2e-2),
T.float16: (1e-2, 1e-2),
T.bfloat16: (2e-2, 2e-2),
}[dtype]
assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}"
assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}"
......
......@@ -35,7 +35,7 @@ def flashattn(
block_N=64,
num_stages=1,
threads=128,
dtype: str = "float16",
dtype: T.dtype = T.float16,
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
......@@ -45,7 +45,7 @@ def flashattn(
scale = sm_scale * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
accum_dtype = "float"
accum_dtype = T.float32
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
......@@ -246,10 +246,11 @@ def main(
seq_kv: int = 256,
dim: int = 128,
window_size: Optional[int] = None,
dtype: str = "float16",
dtype: T.dtype = T.float16,
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print("Using sliding window attention.")
assert window_size <= seq_q
......@@ -308,7 +309,7 @@ if __name__ == "__main__":
parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument("--dtype", type=str, default=T.float16, help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)
......@@ -36,7 +36,7 @@ def flashattn(
block_N=128,
num_stages=2,
threads=256,
dtype: str = "float16",
dtype: T.dtype = T.float16,
):
if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N"
......@@ -47,7 +47,7 @@ def flashattn(
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
accum_dtype = "float"
accum_dtype = T.float32
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
......@@ -256,10 +256,11 @@ def main(
seq_kv: int = 256,
dim: int = 128,
window_size: Optional[int] = None,
dtype: str = "float16",
dtype: T.dtype = T.float16,
tune: bool = False,
):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None:
print("Using sliding window attention.")
assert window_size <= seq_q
......@@ -315,7 +316,7 @@ if __name__ == "__main__":
parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16")
parser.add_argument("--dtype", type=str, default=T.float16, help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune")
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment