Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
......@@ -76,13 +76,13 @@ def bitnet_158_int8xint2_decode(
reduce_thread=32,
):
assert in_dtype in [
"float16",
"int8",
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
storage_nbit = 8
num_bits = 2
......@@ -94,7 +94,7 @@ def bitnet_158_int8xint2_decode(
MAX_TRANSACTION_SIZE_IN_BITS = 128
micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
micro_size_k_compressed = micro_size_k // num_elems_per_byte
storage_dtype = "int8"
storage_dtype = T.int8
block_K = reduce_thread * micro_size_k
use_dp4a = True
......@@ -194,12 +194,12 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8):
# interleave weight numpy implementation
def interleave_weight(qweight, nbits=4, target_dtype="float16"):
assert target_dtype in ["float16", "int8"]
def interleave_weight(qweight, nbits=4, target_dtype=T.float16):
assert target_dtype in [T.float16, T.int8]
# reinterpret the data type of qweight to int32
qweight = qweight.view(np.int32)
new_qweight = np.zeros_like(qweight)
bits_stride = 8 if target_dtype == "int8" else 16
bits_stride = 8 if target_dtype == T.int8 else 16
mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f
num_groups = 32 // bits_stride
elems_per_group = bits_stride // nbits
......@@ -209,7 +209,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits
new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift
if nbits == 1 and target_dtype == "int8":
if nbits == 1 and target_dtype == T.int8:
# special handling for 1b interleave
n16_weight = new_qweight & np.int32(0xF0F00F0F)
n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16
......@@ -217,12 +217,12 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4
n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12
return n16_weight.view(np.int8)
elif nbits == 2 and target_dtype == "float16":
elif nbits == 2 and target_dtype == T.float16:
n8_weight = new_qweight & np.int32(0xFF0000FF)
n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16
n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8
return n8_weight.view(np.int8)
elif nbits == 1 and target_dtype == "float16":
elif nbits == 1 and target_dtype == T.float16:
n8_weight = new_qweight & 0xF000000F
n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8
n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16
......@@ -259,4 +259,4 @@ def assert_bitnet_158_int8xint2_decode_correctness(M, N, K, in_dtype, out_dtype,
if __name__ == "__main__":
assert_bitnet_158_int8xint2_decode_correctness(1, 256, 256, "int8", "int32", "int32")
assert_bitnet_158_int8xint2_decode_correctness(1, 256, 256, T.int8, T.int32, T.int32)
......@@ -88,9 +88,9 @@ def bitnet_158_int8xint2_prefill(
Create a TVM GPU prim_func implementing a block-tiled matrix multiply that multiplies dense A by compressed/interleaved low‑precision B (2-bit packed into int8 storage), decoding B to int8 on-chip and accumulating into C.
The returned prim_func expects:
- A: shape (M, K) with dtype `in_dtype` ("float16" or "int8").
- A: shape (M, K) with dtype `in_dtype` (T.float16 or T.int8).
- B: compressed storage with shape (N, K/4) and int8 storage layout (packing 4 2-bit elements per byte).
- C: output buffer shape (M, N) with dtype `out_dtype` ("float16", "float32", or "int32").
- C: output buffer shape (M, N) with dtype `out_dtype` (T.float16, T.float32, or T.int32).
Details:
- Builds a tiled, pipelined kernel using shared memory and warp-level MMA intrinsics (INT4TensorCoreIntrinEmitter). B is loaded from compressed storage, decoded to int8 in threads (via decode_i2u_to_i8s / decode_i2s_to_i8s), and dequantized into a shared buffer used by the MMA emitter.
......@@ -98,15 +98,15 @@ def bitnet_158_int8xint2_prefill(
- block_row_warps, block_col_warps: number of warps per block in row/col.
- warp_row_tiles, warp_col_tiles: tiles per warp.
- chunk: K-sized chunk per block (block_K).
- micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == "int32").
- micro sizes are fixed (16x16x16, except micro_k=32 when accum_dtype == T.int32).
- Uses 2-stage pipelining by default to overlap loads and compute and applies a swizzle layout to improve L2 behavior.
- Assertions: raises AssertionError if in_dtype or out_dtype are not among supported values.
Parameters:
M, N, K (int): Global matrix dimensions.
in_dtype (str): Input and decoded B element dtype; "float16" or "int8".
out_dtype (str): Output C dtype; one of "float16", "float32", "int32".
accum_dtype (str): Accumulator dtype used by MMA (e.g., "int32").
in_dtype (str): Input and decoded B element dtype; T.float16 or T.int8.
out_dtype (str): Output C dtype; one of T.float16, T.float32, T.int32.
accum_dtype (str): Accumulator dtype used by MMA (e.g., T.int32).
fast_decoding (bool): If True, enable the fast decoding path (affects which device decode is used).
block_row_warps (int): Warps in block row dimension.
block_col_warps (int): Warps in block column dimension.
......@@ -118,18 +118,18 @@ def bitnet_158_int8xint2_prefill(
T.prim_func: A TVM prim_func implementing the described GPU kernel suitable for compilation and execution.
"""
assert in_dtype in [
"float16",
"int8",
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if accum_dtype == "int32":
if accum_dtype == T.int32:
micro_size_k = 32
num_elems_per_byte = 4
......@@ -138,7 +138,7 @@ def bitnet_158_int8xint2_prefill(
local_size_compressed = local_size // num_elems_per_byte
shared_scope = "shared.dyn"
storage_dtype = "int8"
storage_dtype = T.int8
# Pipeline Stage
stage = 2
......@@ -317,12 +317,12 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8):
# interleave weight numpy implementation
def interleave_weight(qweight, nbits=4, target_dtype="float16"):
assert target_dtype in ["float16", "int8"]
def interleave_weight(qweight, nbits=4, target_dtype=T.float16):
assert target_dtype in [T.float16, T.int8]
# reinterpret the data type of qweight to int32
qweight = qweight.view(np.int32)
new_qweight = np.zeros_like(qweight)
bits_stride = 8 if target_dtype == "int8" else 16
bits_stride = 8 if target_dtype == T.int8 else 16
mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f
num_groups = 32 // bits_stride
elems_per_group = bits_stride // nbits
......@@ -332,7 +332,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits
new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift
if nbits == 1 and target_dtype == "int8":
if nbits == 1 and target_dtype == T.int8:
# special handling for 1b interleave
n16_weight = new_qweight & np.int32(0xF0F00F0F)
n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16
......@@ -340,12 +340,12 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4
n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12
return n16_weight.view(np.int8)
elif nbits == 2 and target_dtype == "float16":
elif nbits == 2 and target_dtype == T.float16:
n8_weight = new_qweight & np.int32(0xFF0000FF)
n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16
n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8
return n8_weight.view(np.int8)
elif nbits == 1 and target_dtype == "float16":
elif nbits == 1 and target_dtype == T.float16:
n8_weight = new_qweight & 0xF000000F
n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8
n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16
......@@ -382,4 +382,4 @@ def assert_bitnet_158_int8xint2_prefill_correctness(M, N, K, in_dtype, out_dtype
if __name__ == "__main__":
assert_bitnet_158_int8xint2_prefill_correctness(256, 256, 256, "int8", "int32", "int32")
assert_bitnet_158_int8xint2_prefill_correctness(256, 256, 256, T.int8, T.int32, T.int32)
......@@ -38,18 +38,18 @@ def tl_matmul(
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
if out_dtype == T.int32:
micro_size_k = 32
# This is a debug config
......@@ -57,7 +57,7 @@ def tl_matmul(
block_col_warps = 2
warp_row_tiles = 64
warp_col_tiles = 64
chunk = 32 if in_dtype == "float16" else 64
chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn"
# Pipeline Stage
......@@ -183,7 +183,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
# src_code is the generated cuda source
assert src_code is not None
print(src_code)
if in_dtype == "int8":
if in_dtype == T.int8:
A = torch.randint(-7, 7, (M, K), device="cuda", dtype=torch.int8)
B = torch.randint(-7, 7, (N, K), device="cuda", dtype=torch.int8)
else:
......@@ -209,12 +209,12 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16)
assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32)
if __name__ == "__main__":
# bitblas.testing.main()
# assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
# assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32")
assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32")
# assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16)
# assert_tl_matmul_correctness(128, 128, 128, T.int8, T.int32, T.int32)
assert_tl_matmul_correctness(16384, 16384, 16384, T.int8, T.int32, T.int32)
......@@ -41,9 +41,9 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len]
dtype = "float16"
accum_dtype = "float"
block_mask_dtype = "bool"
dtype = T.float16
accum_dtype = T.float32
block_mask_dtype = T.bool
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
......
......@@ -14,8 +14,8 @@ from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // heads_kv
@tilelang.jit(
......@@ -43,9 +43,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
block_table: T.Tensor(shape_block_table, "int32"),
block_indices: T.Tensor(shape_indices, T.int32),
cache_seqlens: T.Tensor([batch], T.int32),
block_table: T.Tensor(shape_block_table, T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
):
......@@ -139,7 +139,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
max_split = T.alloc_local([1], "int32")
max_split = T.alloc_local([1], T.int32)
T.annotate_layout(
{
......@@ -177,9 +177,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
block_table: T.Tensor(shape_block_table, "int32"),
block_indices: T.Tensor(shape_indices, T.int32),
cache_seqlens: T.Tensor([batch], T.int32),
block_table: T.Tensor(shape_block_table, T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
......
......@@ -11,8 +11,8 @@ from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // heads_kv
@tilelang.jit(
......@@ -35,9 +35,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
# actual_num_blocks: T.Tensor([batch], "int32"),
block_indices: T.Tensor(shape_indices, T.int32),
cache_seqlens: T.Tensor([batch], T.int32),
# actual_num_blocks: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
):
......@@ -128,7 +128,7 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
max_split = T.alloc_local([1], "int32")
max_split = T.alloc_local([1], T.int32)
T.annotate_layout(
{
......@@ -166,9 +166,9 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_indices: T.Tensor(shape_indices, "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
# actual_num_blocks: T.Tensor([batch], "int32"),
block_indices: T.Tensor(shape_indices, T.int32),
cache_seqlens: T.Tensor([batch], T.int32),
# actual_num_blocks: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
......
......@@ -13,8 +13,8 @@ from heuristic import num_splits_heuristic
def flashattn(batch, heads, heads_kv, dim, dim_v):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // heads_kv
@tilelang.jit(
......@@ -37,8 +37,8 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_mask: T.Tensor(shape_mask, "bool"),
cache_seqlens: T.Tensor([batch], "int32"),
block_mask: T.Tensor(shape_mask, T.bool),
cache_seqlens: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
):
......@@ -156,8 +156,8 @@ def flashattn(batch, heads, heads_kv, dim, dim_v):
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_mask: T.Tensor(shape_mask, "bool"),
cache_seqlens: T.Tensor([batch], "int32"),
block_mask: T.Tensor(shape_mask, T.bool),
cache_seqlens: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, heads, num_split], accum_dtype),
Output_partial: T.Tensor(part_shape, accum_dtype),
Output: T.Tensor(shape_o, dtype),
......
......@@ -93,7 +93,7 @@ def supply_program(params: List[KernelParam]):
)
@tilelang.jit(out_idx=[-1])
def blocksparse_matmul(
M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float"
M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32
):
block_mask_shape = (M // block_M, N // block_N, K // block_K)
......
......@@ -5,8 +5,8 @@ from typing import Tuple
from tilelang.utils.tensor import torch_assert_close
# support bfloat16, float, float16
dtype = "bfloat16"
accum_dtype = "float"
dtype = T.bfloat16
accum_dtype = T.float32
@tilelang.jit(out_idx=[2, 3])
......@@ -18,8 +18,8 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
@T.prim_func
def group_per_split_token_cast(
X: T.Tensor((M, N), dtype),
batch_sizes: T.Tensor((BG,), "int32"),
X_fp8: T.Tensor((BG, M_max, N), "float8_e4m3"),
batch_sizes: T.Tensor((BG,), T.int32),
X_fp8: T.Tensor((BG, M_max, N), T.float8_e4m3fn),
X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)), accum_dtype),
):
with T.Kernel(T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz):
......@@ -30,8 +30,8 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
y_amax_local = T.alloc_fragment((blk_m,), accum_dtype)
y_s_local = T.alloc_fragment((blk_m,), accum_dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
row_offset = T.alloc_fragment((1,), "int32")
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn)
row_offset = T.alloc_fragment((1,), T.int32)
T.annotate_layout(
{
......@@ -163,11 +163,11 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> Tuple[torch.Tenso
def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None):
if batch_sizes is None:
batch_sizes = [2048, 6144]
if dtype == "float":
if dtype == T.float:
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
elif dtype == "float16":
elif dtype == T.float16:
x = torch.randn(M, N, device="cuda", dtype=torch.float16)
elif dtype == "bfloat16":
elif dtype == T.bfloat16:
x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
else:
raise ValueError(f"Unsupported dtype: {dtype}")
......
......@@ -7,14 +7,14 @@ from tilelang.utils.tensor import torch_assert_close
@tilelang.jit(out_idx=[1, 2])
def per_token_cast_to_fp8(M, N, blk_m):
dtype = "float"
dtype = T.float
group_size = 128
fp8_min = -448.0
fp8_max = 448.0
@T.prim_func
def per_token_cast(
X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "float8_e4m3"), X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)
X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), T.float8_e4m3fn), X_amax: T.Tensor((M, T.ceildiv(N, group_size)), dtype)
):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by):
row = bx
......@@ -23,7 +23,7 @@ def per_token_cast_to_fp8(M, N, blk_m):
y_amax_local = T.alloc_fragment((blk_m,), dtype)
y_s_local = T.alloc_fragment((blk_m,), dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), T.float8_e4m3fn)
T.annotate_layout(
{
......
import tilelang
import tilelang.language as T
# @tilelang.jit(compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
M = 1024
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
func = matmul(M, N, K, block_M, block_N, block_K)
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags="-O3 --use_fast_math --expt-relaxed-constexpr")
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"])
# or jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", compile_flags=["-O3 --use_fast_math --expt-relaxed-constexpr"])
import torch
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = jit_kernel(a, b)
print(c)
ref_c = a @ b
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
......@@ -25,12 +25,12 @@ def ref_program(stride, padding, dilation):
@tilelang.jit(out_idx=[2])
def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"):
def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
is_hopper = check_hopper()
@T.prim_func
......
......@@ -75,13 +75,13 @@ def get_heuristic_config() -> dict:
@tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[2])
def convolution(
N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float"
N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32
):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
is_hopper = check_hopper()
@T.prim_func
......
......@@ -20,11 +20,11 @@ def tl_gemm(
accum_dtype,
):
assert in_dtype in [
"float8_e4m3",
T.float8_e4m3fn,
], "Currently only float8_e4m3 is supported"
assert out_dtype in [
"bfloat16",
"float32",
T.bfloat16,
T.float32,
], "Currently only float16 and float32 are supported"
group_size = 128
......@@ -44,14 +44,14 @@ def tl_gemm(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
scales_a: T.Tensor(Scales_A_shape, "float32"),
scales_b: T.Tensor(Scales_B_shape, "float32"),
scales_a: T.Tensor(Scales_A_shape, T.float32),
scales_b: T.Tensor(Scales_B_shape, T.float32),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_shared = T.alloc_shared(C_shared_shape, out_dtype)
Scale_C_shared = T.alloc_shared((block_M), "float32")
Scale_C_shared = T.alloc_shared((block_M), T.float32)
C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)
......@@ -176,11 +176,11 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp
def main():
assert_tl_gemm_correctness(1024, 1024, 8192, 128, "float8_e4m3", "bfloat16", "float32")
assert_tl_gemm_correctness(1024, 1024, 8192, 128, T.float8_e4m3fn, T.bfloat16, T.float32)
if __name__ == "__main__":
for dtype in ["float8_e4m3"]:
for out_dtype in ["bfloat16", "float32"]:
for dtype in [T.float8_e4m3fn]:
for out_dtype in [T.bfloat16, T.float32]:
for block_N in [16, 32, 64, 128]:
assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, "float32")
assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, T.float32)
......@@ -36,8 +36,8 @@ def get_configs():
)
def flashmla_decode(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, threads=128):
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
......
......@@ -15,8 +15,8 @@ import argparse
)
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale):
scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
......
......@@ -17,8 +17,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
if softmax_scale is None:
softmax_scale = (dv + dpe) ** -0.5
scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
kv_group_num = h_q // h_kv
VALID_BLOCK_H = min(block_H, kv_group_num)
assert h_kv == 1, "h_kv must be 1"
......@@ -30,8 +30,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
CACHE_SEQLENS: T.Tensor([batch], "int32"),
BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], T.int32),
CACHE_SEQLENS: T.Tensor([batch], T.int32),
Output: T.Tensor([batch, h_q, dv], dtype),
):
with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by):
......@@ -103,8 +103,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
CACHE_SEQLENS: T.Tensor([batch], "int32"),
BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], T.int32),
CACHE_SEQLENS: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
):
......@@ -224,8 +224,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32),
cache_seqlens: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype),
......@@ -239,8 +239,8 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
block_table: T.Tensor([batch, max_seqlen_pad // block_size], T.int32),
cache_seqlens: T.Tensor([batch], T.int32),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype),
......
......@@ -16,8 +16,8 @@ import argparse
)
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
......
......@@ -27,8 +27,8 @@ import argparse
)
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale):
sm_scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
......
......@@ -15,9 +15,9 @@ import argparse
)
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16"
q_dtype = "float8_e4m3"
accum_dtype = "float"
dtype = T.float16
q_dtype = T.float8_e4m3fn
accum_dtype = T.float32
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment