Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
...@@ -76,13 +76,13 @@ def bitnet_158_int8xint2_decode( ...@@ -76,13 +76,13 @@ def bitnet_158_int8xint2_decode(
reduce_thread=32, reduce_thread=32,
): ):
assert in_dtype in [ assert in_dtype in [
"float16", T.float16,
"int8", T.int8,
], "Currently only float16 and int8 are supported" ], "Currently only float16 and int8 are supported"
assert out_dtype in [ assert out_dtype in [
"float16", T.float16,
"float32", T.float32,
"int32", T.int32,
], "Currently only float16, float32 and int32 are supported" ], "Currently only float16, float32 and int32 are supported"
storage_nbit = 8 storage_nbit = 8
num_bits = 2 num_bits = 2
...@@ -94,7 +94,7 @@ def bitnet_158_int8xint2_decode( ...@@ -94,7 +94,7 @@ def bitnet_158_int8xint2_decode(
MAX_TRANSACTION_SIZE_IN_BITS = 128 MAX_TRANSACTION_SIZE_IN_BITS = 128
micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
micro_size_k_compressed = micro_size_k // num_elems_per_byte micro_size_k_compressed = micro_size_k // num_elems_per_byte
storage_dtype = "int8" storage_dtype = T.int8
block_K = reduce_thread * micro_size_k block_K = reduce_thread * micro_size_k
use_dp4a = True use_dp4a = True
...@@ -194,12 +194,12 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): ...@@ -194,12 +194,12 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8):
# interleave weight numpy implementation # interleave weight numpy implementation
def interleave_weight(qweight, nbits=4, target_dtype="float16"): def interleave_weight(qweight, nbits=4, target_dtype=T.float16):
assert target_dtype in ["float16", "int8"] assert target_dtype in [T.float16, T.int8]
# reinterpret the data type of qweight to int32 # reinterpret the data type of qweight to int32
qweight = qweight.view(np.int32) qweight = qweight.view(np.int32)
new_qweight = np.zeros_like(qweight) new_qweight = np.zeros_like(qweight)
bits_stride = 8 if target_dtype == "int8" else 16 bits_stride = 8 if target_dtype == T.int8 else 16
mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f
num_groups = 32 // bits_stride num_groups = 32 // bits_stride
elems_per_group = bits_stride // nbits elems_per_group = bits_stride // nbits
...@@ -209,7 +209,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): ...@@ -209,7 +209,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits
new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift
if nbits == 1 and target_dtype == "int8": if nbits == 1 and target_dtype == T.int8:
# special handling for 1b interleave # special handling for 1b interleave
n16_weight = new_qweight & np.int32(0xF0F00F0F) n16_weight = new_qweight & np.int32(0xF0F00F0F)
n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16
...@@ -217,12 +217,12 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): ...@@ -217,12 +217,12 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4
n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12
return n16_weight.view(np.int8) return n16_weight.view(np.int8)
elif nbits == 2 and target_dtype == "float16": elif nbits == 2 and target_dtype == T.float16:
n8_weight = new_qweight & np.int32(0xFF0000FF) n8_weight = new_qweight & np.int32(0xFF0000FF)
n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16
n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8
return n8_weight.view(np.int8) return n8_weight.view(np.int8)
elif nbits == 1 and target_dtype == "float16": elif nbits == 1 and target_dtype == T.float16:
n8_weight = new_qweight & 0xF000000F n8_weight = new_qweight & 0xF000000F
n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8
n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16
...@@ -259,4 +259,4 @@ def assert_bitnet_158_int8xint2_decode_correctness(M, N, K, in_dtype, out_dtype, ...@@ -259,4 +259,4 @@ def assert_bitnet_158_int8xint2_decode_correctness(M, N, K, in_dtype, out_dtype,
if __name__ == "__main__": if __name__ == "__main__":
assert_bitnet_158_int8xint2_decode_correctness(1, 256, 256, "int8", "int32", "int32") assert_bitnet_158_int8xint2_decode_correctness(1, 256, 256, T.int8, T.int32, T.int32)
...@@ -88,9 +88,9 @@ def bitnet_158_int8xint2_prefill( ...@@ -88,9 +88,9 @@ def bitnet_158_int8xint2_prefill(
Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved low‑precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C. Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved low‑precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C.
The returned prim_func expects: The returned prim_func expects:
- A: shape (M, K) with dtype `in_dtype` ("float16" or "int8"). - A: shape (M, K) with dtype `in_dtype` (T.float16 or T.int8).
- B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte). - B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte).
- C: output buffer shape (M, N) with dtype `out_dtype` ("float16", "float32", or "int32"). - C: output buffer shape (M, N) with dtype `out_dtype` (T.float16, T.float32, or T.int32).
Details: Details:
- Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter. - Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter.
...@@ -98,15 +98,15 @@ def bitnet_158_int8xint2_prefill( ...@@ -98,15 +98,15 @@ def bitnet_158_int8xint2_prefill(
- block_row_warps, block_col_warps: number of warps per block in row/col. - block_row_warps, block_col_warps: number of warps per block in row/col.
- warp_row_tiles, warp_col_tiles: tiles per warp. - warp_row_tiles, warp_col_tiles: tiles per warp.
- chunk: K-sized chunk per block (block_K). - chunk: K-sized chunk per block (block_K).
- micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == "int32"). - micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == T.int32).
- Uses 2-stage pipelining by default to overlap loads and compute and applies a swizzle layout to improve L2 behavior. - Uses 2-stage pipelining by default to overlap loads and compute and applies a swizzle layout to improve L2 behavior.
- Assertions: raises AssertionError if in_dtype or out_dtype are not among supported values. - Assertions: raises AssertionError if in_dtype or out_dtype are not among supported values.
Parameters: Parameters:
M, N, K (int): Global matrix dimensions. M, N, K (int): Global matrix dimensions.
in_dtype (str): Input and decoded B element dtype; "float16" or "int8". in_dtype (str): Input and decoded B element dtype; T.float16 or T.int8.
out_dtype (str): Output C dtype; one of "float16", "float32", "int32". out_dtype (str): Output C dtype; one of T.float16, T.float32, T.int32.
accum_dtype (str): Accumulator dtype used by MMA (e.g., "int32"). accum_dtype (str): Accumulator dtype used by MMA (e.g., T.int32).
fast_decoding (bool): If True, enable the fast decoding path (affects which device decode is used). fast_decoding (bool): If True, enable the fast decoding path (affects which device decode is used).
block_row_warps (int): Warps in block row dimension. block_row_warps (int): Warps in block row dimension.
block_col_warps (int): Warps in block column dimension. block_col_warps (int): Warps in block column dimension.
...@@ -118,18 +118,18 @@ def bitnet_158_int8xint2_prefill( ...@@ -118,18 +118,18 @@ def bitnet_158_int8xint2_prefill(
T.prim_func: A TVM prim_func implementing the described GPU kernel suitable for compilation and execution. T.prim_func: A TVM prim_func implementing the described GPU kernel suitable for compilation and execution.
""" """
assert in_dtype in [ assert in_dtype in [
"float16", T.float16,
"int8", T.int8,
], "Currently only float16 and int8 are supported" ], "Currently only float16 and int8 are supported"
assert out_dtype in [ assert out_dtype in [
"float16", T.float16,
"float32", T.float32,
"int32", T.int32,
], "Currently only float16, float32 and int32 are supported" ], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
if accum_dtype == "int32": if accum_dtype == T.int32:
micro_size_k = 32 micro_size_k = 32
num_elems_per_byte = 4 num_elems_per_byte = 4
...@@ -138,7 +138,7 @@ def bitnet_158_int8xint2_prefill( ...@@ -138,7 +138,7 @@ def bitnet_158_int8xint2_prefill(
local_size_compressed = local_size // num_elems_per_byte local_size_compressed = local_size // num_elems_per_byte
shared_scope = "shared.dyn" shared_scope = "shared.dyn"
storage_dtype = "int8" storage_dtype = T.int8
# Pipeline Stage # Pipeline Stage
stage = 2 stage = 2
...@@ -317,12 +317,12 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): ...@@ -317,12 +317,12 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8):
# interleave weight numpy implementation # interleave weight numpy implementation
def interleave_weight(qweight, nbits=4, target_dtype="float16"): def interleave_weight(qweight, nbits=4, target_dtype=T.float16):
assert target_dtype in ["float16", "int8"] assert target_dtype in [T.float16, T.int8]
# reinterpret the data type of qweight to int32 # reinterpret the data type of qweight to int32
qweight = qweight.view(np.int32) qweight = qweight.view(np.int32)
new_qweight = np.zeros_like(qweight) new_qweight = np.zeros_like(qweight)
bits_stride = 8 if target_dtype == "int8" else 16 bits_stride = 8 if target_dtype == T.int8 else 16
mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f
num_groups = 32 // bits_stride num_groups = 32 // bits_stride
elems_per_group = bits_stride // nbits elems_per_group = bits_stride // nbits
...@@ -332,7 +332,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): ...@@ -332,7 +332,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits
new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift
if nbits == 1 and target_dtype == "int8": if nbits == 1 and target_dtype == T.int8:
# special handling for 1b interleave # special handling for 1b interleave
n16_weight = new_qweight & np.int32(0xF0F00F0F) n16_weight = new_qweight & np.int32(0xF0F00F0F)
n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16
...@@ -340,12 +340,12 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"): ...@@ -340,12 +340,12 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4
n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12
return n16_weight.view(np.int8) return n16_weight.view(np.int8)
elif nbits == 2 and target_dtype == "float16": elif nbits == 2 and target_dtype == T.float16:
n8_weight = new_qweight & np.int32(0xFF0000FF) n8_weight = new_qweight & np.int32(0xFF0000FF)
n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16
n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8
return n8_weight.view(np.int8) return n8_weight.view(np.int8)
elif nbits == 1 and target_dtype == "float16": elif nbits == 1 and target_dtype == T.float16:
n8_weight = new_qweight & 0xF000000F n8_weight = new_qweight & 0xF000000F
n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8
n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16
...@@ -382,4 +382,4 @@ def assert_bitnet_158_int8xint2_prefill_correctness(M, N, K, in_dtype, out_dtype ...@@ -382,4 +382,4 @@ def assert_bitnet_158_int8xint2_prefill_correctness(M, N, K, in_dtype, out_dtype
if __name__ == "__main__": if __name__ == "__main__":
assert_bitnet_158_int8xint2_prefill_correctness(256, 256, 256, "int8", "int32", "int32") assert_bitnet_158_int8xint2_prefill_correctness(256, 256, 256, T.int8, T.int32, T.int32)
...@@ -38,18 +38,18 @@ def tl_matmul( ...@@ -38,18 +38,18 @@ def tl_matmul(
accum_dtype, accum_dtype,
): ):
assert in_dtype in [ assert in_dtype in [
"float16", T.float16,
"int8", T.int8,
], "Currently only float16 and int8 are supported" ], "Currently only float16 and int8 are supported"
assert out_dtype in [ assert out_dtype in [
"float16", T.float16,
"float32", T.float32,
"int32", T.int32,
], "Currently only float16, float32 and int32 are supported" ], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32": if out_dtype == T.int32:
micro_size_k = 32 micro_size_k = 32
# This is a debug config # This is a debug config
...@@ -57,7 +57,7 @@ def tl_matmul( ...@@ -57,7 +57,7 @@ def tl_matmul(
block_col_warps = 2 block_col_warps = 2
warp_row_tiles = 64 warp_row_tiles = 64
warp_col_tiles = 64 warp_col_tiles = 64
chunk = 32 if in_dtype == "float16" else 64 chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn" shared_scope = "shared.dyn"
# Pipeline Stage # Pipeline Stage
...@@ -183,7 +183,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -183,7 +183,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
print(src_code) print(src_code)
if in_dtype == "int8": if in_dtype == T.int8:
A = torch.randint(-7, 7, (M, K), device="cuda", dtype=torch.int8) A = torch.randint(-7, 7, (M, K), device="cuda", dtype=torch.int8)
B = torch.randint(-7, 7, (N, K), device="cuda", dtype=torch.int8) B = torch.randint(-7, 7, (N, K), device="cuda", dtype=torch.int8)
else: else:
...@@ -209,12 +209,12 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -209,12 +209,12 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
def test_assert_tl_matmul(): def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16)
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32)
if __name__ == "__main__": if __name__ == "__main__":
# bitblas.testing.main() # bitblas.testing.main()
# assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") # assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16)
# assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") # assert_tl_matmul_correctness(128, 128, 128, T.int8, T.int32, T.int32)
assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") assert_tl_matmul_correctness(16384, 16384, 16384, T.int8, T.int32, T.int32)
...@@ -41,9 +41,9 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -41,9 +41,9 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len] block_mask_shape = [batch, heads, downsample_len, downsample_len]
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
block_mask_dtype = "bool" block_mask_dtype = T.bool
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
......
...@@ -14,8 +14,8 @@ from heuristic import num_splits_heuristic ...@@ -14,8 +14,8 @@ from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v): def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
kv_group_num = heads // heads_kv kv_group_num = heads // heads_kv
@tilelang.jit( @tilelang.jit(
...@@ -43,9 +43,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -43,9 +43,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"), block_indices: T.Tensor(shape_indices, T.int32),
cache_seqlens: T.Tensor([batch], "int32"), cache_seqlens: T.Tensor([batch], T.int32),
block_table: T.Tensor(shape_block_table, "int32"), block_table: T.Tensor(shape_block_table, T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype), glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype),
): ):
...@@ -139,7 +139,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -139,7 +139,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
lse_logsum_local = T.alloc_local([1], accum_dtype) lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype) lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype)
max_split = T.alloc_local([1], "int32") max_split = T.alloc_local([1], T.int32)
T.annotate_layout( T.annotate_layout(
{ {
...@@ -177,9 +177,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -177,9 +177,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"), block_indices: T.Tensor(shape_indices, T.int32),
cache_seqlens: T.Tensor([batch], "int32"), cache_seqlens: T.Tensor([batch], T.int32),
block_table: T.Tensor(shape_block_table, "int32"), block_table: T.Tensor(shape_block_table, T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype), glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
......
...@@ -11,8 +11,8 @@ from heuristic import num_splits_heuristic ...@@ -11,8 +11,8 @@ from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v): def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
kv_group_num = heads // heads_kv kv_group_num = heads // heads_kv
@tilelang.jit( @tilelang.jit(
...@@ -35,9 +35,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -35,9 +35,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"), block_indices: T.Tensor(shape_indices, T.int32),
cache_seqlens: T.Tensor([batch], "int32"), cache_seqlens: T.Tensor([batch], T.int32),
# actual_num_blocks: T.Tensor([batch], "int32"), # actual_num_blocks: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype), glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype),
): ):
...@@ -128,7 +128,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -128,7 +128,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
lse_logsum_local = T.alloc_local([1], accum_dtype) lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype) lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_local([1], accum_dtype)
max_split = T.alloc_local([1], "int32") max_split = T.alloc_local([1], T.int32)
T.annotate_layout( T.annotate_layout(
{ {
...@@ -166,9 +166,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -166,9 +166,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"), block_indices: T.Tensor(shape_indices, T.int32),
cache_seqlens: T.Tensor([batch], "int32"), cache_seqlens: T.Tensor([batch], T.int32),
# actual_num_blocks: T.Tensor([batch], "int32"), # actual_num_blocks: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype), glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
......
...@@ -13,8 +13,8 @@ from heuristic import num_splits_heuristic ...@@ -13,8 +13,8 @@ from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v): def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
kv_group_num = heads // heads_kv kv_group_num = heads // heads_kv
@tilelang.jit( @tilelang.jit(
...@@ -37,8 +37,8 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -37,8 +37,8 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
block_mask: T.Tensor(shape_mask, "bool"), block_mask: T.Tensor(shape_mask, T.bool),
cache_seqlens: T.Tensor([batch], "int32"), cache_seqlens: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype), glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype),
): ):
...@@ -156,8 +156,8 @@ def flashattn(batch, heads, heads_kv, dim, dim_v): ...@@ -156,8 +156,8 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
block_mask: T.Tensor(shape_mask, "bool"), block_mask: T.Tensor(shape_mask, T.bool),
cache_seqlens: T.Tensor([batch], "int32"), cache_seqlens: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype), glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype), Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
......
...@@ -93,7 +93,7 @@ def supply_program(params: List[KernelParam]): ...@@ -93,7 +93,7 @@ def supply_program(params: List[KernelParam]):
) )
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def blocksparse_matmul( def blocksparse_matmul(
M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float" M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32
): ):
block_mask_shape = (M // block_M, N // block_N, K // block_K) block_mask_shape = (M // block_M, N // block_N, K // block_K)
......
...@@ -5,8 +5,8 @@ from typing import Tuple ...@@ -5,8 +5,8 @@ from typing import Tuple
from tilelang.utils.tensor import torch_assert_close from tilelang.utils.tensor import torch_assert_close
# support bfloat16, float, float16 # support bfloat16, float, float16
dtype = "bfloat16" dtype = T.bfloat16
accum_dtype = "float" accum_dtype = T.float32
@tilelang.jit(out_idx=[2, 3]) @tilelang.jit(out_idx=[2, 3])
...@@ -18,8 +18,8 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): ...@@ -18,8 +18,8 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
@T.prim_func @T.prim_func
def group_per_split_token_cast( def group_per_split_token_cast(
X: T.Tensor((M, N), dtype), X: T.Tensor((M, N), dtype),
batch_sizes: T.Tensor((BG,), "int32"), batch_sizes: T.Tensor((BG,), T.int32),
X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"), X_fp8: T.Tensor((BG, M_max, N), T.float8_e4m3fn),
X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)), accum_dtype), X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)), accum_dtype),
): ):
with T.Kernel(T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz):
...@@ -30,8 +30,8 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): ...@@ -30,8 +30,8 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
y_amax_local = T.alloc_fragment((blk_m,), accum_dtype) y_amax_local = T.alloc_fragment((blk_m,), accum_dtype)
y_s_local = T.alloc_fragment((blk_m,), accum_dtype) y_s_local = T.alloc_fragment((blk_m,), accum_dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype) y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn)
row_offset = T.alloc_fragment((1,), "int32") row_offset = T.alloc_fragment((1,), T.int32)
T.annotate_layout( T.annotate_layout(
{ {
...@@ -163,11 +163,11 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> Tuple[torch.Tenso ...@@ -163,11 +163,11 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> Tuple[torch.Tenso
def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None): def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None):
if batch_sizes is None: if batch_sizes is None:
batch_sizes = [2048, 6144] batch_sizes = [2048, 6144]
if dtype == "float": if dtype == T.float:
x = torch.randn(M, N, device="cuda", dtype=torch.float32) x = torch.randn(M, N, device="cuda", dtype=torch.float32)
elif dtype == "float16": elif dtype == T.float16:
x = torch.randn(M, N, device="cuda", dtype=torch.float16) x = torch.randn(M, N, device="cuda", dtype=torch.float16)
elif dtype == "bfloat16": elif dtype == T.bfloat16:
x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
else: else:
raise ValueError(f"Unsupported dtype: {dtype}") raise ValueError(f"Unsupported dtype: {dtype}")
......
...@@ -7,14 +7,14 @@ from tilelang.utils.tensor import torch_assert_close ...@@ -7,14 +7,14 @@ from tilelang.utils.tensor import torch_assert_close
@tilelang.jit(out_idx=[1, 2]) @tilelang.jit(out_idx=[1, 2])
def per_token_cast_to_fp8(M, N, blk_m): def per_token_cast_to_fp8(M, N, blk_m):
dtype = "float" dtype = T.float
group_size = 128 group_size = 128
fp8_min = -448.0 fp8_min = -448.0
fp8_max = 448.0 fp8_max = 448.0
@T.prim_func @T.prim_func
def per_token_cast( def per_token_cast(
X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"), X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype) X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), T.float8_e4m3fn), X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)
): ):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by): with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by):
row = bx row = bx
...@@ -23,7 +23,7 @@ def per_token_cast_to_fp8(M, N, blk_m): ...@@ -23,7 +23,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
y_amax_local = T.alloc_fragment((blk_m,), dtype) y_amax_local = T.alloc_fragment((blk_m,), dtype)
y_s_local = T.alloc_fragment((blk_m,), dtype) y_s_local = T.alloc_fragment((blk_m,), dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), dtype) y_q_local = T.alloc_fragment((blk_m, group_size), dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn)
T.annotate_layout( T.annotate_layout(
{ {
......
import tilelang
import tilelang.language as T
# @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
M = 1024
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
func = matmul(M, N, K, block_M, block_N, block_K)
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr")
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"])
import torch
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = jit_kernel(a, b)
print(c)
ref_c = a @ b
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
...@@ -25,12 +25,12 @@ def ref_program(stride, padding, dilation): ...@@ -25,12 +25,12 @@ def ref_program(stride, padding, dilation):
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"): def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32):
KH, KW = K, K KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
is_hopper = check_hopper() is_hopper = check_hopper()
@T.prim_func @T.prim_func
......
...@@ -75,13 +75,13 @@ def get_heuristic_config() -> dict: ...@@ -75,13 +75,13 @@ def get_heuristic_config() -> dict:
@tilelang.autotune(configs=get_configs()) @tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def convolution( def convolution(
N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float" N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32
): ):
KH, KW = K, K KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
is_hopper = check_hopper() is_hopper = check_hopper()
@T.prim_func @T.prim_func
......
...@@ -20,11 +20,11 @@ def tl_gemm( ...@@ -20,11 +20,11 @@ def tl_gemm(
accum_dtype, accum_dtype,
): ):
assert in_dtype in [ assert in_dtype in [
"float8_e4m3", T.float8_e4m3fn,
], "Currently only float8_e4m3 is supported" ], "Currently only float8_e4m3 is supported"
assert out_dtype in [ assert out_dtype in [
"bfloat16", T.bfloat16,
"float32", T.float32,
], "Currently only float16 and float32 are supported" ], "Currently only float16 and float32 are supported"
group_size = 128 group_size = 128
...@@ -44,14 +44,14 @@ def tl_gemm( ...@@ -44,14 +44,14 @@ def tl_gemm(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
scales_a: T.Tensor(Scales_A_shape, "float32"), scales_a: T.Tensor(Scales_A_shape, T.float32),
scales_b: T.Tensor(Scales_B_shape, "float32"), scales_b: T.Tensor(Scales_B_shape, T.float32),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_shared = T.alloc_shared(C_shared_shape, out_dtype) C_shared = T.alloc_shared(C_shared_shape, out_dtype)
Scale_C_shared = T.alloc_shared((block_M), "float32") Scale_C_shared = T.alloc_shared((block_M), T.float32)
C_local = T.alloc_fragment(C_shared_shape, accum_dtype) C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype) C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)
...@@ -176,11 +176,11 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp ...@@ -176,11 +176,11 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp
def main(): def main():
assert_tl_gemm_correctness(1024, 1024, 8192, 128, "float8_e4m3", "bfloat16", "float32") assert_tl_gemm_correctness(1024, 1024, 8192, 128, T.float8_e4m3fn, T.bfloat16, T.float32)
if __name__ == "__main__": if __name__ == "__main__":
for dtype in ["float8_e4m3"]: for dtype in [T.float8_e4m3fn]:
for out_dtype in ["bfloat16", "float32"]: for out_dtype in [T.bfloat16, T.float32]:
for block_N in [16, 32, 64, 128]: for block_N in [16, 32, 64, 128]:
assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, "float32") assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, T.float32)
...@@ -36,8 +36,8 @@ def get_configs(): ...@@ -36,8 +36,8 @@ def get_configs():
) )
def flashmla_decode(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, threads=128): def flashmla_decode(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, threads=128):
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
kv_group_num = heads // kv_head_num kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num) VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1" assert kv_head_num == 1, "kv_head_num must be 1"
......
...@@ -15,8 +15,8 @@ import argparse ...@@ -15,8 +15,8 @@ import argparse
) )
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale):
scale = float(softmax_scale * 1.44269504) # log2(e) scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
kv_group_num = heads // kv_head_num kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num) VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1" assert kv_head_num == 1, "kv_head_num must be 1"
......
...@@ -17,8 +17,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc ...@@ -17,8 +17,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
if softmax_scale is None: if softmax_scale is None:
softmax_scale = (dv + dpe) ** -0.5 softmax_scale = (dv + dpe) ** -0.5
scale = float(softmax_scale * 1.44269504) # log2(e) scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
kv_group_num = h_q // h_kv kv_group_num = h_q // h_kv
VALID_BLOCK_H = min(block_H, kv_group_num) VALID_BLOCK_H = min(block_H, kv_group_num)
assert h_kv == 1, "h_kv must be 1" assert h_kv == 1, "h_kv must be 1"
...@@ -30,8 +30,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc ...@@ -30,8 +30,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
Q_pe: T.Tensor([batch, h_q, dpe], dtype), Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], T.int32),
CACHE_SEQLENS: T.Tensor([batch], "int32"), CACHE_SEQLENS: T.Tensor([batch], T.int32),
Output: T.Tensor([batch, h_q, dv], dtype), Output: T.Tensor([batch, h_q, dv], dtype),
): ):
with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by): with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by):
...@@ -103,8 +103,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc ...@@ -103,8 +103,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
Q_pe: T.Tensor([batch, h_q, dpe], dtype), Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], T.int32),
CACHE_SEQLENS: T.Tensor([batch], "int32"), CACHE_SEQLENS: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, h_q, num_split], dtype), glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
): ):
...@@ -224,8 +224,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc ...@@ -224,8 +224,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
Q_pe: T.Tensor([batch, h_q, dpe], dtype), Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32),
cache_seqlens: T.Tensor([batch], "int32"), cache_seqlens: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, h_q, num_split], dtype), glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype), Output: T.Tensor([batch, h_q, dv], dtype),
...@@ -239,8 +239,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc ...@@ -239,8 +239,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
Q_pe: T.Tensor([batch, h_q, dpe], dtype), Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype), KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype), K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"), block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32),
cache_seqlens: T.Tensor([batch], "int32"), cache_seqlens: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, h_q, num_split], dtype), glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype), Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype), Output: T.Tensor([batch, h_q, dv], dtype),
......
...@@ -16,8 +16,8 @@ import argparse ...@@ -16,8 +16,8 @@ import argparse
) )
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
kv_group_num = heads // kv_head_num kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num) VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1" assert kv_head_num == 1, "kv_head_num must be 1"
......
...@@ -27,8 +27,8 @@ import argparse ...@@ -27,8 +27,8 @@ import argparse
) )
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale):
sm_scale = float(softmax_scale * 1.44269504) # log2(e) sm_scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
kv_group_num = heads // kv_head_num kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num) VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1" assert kv_head_num == 1, "kv_head_num must be 1"
......
...@@ -15,9 +15,9 @@ import argparse ...@@ -15,9 +15,9 @@ import argparse
) )
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e) scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16" dtype = T.float16
q_dtype = "float8_e4m3" q_dtype = T.float8_e4m3fn
accum_dtype = "float" accum_dtype = T.float32
kv_group_num = heads // kv_head_num kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num) VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1" assert kv_head_num == 1, "kv_head_num must be 1"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment