Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
...@@ -137,7 +137,7 @@ import tilelang.language as T ...@@ -137,7 +137,7 @@ import tilelang.language as T
# target currently can be "cuda" or "hip" or "cpu". # target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time # if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float):
@T.prim_func @T.prim_func
def matmul_relu_kernel( def matmul_relu_kernel(
......
...@@ -40,9 +40,9 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -40,9 +40,9 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len] block_mask_shape = [batch, heads, downsample_len, downsample_len]
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
block_mask_dtype = "bool" block_mask_dtype = T.bool
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
......
...@@ -202,8 +202,8 @@ def chunk_scan_fwd( ...@@ -202,8 +202,8 @@ def chunk_scan_fwd(
num_stages=2, num_stages=2,
threads=128, threads=128,
): ):
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
nchunks = T.ceildiv(seqlen, chunk_size) nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504 p = 1.44269504
......
...@@ -62,9 +62,9 @@ def get_configs(args, kwargs): ...@@ -62,9 +62,9 @@ def get_configs(args, kwargs):
M=M, M=M,
N=N, N=N,
K=K, K=K,
in_dtype="float16", in_dtype=T.float16,
out_dtype="float16", out_dtype=T.float16,
accum_dtype="float", accum_dtype=T.float32,
).with_arch(arch) ).with_arch(arch)
func = carve_template.equivalent_function() func = carve_template.equivalent_function()
...@@ -155,8 +155,8 @@ def matmul( ...@@ -155,8 +155,8 @@ def matmul(
# Use half-precision for input data to reduce memory bandwidth, # Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy # accumulate in float for better numerical accuracy
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
@T.prim_func @T.prim_func
def main( def main(
......
...@@ -49,22 +49,22 @@ def tl_matmul( ...@@ -49,22 +49,22 @@ def tl_matmul(
enable_rasteration=False, enable_rasteration=False,
): ):
assert in_dtype in [ assert in_dtype in [
"float16", T.float16,
"int8", T.int8,
], "Currently only float16 and int8 are supported" ], "Currently only float16 and int8 are supported"
assert out_dtype in [ assert out_dtype in [
"float16", T.float16,
"float32", T.float32,
"int32", T.int32,
], "Currently only float16, float32 and int32 are supported" ], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32": if out_dtype == T.int32:
micro_size_k = 32 micro_size_k = 32
# This is a debug config # This is a debug config
# chunk = 32 if in_dtype == "float16" else 64 # chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn" shared_scope = "shared.dyn"
block_M = block_row_warps * warp_row_tiles block_M = block_row_warps * warp_row_tiles
...@@ -194,9 +194,9 @@ def get_configs(args, kwargs): ...@@ -194,9 +194,9 @@ def get_configs(args, kwargs):
M=M, M=M,
N=N, N=N,
K=K, K=K,
in_dtype="float16", in_dtype=T.float16,
out_dtype="float16", out_dtype=T.float16,
accum_dtype="float16", accum_dtype=T.float16,
).with_arch(arch) ).with_arch(arch)
func = carve_template.equivalent_function() func = carve_template.equivalent_function()
...@@ -251,9 +251,9 @@ def matmul( ...@@ -251,9 +251,9 @@ def matmul(
M, M,
N, N,
K, K,
in_dtype="float16", in_dtype=T.float16,
out_dtype="float16", out_dtype=T.float16,
accum_dtype="float16", accum_dtype=T.float16,
with_roller=False, with_roller=False,
block_row_warps=None, block_row_warps=None,
block_col_warps=None, block_col_warps=None,
...@@ -295,9 +295,9 @@ if __name__ == "__main__": ...@@ -295,9 +295,9 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
M, N, K = args.m, args.n, args.k M, N, K = args.m, args.n, args.k
in_dtype = args.dtype in_dtype = T.dtype(args.dtype)
out_dtype = "float32" if in_dtype == "int8" else "float16" out_dtype = T.float32 if in_dtype == T.int8 else T.float16
accum_dtype = "float32" if in_dtype == "int8" else "float16" accum_dtype = T.float32 if in_dtype == T.int8 else T.float16
with_roller = args.with_roller with_roller = args.with_roller
with_roller = True with_roller = True
# Compute total floating-point operations # Compute total floating-point operations
......
...@@ -262,7 +262,7 @@ if __name__ == "__main__": ...@@ -262,7 +262,7 @@ if __name__ == "__main__":
total_flops = 2 * M * N * K total_flops = 2 * M * N * K
# matmul(...) returns (best_latency, best_config, ref_latency) # matmul(...) returns (best_latency, best_config, ref_latency)
best_result = matmul_sp(M, N, K, "float16", args.accum_dtype) best_result = matmul_sp(M, N, K, T.float16, args.accum_dtype)
best_latency = best_result.latency best_latency = best_result.latency
best_config = best_result.config best_config = best_result.config
A = torch.randn(M, K, dtype=torch.float16, device="cuda") A = torch.randn(M, K, dtype=torch.float16, device="cuda")
......
...@@ -63,9 +63,9 @@ def get_configs(args, kwargs): ...@@ -63,9 +63,9 @@ def get_configs(args, kwargs):
M=M, M=M,
N=N, N=N,
K=K, K=K,
in_dtype="float16", in_dtype=T.float16,
out_dtype="float16", out_dtype=T.float16,
accum_dtype="float", accum_dtype=T.float32,
).with_arch(arch) ).with_arch(arch)
func = carve_template.equivalent_function() func = carve_template.equivalent_function()
...@@ -159,8 +159,8 @@ def matmul( ...@@ -159,8 +159,8 @@ def matmul(
# Use half-precision for input data to reduce memory bandwidth, # Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy # accumulate in float for better numerical accuracy
dtype = "float8_e4m3fnuz" if torch.version.hip is not None else "float8_e4m3" dtype = T.float8_e4m3fnuz if torch.version.hip is not None else T.float8_e4m3fn
accum_dtype = "float" accum_dtype = T.float32
@T.prim_func @T.prim_func
def main( def main(
......
...@@ -24,7 +24,7 @@ Please note that this tutorial does not delve deeply into the design principles ...@@ -24,7 +24,7 @@ Please note that this tutorial does not delve deeply into the design principles
## Elementwise add in TileLang ## Elementwise add in TileLang
```python ```python
def elementwise_add(N, threads=256, dtype="bfloat16"): def elementwise_add(N, threads=256, dtype=T.bfloat16):
@T.prim_func @T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
...@@ -43,7 +43,7 @@ Those familiar with CUDA programming might wonder where `threadIdx` fits into th ...@@ -43,7 +43,7 @@ Those familiar with CUDA programming might wonder where `threadIdx` fits into th
The program can be compiled using the following code: The program can be compiled using the following code:
```python ```python
program = elementwise_add(1024, threads=256, dtype="bfloat16") program = elementwise_add(1024, threads=256, dtype=T.bfloat16)
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython") kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
``` ```
Launching the kernel is straightforward, just call it directly like a function: Launching the kernel is straightforward, just call it directly like a function:
...@@ -89,7 +89,7 @@ def elementwise_add( ...@@ -89,7 +89,7 @@ def elementwise_add(
In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this: In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this:
```python ```python
program = elementwise_add(T.dynamic("N"), threads=256, dtype="bfloat16") program = elementwise_add(T.dynamic("N"), threads=256, dtype=T.bfloat16)
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython") kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
``` ```
...@@ -102,7 +102,7 @@ TileLang automatically incorporates boundary-checking conditions; however, this ...@@ -102,7 +102,7 @@ TileLang automatically incorporates boundary-checking conditions; however, this
When compiling the example below, let's set `N` to 2047: When compiling the example below, let's set `N` to 2047:
```python ```python
def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"): def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16):
@T.prim_func @T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
...@@ -176,7 +176,7 @@ While TileLang incorporates various optimizations for the aforementioned case, i ...@@ -176,7 +176,7 @@ While TileLang incorporates various optimizations for the aforementioned case, i
In such scenarios, explicitly specifying the number of elements computed per thread can help "guide" TileLang's code generation process, leading to implementations that are more closely aligned with the intended design. In such scenarios, explicitly specifying the number of elements computed per thread can help "guide" TileLang's code generation process, leading to implementations that are more closely aligned with the intended design.
```python ```python
def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"): def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16):
@T.prim_func @T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
...@@ -212,7 +212,7 @@ Aha, this CUDA code aligns closely with conventional programming practices, maki ...@@ -212,7 +212,7 @@ Aha, this CUDA code aligns closely with conventional programming practices, maki
But what happens if we provide additional hints to TileLang? For instance, by explicitly specifying register copies using the `T.copy(...)` operation. The example below demonstrates a vector addition implementation. Unlike the previous examples, this code explicitly loads data into registers before performing computations. But what happens if we provide additional hints to TileLang? For instance, by explicitly specifying register copies using the `T.copy(...)` operation. The example below demonstrates a vector addition implementation. Unlike the previous examples, this code explicitly loads data into registers before performing computations.
```python ```python
def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype="bfloat16"): def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype=T.bfloat16):
@T.prim_func @T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)): def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
......
...@@ -87,8 +87,8 @@ def fast_flashattn( ...@@ -87,8 +87,8 @@ def fast_flashattn(
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim] q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim] kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
vec_size = qk_coalesced_width vec_size = qk_coalesced_width
v_vec_size = v_coalesced_width v_vec_size = v_coalesced_width
...@@ -109,7 +109,7 @@ def fast_flashattn( ...@@ -109,7 +109,7 @@ def fast_flashattn(
num_q_blocks = T.ceildiv(seq_len, block_M) num_q_blocks = T.ceildiv(seq_len, block_M)
bx_loop_var = T.alloc_var("int32") bx_loop_var = T.alloc_var(T.int32)
bx_loop_var = b_split bx_loop_var = b_split
with T.While(bx_loop_var < num_q_blocks): with T.While(bx_loop_var < num_q_blocks):
...@@ -236,8 +236,8 @@ def get_bwd_configs(): ...@@ -236,8 +236,8 @@ def get_bwd_configs():
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def flashattn_bwd_preprocess(batch, heads, seq_len, dim): def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
blk = 32 blk = 32
...@@ -280,8 +280,8 @@ def flashattn_bwd( ...@@ -280,8 +280,8 @@ def flashattn_bwd(
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim] q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim] kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
@T.prim_func @T.prim_func
def flash_bwd_kernel( def flash_bwd_kernel(
...@@ -368,8 +368,8 @@ def flashattn_bwd( ...@@ -368,8 +368,8 @@ def flashattn_bwd(
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def flashattn_bwd_postprocess(batch, heads, seq_len, dim): def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
shape = [batch, seq_len, heads, dim] shape = [batch, seq_len, heads, dim]
blk = 64 blk = 64
......
...@@ -100,8 +100,8 @@ def fast_flashattn( ...@@ -100,8 +100,8 @@ def fast_flashattn(
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim] q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim] kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
vec_size = qk_coalesced_width vec_size = qk_coalesced_width
v_vec_size = v_coalesced_width v_vec_size = v_coalesced_width
...@@ -121,7 +121,7 @@ def fast_flashattn( ...@@ -121,7 +121,7 @@ def fast_flashattn(
num_q_blocks = T.ceildiv(seq_len, block_M) num_q_blocks = T.ceildiv(seq_len, block_M)
bx = T.alloc_var("int32") bx = T.alloc_var(T.int32)
bx = b_split bx = b_split
with T.While(bx < num_q_blocks): with T.While(bx < num_q_blocks):
......
...@@ -21,9 +21,9 @@ M = N = K = 1024 ...@@ -21,9 +21,9 @@ M = N = K = 1024
def kernel(block_M=128, block_N=128, block_K=32, num_stages=3, thread_num=128): def kernel(block_M=128, block_N=128, block_K=32, num_stages=3, thread_num=128):
@T.prim_func @T.prim_func
def main(A: T.Tensor((M, K), "float16"), def main(A: T.Tensor((M, K), T.float16),
B: T.Tensor((N, K), "float16"), B: T.Tensor((N, K), T.float16),
C: T.Tensor((M, N), "float")): C: T.Tensor((M, N), T.float)):
# ... (kernel definition) # ... (kernel definition)
return main return main
...@@ -40,9 +40,9 @@ from tilelang.carver.arch import CUDA ...@@ -40,9 +40,9 @@ from tilelang.carver.arch import CUDA
def kernel(N=64, C=256, H=512, W=512, F=512, K=3, block_M=64, block_N=128): def kernel(N=64, C=256, H=512, W=512, F=512, K=3, block_M=64, block_N=128):
@T.prim_func @T.prim_func
def main(data: T.Tensor((N, H, W, C), "float16"), def main(data: T.Tensor((N, H, W, C), T.float16),
kernel: T.Tensor((K, K, C, F), "float16"), kernel: T.Tensor((K, K, C, F), T.float16),
out: T.Tensor((N, (H-K+1), (W-K+1), F), "float")): out: T.Tensor((N, (H-K+1), (W-K+1), F), T.float)):
# ... (convolution kernel definition) # ... (convolution kernel definition)
return main return main
......
...@@ -25,12 +25,12 @@ def check_hopper(): ...@@ -25,12 +25,12 @@ def check_hopper():
return False return False
def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"): def kernel(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32):
KH, KW = K, K KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
is_hopper = check_hopper() is_hopper = check_hopper()
@T.prim_func @T.prim_func
......
...@@ -15,8 +15,8 @@ def kernel( ...@@ -15,8 +15,8 @@ def kernel(
thread_num=None, thread_num=None,
enable_rasteration=None, enable_rasteration=None,
): ):
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
@T.prim_func @T.prim_func
def matmul( def matmul(
......
import torch import torch
import argparse import argparse
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
from tilelang import language as T
import triton import triton
import triton.language as tl import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor from triton.tools.tensor_descriptor import TensorDescriptor
...@@ -135,7 +136,8 @@ def main( ...@@ -135,7 +136,8 @@ def main(
dtype: str = "float16", dtype: str = "float16",
tune: bool = False, tune: bool = False,
): ):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None: if window_size is not None:
print("Using sliding window attention.") print("Using sliding window attention.")
assert window_size <= seq_q assert window_size <= seq_q
......
import torch import torch
import argparse import argparse
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
from tilelang import language as T
import triton import triton
import triton.language as tl import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor from triton.tools.tensor_descriptor import TensorDescriptor
...@@ -131,7 +132,8 @@ def main( ...@@ -131,7 +132,8 @@ def main(
dtype: str = "float16", dtype: str = "float16",
tune: bool = False, tune: bool = False,
): ):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None: if window_size is not None:
print("Using sliding window attention.") print("Using sliding window attention.")
assert window_size <= seq_q assert window_size <= seq_q
......
...@@ -37,7 +37,7 @@ def flashattn_fwd( ...@@ -37,7 +37,7 @@ def flashattn_fwd(
block_N=64, block_N=64,
num_stages=1, num_stages=1,
threads=128, threads=128,
dtype: str = "float16", dtype: T.dtype = T.float16,
): ):
if window_size is not None: if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N" assert window_size % block_N == 0, "window_size must be divisible by block_N"
...@@ -49,7 +49,7 @@ def flashattn_fwd( ...@@ -49,7 +49,7 @@ def flashattn_fwd(
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, heads, seq_len, dim] q_shape = [batch, heads, seq_len, dim]
kv_shape = [batch, head_kv, seq_len, dim] kv_shape = [batch, head_kv, seq_len, dim]
accum_dtype = "float" accum_dtype = T.float32
@T.prim_func @T.prim_func
def flash_fwd( def flash_fwd(
...@@ -140,8 +140,8 @@ def flashattn_fwd( ...@@ -140,8 +140,8 @@ def flashattn_fwd(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}, },
) )
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"): def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16):
accum_dtype = "float" accum_dtype = T.float32
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
blk = 32 blk = 32
...@@ -179,8 +179,8 @@ def make_dq_layout(dQ): ...@@ -179,8 +179,8 @@ def make_dq_layout(dQ):
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}, },
) )
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"): def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16):
accum_dtype = "float" accum_dtype = T.float32
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
blk = 64 blk = 64
...@@ -204,7 +204,7 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16" ...@@ -204,7 +204,7 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
} }
) )
def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale=None, dtype="float16"): # None for full attention def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale=None, dtype=T.float16): # None for full attention
if sm_scale is None: if sm_scale is None:
sm_scale = (1.0 / dim) ** 0.5 sm_scale = (1.0 / dim) ** 0.5
scale = sm_scale * 1.44269504 # log2(e) scale = sm_scale * 1.44269504 # log2(e)
...@@ -212,7 +212,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale ...@@ -212,7 +212,7 @@ def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, heads, seq_len, dim] q_shape = [batch, heads, seq_len, dim]
kv_shape = [batch, head_kv, seq_len, dim] kv_shape = [batch, head_kv, seq_len, dim]
accum_dtype = "float" accum_dtype = T.float32
block_M, block_N, num_stages, threads = get_bwd_configs() block_M, block_N, num_stages, threads = get_bwd_configs()
...@@ -309,8 +309,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale ...@@ -309,8 +309,8 @@ def flashattn_bwd(batch, heads, seq_len, dim, groups, window_size=None, sm_scale
@tilelang.jit(out_idx=-1) @tilelang.jit(out_idx=-1)
def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: str = "float16"): def flashattn_bwd_dsink(batch, heads, seq_len, block=256, dtype: T.dtype = T.float16):
accum_dtype = "float" accum_dtype = T.float32
shape = [batch, heads, seq_len] shape = [batch, heads, seq_len]
@T.prim_func @T.prim_func
...@@ -346,7 +346,7 @@ class _attention(torch.autograd.Function): ...@@ -346,7 +346,7 @@ class _attention(torch.autograd.Function):
q, k, v, sinks = [maybe_contiguous(x) for x in (q, k, v, sinks)] q, k, v, sinks = [maybe_contiguous(x) for x in (q, k, v, sinks)]
BATCH, H, N_CTX, D_HEAD = q.shape BATCH, H, N_CTX, D_HEAD = q.shape
dtype = "float16" if q.dtype == torch.float16 else "bfloat16" dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16
kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype) kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype)
o, lse = kernel(q, k, v, sinks) o, lse = kernel(q, k, v, sinks)
ctx.save_for_backward(q, k, v, sinks, o, lse) ctx.save_for_backward(q, k, v, sinks, o, lse)
...@@ -359,7 +359,7 @@ class _attention(torch.autograd.Function): ...@@ -359,7 +359,7 @@ class _attention(torch.autograd.Function):
q, k, v, sinks, o, lse = ctx.saved_tensors q, k, v, sinks, o, lse = ctx.saved_tensors
BATCH, H, N_CTX, D_HEAD = q.shape BATCH, H, N_CTX, D_HEAD = q.shape
groups = ctx.groups groups = ctx.groups
dtype = "float16" if q.dtype == torch.float16 else "bfloat16" dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
...@@ -440,7 +440,8 @@ def main( ...@@ -440,7 +440,8 @@ def main(
window_size: Optional[int] = None, window_size: Optional[int] = None,
dtype: str = "float16", dtype: str = "float16",
): ):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None: if window_size is not None:
print("Using sliding window attention.") print("Using sliding window attention.")
assert window_size <= N_CTX assert window_size <= N_CTX
...@@ -472,8 +473,8 @@ def main( ...@@ -472,8 +473,8 @@ def main(
# Checks # Checks
rtol, atol = { rtol, atol = {
"float16": (1e-2, 1e-2), T.float16: (1e-2, 1e-2),
"bfloat16": (2e-2, 2e-2), T.bfloat16: (2e-2, 2e-2),
}[dtype] }[dtype]
assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}" assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}"
assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}"
......
...@@ -41,7 +41,7 @@ def flashattn( ...@@ -41,7 +41,7 @@ def flashattn(
block_N=128, block_N=128,
num_stages=2, num_stages=2,
threads=256, threads=256,
dtype: str = "float16", dtype: T.dtype = T.float16,
): ):
if window_size is not None: if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N" assert window_size % block_N == 0, "window_size must be divisible by block_N"
...@@ -53,7 +53,7 @@ def flashattn( ...@@ -53,7 +53,7 @@ def flashattn(
head_kv = heads // groups head_kv = heads // groups
q_shape = [batch, heads, seq_q, dim] q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, head_kv, seq_kv, dim] kv_shape = [batch, head_kv, seq_kv, dim]
accum_dtype = "float" accum_dtype = T.float32
past_len = seq_kv - seq_q past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
...@@ -263,10 +263,11 @@ def main( ...@@ -263,10 +263,11 @@ def main(
dim: int = 128, dim: int = 128,
groups: int = 8, groups: int = 8,
window_size: Optional[int] = None, window_size: Optional[int] = None,
dtype: str = "float16", dtype: T.dtype = T.float16,
tune: bool = False, tune: bool = False,
): ):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None: if window_size is not None:
print("Using sliding window attention.") print("Using sliding window attention.")
assert window_size <= seq_q assert window_size <= seq_q
......
...@@ -36,7 +36,7 @@ def flashattn_fwd( ...@@ -36,7 +36,7 @@ def flashattn_fwd(
block_N=64, block_N=64,
num_stages=1, num_stages=1,
threads=128, threads=128,
dtype: str = "float16", dtype: T.dtype = T.float16,
): ):
if window_size is not None: if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N" assert window_size % block_N == 0, "window_size must be divisible by block_N"
...@@ -46,7 +46,7 @@ def flashattn_fwd( ...@@ -46,7 +46,7 @@ def flashattn_fwd(
scale = sm_scale * 1.44269504 # log2(e) scale = sm_scale * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
accum_dtype = "float" accum_dtype = T.float32
@T.prim_func @T.prim_func
def flash_fwd( def flash_fwd(
...@@ -137,8 +137,8 @@ def flashattn_fwd( ...@@ -137,8 +137,8 @@ def flashattn_fwd(
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}, },
) )
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"): def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16):
accum_dtype = "float" accum_dtype = T.float32
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
blk = 32 blk = 32
...@@ -176,8 +176,8 @@ def make_dq_layout(dQ): ...@@ -176,8 +176,8 @@ def make_dq_layout(dQ):
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}, },
) )
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"): def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: T.dtype = T.float16):
accum_dtype = "float" accum_dtype = T.float32
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
blk = 64 blk = 64
...@@ -208,7 +208,7 @@ def flashattn_bwd( ...@@ -208,7 +208,7 @@ def flashattn_bwd(
dim, dim,
window_size=None, # None for full attention window_size=None, # None for full attention
sm_scale=None, sm_scale=None,
dtype: str = "float16", dtype: T.dtype = T.float16,
): ):
block_M, block_N, num_stages, threads = get_bwd_configs() block_M, block_N, num_stages, threads = get_bwd_configs()
...@@ -217,7 +217,7 @@ def flashattn_bwd( ...@@ -217,7 +217,7 @@ def flashattn_bwd(
scale = sm_scale * 1.44269504 # log2(e) scale = sm_scale * 1.44269504 # log2(e)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
accum_dtype = "float" accum_dtype = T.float32
if window_size is not None: if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N" assert window_size % block_N == 0, "window_size must be divisible by block_N"
...@@ -315,8 +315,8 @@ def flashattn_bwd( ...@@ -315,8 +315,8 @@ def flashattn_bwd(
@tilelang.jit(out_idx=-1) @tilelang.jit(out_idx=-1)
def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: str = "float16"): def flashattn_bwd_dsink(batch, heads, seq_len, block=128, dtype: T.dtype = T.float16):
accum_dtype = "float" accum_dtype = T.float32
shape = [batch, heads, seq_len] shape = [batch, heads, seq_len]
@T.prim_func @T.prim_func
...@@ -346,7 +346,7 @@ class _attention(torch.autograd.Function): ...@@ -346,7 +346,7 @@ class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, sinks, window_size): def forward(ctx, q, k, v, sinks, window_size):
BATCH, H, N_CTX, D_HEAD = q.shape BATCH, H, N_CTX, D_HEAD = q.shape
dtype = "float16" if q.dtype == torch.float16 else "bfloat16" dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16
kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype) kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype)
o, lse = kernel(q, k, v, sinks) o, lse = kernel(q, k, v, sinks)
ctx.save_for_backward(q, k, v, sinks, o, lse) ctx.save_for_backward(q, k, v, sinks, o, lse)
...@@ -364,7 +364,7 @@ class _attention(torch.autograd.Function): ...@@ -364,7 +364,7 @@ class _attention(torch.autograd.Function):
return x return x
do, q, k, v, sinks, o = [maybe_contiguous(x) for x in (do, q, k, v, sinks, o)] do, q, k, v, sinks, o = [maybe_contiguous(x) for x in (do, q, k, v, sinks, o)]
dtype = "float16" if q.dtype == torch.float16 else "bfloat16" dtype = T.float16 if q.dtype == torch.float16 else T.bfloat16
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype)
delta = kernel_prep(o, do) delta = kernel_prep(o, do)
...@@ -433,8 +433,9 @@ def ref_program( ...@@ -433,8 +433,9 @@ def ref_program(
return output.transpose(1, 2).contiguous() return output.transpose(1, 2).contiguous()
def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window_size: Optional[int] = None, dtype: str = "float16"): def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window_size: Optional[int] = None, dtype: T.dtype = T.float16):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None: if window_size is not None:
print("Using sliding window attention.") print("Using sliding window attention.")
assert window_size <= N_CTX assert window_size <= N_CTX
...@@ -466,8 +467,8 @@ def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window ...@@ -466,8 +467,8 @@ def main(BATCH: int = 1, H: int = 1, N_CTX: int = 512, D_HEAD: int = 128, window
# Checks # Checks
rtol, atol = { rtol, atol = {
"float16": (1e-2, 1e-2), T.float16: (1e-2, 1e-2),
"bfloat16": (2e-2, 2e-2), T.bfloat16: (2e-2, 2e-2),
}[dtype] }[dtype]
assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}" assert torch.allclose(O, O_ref, rtol=rtol, atol=atol), f"O max err: {(O - O_ref).abs().max()}"
assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}"
......
...@@ -35,7 +35,7 @@ def flashattn( ...@@ -35,7 +35,7 @@ def flashattn(
block_N=64, block_N=64,
num_stages=1, num_stages=1,
threads=128, threads=128,
dtype: str = "float16", dtype: T.dtype = T.float16,
): ):
if window_size is not None: if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N" assert window_size % block_N == 0, "window_size must be divisible by block_N"
...@@ -45,7 +45,7 @@ def flashattn( ...@@ -45,7 +45,7 @@ def flashattn(
scale = sm_scale * 1.44269504 # log2(e) scale = sm_scale * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim] q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim] kv_shape = [batch, heads, seq_kv, dim]
accum_dtype = "float" accum_dtype = T.float32
past_len = seq_kv - seq_q past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
...@@ -246,10 +246,11 @@ def main( ...@@ -246,10 +246,11 @@ def main(
seq_kv: int = 256, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
window_size: Optional[int] = None, window_size: Optional[int] = None,
dtype: str = "float16", dtype: T.dtype = T.float16,
tune: bool = False, tune: bool = False,
): ):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None: if window_size is not None:
print("Using sliding window attention.") print("Using sliding window attention.")
assert window_size <= seq_q assert window_size <= seq_q
...@@ -308,7 +309,7 @@ if __name__ == "__main__": ...@@ -308,7 +309,7 @@ if __name__ == "__main__":
parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim") parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") parser.add_argument("--dtype", type=str, default=T.float16, help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune") parser.add_argument("--tune", action="store_true", help="tune")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)
...@@ -36,7 +36,7 @@ def flashattn( ...@@ -36,7 +36,7 @@ def flashattn(
block_N=128, block_N=128,
num_stages=2, num_stages=2,
threads=256, threads=256,
dtype: str = "float16", dtype: T.dtype = T.float16,
): ):
if window_size is not None: if window_size is not None:
assert window_size % block_N == 0, "window_size must be divisible by block_N" assert window_size % block_N == 0, "window_size must be divisible by block_N"
...@@ -47,7 +47,7 @@ def flashattn( ...@@ -47,7 +47,7 @@ def flashattn(
q_shape = [batch, heads, seq_q, dim] q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim] kv_shape = [batch, heads, seq_kv, dim]
accum_dtype = "float" accum_dtype = T.float32
past_len = seq_kv - seq_q past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
...@@ -256,10 +256,11 @@ def main( ...@@ -256,10 +256,11 @@ def main(
seq_kv: int = 256, seq_kv: int = 256,
dim: int = 128, dim: int = 128,
window_size: Optional[int] = None, window_size: Optional[int] = None,
dtype: str = "float16", dtype: T.dtype = T.float16,
tune: bool = False, tune: bool = False,
): ):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] dtype = T.dtype(dtype)
torch_dtype = dtype.as_torch()
if window_size is not None: if window_size is not None:
print("Using sliding window attention.") print("Using sliding window attention.")
assert window_size <= seq_q assert window_size <= seq_q
...@@ -315,7 +316,7 @@ if __name__ == "__main__": ...@@ -315,7 +316,7 @@ if __name__ == "__main__":
parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value") parser.add_argument("--seq_kv", type=int, default=4096, help="sequence length of key/value")
parser.add_argument("--dim", type=int, default=128, help="dim") parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)")
parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") parser.add_argument("--dtype", type=str, default=T.float16, help="dtype, can be float16 or bfloat16")
parser.add_argument("--tune", action="store_true", help="tune") parser.add_argument("--tune", action="store_true", help="tune")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune) main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.window_size, args.dtype, args.tune)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment