Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
......@@ -80,16 +80,16 @@ def flashattn(batch_size, UQ, UKV, heads, dim, is_causal, block_M=64, block_N=64
v_shape = [UKV, heads, dim]
o_shape = [UQ, heads, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def main(
Q_unpad: T.Tensor(q_shape, dtype),
K_unpad: T.Tensor(k_shape, dtype),
V_unpad: T.Tensor(v_shape, dtype),
cu_seqlens_q: T.Tensor([batch_size + 1], "int32"),
cu_seqlens_k: T.Tensor([batch_size + 1], "int32"),
cu_seqlens_q: T.Tensor([batch_size + 1], T.int32),
cu_seqlens_k: T.Tensor([batch_size + 1], T.int32),
max_seqlen_q: T.int32,
Output_unpad: T.Tensor(o_shape, dtype),
):
......
......@@ -53,8 +53,8 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
shape_k = [batch, seqlen_kv, groups, dim]
shape_v = [batch, seqlen_kv, groups, dim]
shape_o = [batch, heads, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // groups
part_shape = [batch, heads, num_split, dim]
......
......@@ -209,8 +209,8 @@ def flashattn(
shape_v = [total_seqlen_k, k_heads, dim]
shape_o = [batch, heads, dim]
shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // k_heads
valid_block_H = min(block_H, kv_group_num)
......@@ -221,8 +221,8 @@ def flashattn(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], "int32"),
s_aux: T.Tensor([heads], "float32"),
cu_seqlens_k: T.Tensor([batch + 1], T.int32),
s_aux: T.Tensor([heads], T.float32),
Output: T.Tensor([batch, heads, dim], dtype),
S: T.Tensor(shape_s, dtype),
):
......@@ -241,7 +241,7 @@ def flashattn(
logsum = T.alloc_fragment([block_H], accum_dtype)
S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype)
# S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype)
s_aux_shared = T.alloc_shared([block_H], "float32")
s_aux_shared = T.alloc_shared([block_H], T.float32)
T.annotate_layout(
{
......@@ -321,8 +321,8 @@ def flashattn(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], "int32"),
s_aux: T.Tensor([heads], "float32"),
cu_seqlens_k: T.Tensor([batch + 1], T.int32),
s_aux: T.Tensor([heads], T.float32),
Output: T.Tensor(shape_o, dtype),
S: T.Tensor(shape_s, dtype),
):
......@@ -449,7 +449,7 @@ def test_equal_seqlen_decode_main(args):
real_max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
# For decode, query is just 1 token per batch
q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
......@@ -568,7 +568,7 @@ def test_varlen_decode_main(args):
real_max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})")
......@@ -789,7 +789,7 @@ def speed_benchmark_decode_comparison(args):
max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
print("\n=== Decode Speed Benchmark Comparison ===")
print("Configuration:")
......@@ -890,7 +890,7 @@ if __name__ == "__main__":
parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length")
parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension")
parser.add_argument("--block_size", type=int, default=64, help="Block size for computation")
parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"], help="Data type")
parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type")
parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths")
parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism")
parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark")
......@@ -898,7 +898,7 @@ if __name__ == "__main__":
args = parser.parse_args()
args.test_sink = True
args.test_varlen = False
args.dtype = "float16"
args.dtype = T.float16
args.num_split = 1
if args.benchmark:
......
......@@ -45,8 +45,8 @@ def flashattn(
shape_v = [total_seqlen_k, k_heads, dim]
shape_o = [batch, heads, dim]
shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // k_heads
assert page_block_size >= block_N and page_block_size % block_N == 0, (
"page_block_size must be larger than block_N and a multiple of block_N"
......@@ -60,9 +60,9 @@ def flashattn(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], "int32"),
s_aux: T.Tensor([heads], "float32"),
BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / block_N)], "int32"),
cu_seqlens_k: T.Tensor([batch + 1], T.int32),
s_aux: T.Tensor([heads], T.float32),
BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / block_N)], T.int32),
Output: T.Tensor([batch, heads, dim], dtype),
S: T.Tensor(shape_s, dtype),
):
......@@ -80,7 +80,7 @@ def flashattn(
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype)
s_aux_shared = T.alloc_shared([block_H], "float32")
s_aux_shared = T.alloc_shared([block_H], T.float32)
bid = bx
hid = by
......@@ -146,9 +146,9 @@ def flashattn(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], "int32"),
s_aux: T.Tensor([heads], "float32"),
BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], "int32"),
cu_seqlens_k: T.Tensor([batch + 1], T.int32),
s_aux: T.Tensor([heads], T.float32),
BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], T.int32),
Output: T.Tensor(shape_o, dtype),
S: T.Tensor(shape_s, dtype),
):
......@@ -211,7 +211,7 @@ def test_equal_seqlen_decode_main(args):
head_size = args.head_size
block_size = args.block_size
page_block_size = args.page_block_size
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
# For decode, query is just 1 token per batch
q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
......@@ -341,7 +341,7 @@ def test_varlen_decode_main(args):
head_size = args.head_size
block_size = args.block_size
page_block_size = args.page_block_size
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})")
......@@ -549,7 +549,7 @@ def speed_benchmark_decode_comparison(args):
head_size = args.head_size
block_size = args.block_size
page_block_size = args.page_block_size
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
print("\n=== Decode Speed Benchmark Comparison ===")
print("Configuration:")
......@@ -659,7 +659,7 @@ if __name__ == "__main__":
parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length")
parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension")
parser.add_argument("--block_size", type=int, default=128, help="Block size for computation")
parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"], help="Data type")
parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type")
parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths")
parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism")
parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark")
......@@ -668,7 +668,7 @@ if __name__ == "__main__":
args = parser.parse_args()
args.test_sink = True
args.test_varlen = True
args.dtype = "float16"
args.dtype = T.float16
args.num_split = 1
if args.benchmark:
......
......@@ -14,8 +14,8 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
shape_q = [batch, seqlen_q, heads, dim]
shape_kv = [batch, seqlen_kv, heads, dim]
part_shape = [batch, seqlen_q, heads, num_split, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.macro
def MMA0(
......
......@@ -33,7 +33,7 @@ def moe_forward_tilelang_shared(
shared_W_up_shape = (dexpert, dhidden)
shared_W_down_shape = (dhidden, dexpert)
accum_type = "float32"
accum_type = T.float32
@T.prim_func
def kernel_shared(
......@@ -121,7 +121,7 @@ def moe_forward_tilelang_routed(
# group_count = len(group_sizes_list)
# M = sum([(group_size + block_token - 1) // block_token for group_size in group_sizes_list])
M = math.ceil(group_sum / block_token) + group_count
accum_dtype = "float32"
accum_dtype = T.float32
# Tensors: Note that input shape is reshape to (bs * seq_len * n_experts_per_token, dhidden) for grouped gemm
input_shape = (group_sum, dhidden)
......@@ -139,10 +139,10 @@ def moe_forward_tilelang_routed(
routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore
routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore
routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore
group_sizes: T.Tensor(group_sizes_shape, "int32"), # type: ignore
group_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore
group_padded_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore
group_idx_for_bx: T.Tensor((M,), "int32"), # type: ignore
group_sizes: T.Tensor(group_sizes_shape, T.int32), # type: ignore
group_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore
group_padded_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore
group_idx_for_bx: T.Tensor((M,), T.int32), # type: ignore
up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore
output: T.Tensor(input_shape, dtype), # type: ignore
):
......@@ -155,8 +155,8 @@ def moe_forward_tilelang_routed(
gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype)
up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype)
cur_group_idx = T.alloc_local([1], "int32")
cur_group_size = T.alloc_local([1], "int32")
cur_group_idx = T.alloc_local([1], T.int32)
cur_group_size = T.alloc_local([1], T.int32)
T.use_swizzle(10, enable=True)
......@@ -208,8 +208,8 @@ def moe_forward_tilelang_routed(
routed_expert_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype)
output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_dtype)
cur_group_idx = T.alloc_local([1], "int32")
cur_group_size = T.alloc_local([1], "int32")
cur_group_idx = T.alloc_local([1], T.int32)
cur_group_size = T.alloc_local([1], T.int32)
T.use_swizzle(10, enable=True)
......@@ -464,7 +464,7 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
"""
input_tensor, weights, config = data
dtype_str = "float16"
dtype_str = T.float16
shared_kernel = moe_forward_tilelang_shared(
config["d_hidden"],
......
......@@ -250,13 +250,13 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dv_fragment_2 = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
dO_shared_t = T.alloc_shared((block_DV, block_S), dtype="float32")
dO_fragment = T.alloc_fragment((block_S, block_DV), dtype="float32")
dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype="float32")
dO_shared_t = T.alloc_shared((block_DV, block_S), dtype=T.float32)
dO_fragment = T.alloc_fragment((block_S, block_DV), dtype=T.float32)
dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype=T.float32)
K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
Q_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
Q_shared_fp32 = T.alloc_shared((block_S, DK), dtype="float32")
Q_shared_fp32 = T.alloc_shared((block_S, DK), dtype=T.float32)
W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
G_last_local = T.alloc_local((1), dtype=gate_dtype)
......@@ -592,11 +592,11 @@ def main():
H=8,
DK=DK,
DV=128,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
state_dtype="float32",
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
accum_dtype=T.float32,
gate_dtype=T.float32,
state_dtype=T.float32,
chunk_size=64,
scale=DK**-0.5,
use_g=True,
......
......@@ -387,11 +387,11 @@ def main():
H=32,
DK=128,
DV=128,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
state_dtype="float32",
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
accum_dtype=T.float32,
gate_dtype=T.float32,
state_dtype=T.float32,
chunk_size=64,
use_g=True,
use_initial_state=False,
......
......@@ -230,10 +230,10 @@ def main():
DK=128,
DV=128,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
accum_dtype=T.float32,
gate_dtype=T.float32,
use_g=True,
block_DK=128,
block_DV=128,
......
......@@ -505,11 +505,11 @@ def main():
H=8,
DK=DK,
DV=DV,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
state_dtype="float32",
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
accum_dtype=T.float32,
gate_dtype=T.float32,
state_dtype=T.float32,
chunk_size=64,
scale=DK**-0.5,
# scale=1,
......
......@@ -57,9 +57,9 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
H,
DK,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
accum_dtype=T.float32,
use_g=True,
# kernel config
block_S=64,
......@@ -183,9 +183,9 @@ def main():
H=32,
DK=128,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
accum_dtype=T.float32,
use_g=True,
block_DK=64,
threads=128,
......
......@@ -32,8 +32,8 @@ def tilelang_chunk_local_cumsum_scalar(
is_varlen=False,
head_first=False,
reverse=False,
input_dtype="float16",
output_dtype="float32",
input_dtype=T.float16,
output_dtype=T.float32,
# kernel config
block_S=64,
threads=256,
......@@ -154,8 +154,8 @@ def main():
chunk_size=64,
reverse=True,
head_first=False,
input_dtype="float32",
output_dtype="float32",
input_dtype=T.float32,
output_dtype=T.float32,
threads=256,
use_fragment=False,
)
......
......@@ -205,10 +205,10 @@ def main():
DK=128,
DV=128,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
gate_dtype="float32",
accum_dtype="float32",
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
gate_dtype=T.float32,
accum_dtype=T.float32,
block_DK=64,
block_DV=32,
threads=128,
......
......@@ -518,11 +518,11 @@ def main():
H=8,
DK=DK,
DV=DV,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
state_dtype="float32",
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
accum_dtype=T.float32,
gate_dtype=T.float32,
state_dtype=T.float32,
chunk_size=64,
block_DK=32,
block_DV=32,
......
import tilelang.testing
import torch
import tilelang.testing
from tilelang import language as T
B = 1
S = 1024 # small but for test only.
H = 32
DK = 128
DV = 128
input_dtype = "bfloat16"
output_dtype = "bfloat16"
accum_dtype = "float32"
gate_dtype = "float32"
state_dtype = "float32"
input_dtype = T.bfloat16
output_dtype = T.bfloat16
accum_dtype = T.float32
gate_dtype = T.float32
state_dtype = T.float32
chunk_size = 64
use_g = True
use_initial_state = True
......
......@@ -53,7 +53,7 @@ import tilelang
from tilelang import Profiler
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......@@ -176,7 +176,7 @@ import tilelang.language as T
# that helps align data for MMA (Matrix Multiply-Accumulate) operations.
from tilelang.intrinsics import make_mma_swizzle_layout as make_swizzle_layout
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......@@ -265,18 +265,18 @@ def tl_matmul(
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
if out_dtype == T.int32:
micro_size_k = 32
# This is a debug config
......
......@@ -3,7 +3,7 @@ import tilelang.language as T
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
......
......@@ -51,9 +51,9 @@ def get_configs(M, N, K, with_roller=False, topk=20):
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float",
in_dtype=T.float16,
out_dtype=T.float16,
accum_dtype=T.float32,
).with_arch(arch)
func = carve_template.equivalent_function()
......@@ -116,8 +116,8 @@ def get_best_config(M, N, K, with_roller=False):
thread_num=None,
enable_rasteration=None,
):
dtype = "bfloat16"
accum_dtype = "float"
dtype = T.bfloat16
accum_dtype = T.float32
@T.prim_func
def main(
......@@ -178,7 +178,7 @@ def get_heuristic_config() -> dict:
@tl.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def gemm_autotune(
A: T.Tensor((M, K), dtype),
......
......@@ -35,18 +35,18 @@ def tl_matmul(
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
if out_dtype == T.int32:
micro_size_k = 32
# This is a debug config
......@@ -54,7 +54,7 @@ def tl_matmul(
block_col_warps = 2
warp_row_tiles = 64
warp_col_tiles = 64
# chunk = 32 if in_dtype == "float16" else 64
# chunk = 32 if in_dtype == T.float16 else 64
chunk = 32
shared_scope = "shared.dyn"
......@@ -163,7 +163,7 @@ def ref_program(A, B):
def main(M=4096, N=4096, K=4096):
in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32"
in_dtype, out_dtype, accum_dtype = T.float16, T.float16, T.float32
kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
......
......@@ -5,7 +5,7 @@ import argparse
@tilelang.jit(out_idx=[-1])
def matmul_non_persistent(M, N, K, block_M, block_N, block_K, threads, num_stages, dtype="float16", accum_dtype="float"):
def matmul_non_persistent(M, N, K, block_M, block_N, block_K, threads, num_stages, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......@@ -34,7 +34,7 @@ def matmul_non_persistent(M, N, K, block_M, block_N, block_K, threads, num_stage
@tilelang.jit(out_idx=[-1])
def matmul_persistent(
M, N, K, block_M, block_N, block_K, threads, num_stages, dtype="float16", accum_dtype="float", use_persistent_primitive=True
M, N, K, block_M, block_N, block_K, threads, num_stages, dtype=T.float16, accum_dtype=T.float32, use_persistent_primitive=True
):
sm_num = driver.get_num_sms()
m_blocks = T.ceildiv(M, block_M)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment