Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
...@@ -3,7 +3,7 @@ import tilelang.language as T ...@@ -3,7 +3,7 @@ import tilelang.language as T
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func @T.prim_func
def gemm_schedule( def gemm_schedule(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
......
...@@ -53,8 +53,8 @@ def get_configs(): ...@@ -53,8 +53,8 @@ def get_configs():
) )
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type): def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type):
dtype = "float8_e4m3fnuz" dtype = T.float8_e4m3fnuz
accum_dtype = "float" accum_dtype = T.float32
@T.prim_func @T.prim_func
def gemm_fp8_rs( def gemm_fp8_rs(
......
import torch import torch
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
def calc_diff(x, y): def calc_diff(x, y):
...@@ -12,7 +11,7 @@ def calc_diff(x, y): ...@@ -12,7 +11,7 @@ def calc_diff(x, y):
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32):
@T.prim_func @T.prim_func
def gemm_fp8( def gemm_fp8(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
...@@ -36,7 +35,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): ...@@ -36,7 +35,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
def test_gemm_fp8(M, N, K, dtype): def test_gemm_fp8(M, N, K, dtype):
torch_dtype = map_torch_type(dtype) torch_dtype = T.dtype(dtype).as_torch()
kernel = matmul(M, N, K, 128, 128, 64, dtype) kernel = matmul(M, N, K, 128, 128, 64, dtype)
...@@ -56,8 +55,8 @@ def test_gemm_fp8(M, N, K, dtype): ...@@ -56,8 +55,8 @@ def test_gemm_fp8(M, N, K, dtype):
def main(): def main():
test_gemm_fp8(1024, 1024, 1024, "float8_e4m3") test_gemm_fp8(1024, 1024, 1024, T.float8_e4m3fn)
test_gemm_fp8(1024, 1024, 1024, "float8_e5m2") test_gemm_fp8(1024, 1024, 1024, T.float8_e5m2)
if __name__ == "__main__": if __name__ == "__main__":
......
import torch import torch
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32):
# for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128. # for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128.
# if block_K < 128, promote after 128/block_K iters. # if block_K < 128, promote after 128/block_K iters.
# if block_K > 128, promote after every iter. # if block_K > 128, promote after every iter.
...@@ -55,7 +54,7 @@ def calc_diff(x, y): ...@@ -55,7 +54,7 @@ def calc_diff(x, y):
def test_gemm_fp8(M, N, K, dtype): def test_gemm_fp8(M, N, K, dtype):
torch_dtype = map_torch_type(dtype) torch_dtype = T.dtype(dtype).as_torch()
kernel = matmul(M, N, K, 128, 128, 64, dtype) kernel = matmul(M, N, K, 128, 128, 64, dtype)
...@@ -74,8 +73,8 @@ def test_gemm_fp8(M, N, K, dtype): ...@@ -74,8 +73,8 @@ def test_gemm_fp8(M, N, K, dtype):
def main(): def main():
test_gemm_fp8(1024, 1024, 8192, "float8_e4m3") test_gemm_fp8(1024, 1024, 8192, T.float8_e4m3fn)
test_gemm_fp8(1024, 1024, 8192, "float8_e5m2") test_gemm_fp8(1024, 1024, 8192, T.float8_e5m2)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -39,26 +39,26 @@ def tl_matmul( ...@@ -39,26 +39,26 @@ def tl_matmul(
accum_dtype, accum_dtype,
): ):
assert in_dtype in [ assert in_dtype in [
"float16", T.float16,
"float8_e4m3", T.float8_e4m3fn,
"float8_e5m2", T.float8_e5m2,
"int8", T.int8,
], "Currently only float16 and int8 are supported" ], "Currently only float16 and int8 are supported"
assert out_dtype in [ assert out_dtype in [
"float16", T.float16,
"float32", T.float32,
"int32", T.int32,
], "Currently only float16, float32 and int32 are supported" ], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in [ is_float8 = in_dtype in [
"float8_e4m3", T.float8_e4m3fn,
"float8_e5m2", T.float8_e5m2,
"float8_e4m3fn", T.float8_e4m3fn,
"float8_e5m2fnuz", T.float8_e5m2fnuz,
] ]
if out_dtype == "int32" or is_float8: if out_dtype == T.int32 or is_float8:
micro_size_k = 32 micro_size_k = 32
# This is a debug config # This is a debug config
...@@ -66,7 +66,7 @@ def tl_matmul( ...@@ -66,7 +66,7 @@ def tl_matmul(
block_col_warps = 2 block_col_warps = 2
warp_row_tiles = 32 warp_row_tiles = 32
warp_col_tiles = 32 warp_col_tiles = 32
chunk = 32 if in_dtype == "float16" else 64 chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn" shared_scope = "shared.dyn"
# Pipeline Stage # Pipeline Stage
...@@ -220,8 +220,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -220,8 +220,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
def main(): def main():
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32") assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32)
assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32") assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -73,8 +73,8 @@ block_M, block_N, block_K = 64, 256, 32 ...@@ -73,8 +73,8 @@ block_M, block_N, block_K = 64, 256, 32
trans_A, trans_B = False, True trans_A, trans_B = False, True
num_stages = 2 num_stages = 2
threads = 256 threads = 256
for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]: for tvm_fp8_dtype in [T.float8_e4m3fn, T.float8_e5m2]:
for tvm_acc_dtype in ["float16", "float32"]: # , torch.float16]: for tvm_acc_dtype in [T.float16, T.float32]: # , torch.float16]:
torch_fp8_dtype = map_torch_type(tvm_fp8_dtype) torch_fp8_dtype = map_torch_type(tvm_fp8_dtype)
torch_acc_dtype = map_torch_type(tvm_acc_dtype) torch_acc_dtype = map_torch_type(tvm_acc_dtype)
print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}") print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}")
......
...@@ -40,19 +40,19 @@ import tilelang.language as T ...@@ -40,19 +40,19 @@ import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), "bfloat16"), A: T.Tensor((M, K), T.bfloat16),
B: T.Tensor((N, K), "bfloat16"), B: T.Tensor((N, K), T.bfloat16),
C: T.Tensor((M, N), "bfloat16"), C: T.Tensor((M, N), T.bfloat16),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
# 1. Allocate memory buffers # 1. Allocate memory buffers
A_shared = T.alloc_shared((block_M, block_K), "bfloat16") # A matrix shared memory A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) # A matrix shared memory
B_shared = T.alloc_shared((block_N, block_K), "bfloat16") # B matrix shared memory B_shared = T.alloc_shared((block_N, block_K), T.bfloat16) # B matrix shared memory
C_tmem = T.alloc_tmem([block_M, block_N], "float") # TCGEN5MMA output to Tensor Memory C_tmem = T.alloc_tmem([block_M, block_N], T.float) # TCGEN5MMA output to Tensor Memory
mbar = T.alloc_barrier(1) # mbarrier synchronization primitive mbar = T.alloc_barrier(1) # mbarrier synchronization primitive
C_local = T.alloc_fragment((block_M, block_N), "float") # Register storage C_local = T.alloc_fragment((block_M, block_N), T.float) # Register storage
C_shared = T.alloc_shared((block_M, block_N), "bfloat16") # Output shared memory C_shared = T.alloc_shared((block_M, block_N), T.bfloat16) # Output shared memory
# 2. Main computation loop # 2. Main computation loop
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
......
...@@ -4,7 +4,7 @@ import tilelang.language as T ...@@ -4,7 +4,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
......
...@@ -54,7 +54,7 @@ def matmul( ...@@ -54,7 +54,7 @@ def matmul(
M, N, K = 4096, 4096, 8192 M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 128, 256, 128 block_M, block_N, block_K = 128, 256, 128
trans_A, trans_B = False, True trans_A, trans_B = False, True
in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float" in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float
num_stages = 2 num_stages = 2
threads = 256 threads = 256
......
...@@ -17,7 +17,7 @@ torch.manual_seed(42) ...@@ -17,7 +17,7 @@ torch.manual_seed(42)
DEFAULT_CONFIG = { # take best config from autotune script DEFAULT_CONFIG = { # take best config from autotune script
"4090": { "4090": {
"float": { T.float: {
"block_M": 128, "block_M": 128,
"block_N": 64, "block_N": 64,
"block_K": 64, "block_K": 64,
...@@ -26,7 +26,7 @@ DEFAULT_CONFIG = { # take best config from autotune script ...@@ -26,7 +26,7 @@ DEFAULT_CONFIG = { # take best config from autotune script
"policy": T.GemmWarpPolicy.Square, "policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True, "enable_rasterization": True,
}, },
"float16": { T.float16: {
"block_M": 256, "block_M": 256,
"block_N": 128, "block_N": 128,
"block_K": 64, "block_K": 64,
...@@ -37,7 +37,7 @@ DEFAULT_CONFIG = { # take best config from autotune script ...@@ -37,7 +37,7 @@ DEFAULT_CONFIG = { # take best config from autotune script
}, },
}, },
"h20": { "h20": {
"float": { T.float: {
"block_M": 128, "block_M": 128,
"block_N": 64, "block_N": 64,
"block_K": 128, "block_K": 128,
...@@ -46,7 +46,7 @@ DEFAULT_CONFIG = { # take best config from autotune script ...@@ -46,7 +46,7 @@ DEFAULT_CONFIG = { # take best config from autotune script
"policy": T.GemmWarpPolicy.Square, "policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True, "enable_rasterization": True,
}, },
"float16": { T.float16: {
"block_M": 128, "block_M": 128,
"block_N": 64, "block_N": 64,
"block_K": 128, "block_K": 128,
...@@ -65,26 +65,26 @@ ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} ...@@ -65,26 +65,26 @@ ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
def matmul_sp_fp16_custom_compress( def matmul_sp_fp16_custom_compress(
M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization, use_cutlass_layout M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization, use_cutlass_layout
): ):
e_factor, e_dtype = (16, "int16") e_factor, e_dtype = (16, T.int16)
@T.prim_func @T.prim_func
def gemm_sp_fp16_custom_compress( def gemm_sp_fp16_custom_compress(
A_sparse: T.Tensor((M, K // 2), "float16"), A_sparse: T.Tensor((M, K // 2), T.float16),
E: T.Tensor((M, K // e_factor), e_dtype), E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), "float16"), B: T.Tensor((K, N), T.float16),
C: T.Tensor((M, N), accum_dtype), C: T.Tensor((M, N), accum_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), "float16") A_shared = T.alloc_shared((block_M, block_K // 2), T.float16)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), "float16") B_shared = T.alloc_shared((block_K, block_N), T.float16)
C_shared = T.alloc_shared((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
if use_cutlass_layout: if use_cutlass_layout:
T.annotate_layout( T.annotate_layout(
{ {
E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="8.0", block_k=block_K), E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="8.0", block_k=block_K), E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K),
} }
) )
T.clear(C_local) T.clear(C_local)
...@@ -253,15 +253,15 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout): ...@@ -253,15 +253,15 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
if use_cutlass_layout: if use_cutlass_layout:
T.annotate_layout( T.annotate_layout(
{ {
E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="8.0", block_k=block_K), E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="8.0", block_k=block_K), E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K),
} }
) )
T.clear(A_sp_shared) T.clear(A_sp_shared)
T.clear(E_shared) T.clear(E_shared)
# TODO: alloc_var seems buggy here # TODO: alloc_var seems buggy here
non_zero_cnt = T.alloc_local((1,), dtype="uint8") non_zero_cnt = T.alloc_local((1,), dtype=T.uint8)
non_zero_elt_log_idx = T.alloc_local((elem,), dtype="uint8") non_zero_elt_log_idx = T.alloc_local((elem,), dtype=T.uint8)
T.copy(A[bx * block_M, by * block_K], A_shared) T.copy(A[bx * block_M, by * block_K], A_shared)
for tm in T.Parallel(block_M): for tm in T.Parallel(block_M):
for g_i in range(0, block_K // group): for g_i in range(0, block_K // group):
...@@ -300,7 +300,7 @@ def main(): ...@@ -300,7 +300,7 @@ def main():
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor") parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor")
parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference") parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference")
parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090"], default="4090") parser.add_argument("--cfg", type=str, choices=["4090"], default="4090")
args = parser.parse_args() args = parser.parse_args()
kernel = matmul_sp_fp16_custom_compress( kernel = matmul_sp_fp16_custom_compress(
...@@ -314,7 +314,7 @@ def main(): ...@@ -314,7 +314,7 @@ def main():
assert not args.use_cutlass_layout, "torch sparse must be used with naive layout" assert not args.use_cutlass_layout, "torch sparse must be used with naive layout"
a_sparse, e = torch_compress(a) a_sparse, e = torch_compress(a)
else: else:
a_sparse, e = compress_kernel(args.m, args.k, 32, 32, "float16", use_cutlass_layout=args.use_cutlass_layout)(a) a_sparse, e = compress_kernel(args.m, args.k, 32, 32, T.float16, use_cutlass_layout=args.use_cutlass_layout)(a)
c = kernel(a_sparse, e, b) c = kernel(a_sparse, e, b)
......
...@@ -16,7 +16,7 @@ arch = nvcc.get_target_compute_version() ...@@ -16,7 +16,7 @@ arch = nvcc.get_target_compute_version()
DEFAULT_CONFIG = { # take best config from autotune script DEFAULT_CONFIG = { # take best config from autotune script
"4090": { "4090": {
"float": { T.float: {
"block_M": 128, "block_M": 128,
"block_N": 64, "block_N": 64,
"block_K": 64, "block_K": 64,
...@@ -25,7 +25,7 @@ DEFAULT_CONFIG = { # take best config from autotune script ...@@ -25,7 +25,7 @@ DEFAULT_CONFIG = { # take best config from autotune script
"policy": T.GemmWarpPolicy.Square, "policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True, "enable_rasterization": True,
}, },
"float16": { T.float16: {
"block_M": 256, "block_M": 256,
"block_N": 128, "block_N": 128,
"block_K": 64, "block_K": 64,
...@@ -36,7 +36,7 @@ DEFAULT_CONFIG = { # take best config from autotune script ...@@ -36,7 +36,7 @@ DEFAULT_CONFIG = { # take best config from autotune script
}, },
}, },
"h20": { "h20": {
"float": { T.float: {
"block_M": 128, "block_M": 128,
"block_N": 64, "block_N": 64,
"block_K": 128, "block_K": 128,
...@@ -45,7 +45,7 @@ DEFAULT_CONFIG = { # take best config from autotune script ...@@ -45,7 +45,7 @@ DEFAULT_CONFIG = { # take best config from autotune script
"policy": T.GemmWarpPolicy.Square, "policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True, "enable_rasterization": True,
}, },
"float16": { T.float16: {
"block_M": 128, "block_M": 128,
"block_N": 64, "block_N": 64,
"block_K": 128, "block_K": 128,
...@@ -66,15 +66,15 @@ def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, ...@@ -66,15 +66,15 @@ def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages,
@T.prim_func @T.prim_func
def gemm_sp_fp16( def gemm_sp_fp16(
A_sparse: T.Tensor((M, K // 2), "float16"), A_sparse: T.Tensor((M, K // 2), T.float16),
E: T.Tensor((M, K // e_factor), e_dtype), E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), "float16"), B: T.Tensor((K, N), T.float16),
C: T.Tensor((M, N), accum_dtype), C: T.Tensor((M, N), accum_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), "float16") A_shared = T.alloc_shared((block_M, block_K // 2), T.float16)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype) E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), "float16") B_shared = T.alloc_shared((block_K, block_N), T.float16)
C_shared = T.alloc_shared((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -83,8 +83,8 @@ def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, ...@@ -83,8 +83,8 @@ def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages,
T.use_swizzle(panel_size=10, enable=enable_rasterization) T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout( T.annotate_layout(
{ {
E: make_cutlass_metadata_layout(E, mma_dtype="float16", block_k=block_K, arch=arch), E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, block_k=block_K, arch=arch),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", block_k=block_K, arch=arch), E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, block_k=block_K, arch=arch),
} }
) )
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
...@@ -104,7 +104,7 @@ def main(): ...@@ -104,7 +104,7 @@ def main():
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090") parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090")
args = parser.parse_args() args = parser.parse_args()
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype]) kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype])
......
...@@ -3,7 +3,7 @@ import tilelang.language as T ...@@ -3,7 +3,7 @@ import tilelang.language as T
@tilelang.jit @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float", out_dtype="float32"): def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32):
splitK = K // split_k splitK = K // split_k
@T.prim_func @T.prim_func
......
...@@ -3,7 +3,7 @@ import tilelang.language as T ...@@ -3,7 +3,7 @@ import tilelang.language as T
@tilelang.jit @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float", out_dtype="float32"): def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32):
splitK = K // split_k splitK = K // split_k
@T.prim_func @T.prim_func
......
...@@ -87,8 +87,8 @@ def tl_matmul_streamk( ...@@ -87,8 +87,8 @@ def tl_matmul_streamk(
C: T.Tensor, C: T.Tensor,
C_local: T.LocalBuffer, C_local: T.LocalBuffer,
): ):
start_iter = T.alloc_fragment((1,), "int32", "local") start_iter = T.alloc_fragment((1,), T.int32, "local")
end_iter = T.alloc_fragment((1,), "int32", "local") end_iter = T.alloc_fragment((1,), T.int32, "local")
start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles) start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles)
last_iter = (pid + 1) * streamk_full_tiles + T.min(pid + 1, streamk_partial_tiles) last_iter = (pid + 1) * streamk_full_tiles + T.min(pid + 1, streamk_partial_tiles)
...@@ -179,9 +179,9 @@ def main(): ...@@ -179,9 +179,9 @@ def main():
BLOCK_SIZE_K, BLOCK_SIZE_K,
False, False,
True, True,
"float16", T.float16,
"float16", T.float16,
"float32", T.float32,
2, 2,
64, 64,
) )
......
...@@ -17,8 +17,8 @@ def naive_gemv( ...@@ -17,8 +17,8 @@ def naive_gemv(
K: int, K: int,
BLOCK_N: int, BLOCK_N: int,
BLOCK_K: int, BLOCK_K: int,
dtype: str = "float16", dtype: T.dtype = T.float16,
accum_dtype: str = "float", accum_dtype: T.dtype = T.float,
): ):
@T.prim_func @T.prim_func
def main( def main(
...@@ -49,8 +49,8 @@ def naive_splitk_gemv( ...@@ -49,8 +49,8 @@ def naive_splitk_gemv(
K: int, K: int,
BLOCK_N: int, BLOCK_N: int,
BLOCK_K: int, BLOCK_K: int,
dtype: str = "float16", dtype: T.dtype = T.float16,
accum_dtype: str = "float", accum_dtype: T.dtype = T.float,
): ):
@T.prim_func @T.prim_func
def main( def main(
...@@ -85,8 +85,8 @@ def splitk_gemv( ...@@ -85,8 +85,8 @@ def splitk_gemv(
BLOCK_N: int, BLOCK_N: int,
BLOCK_K: int, BLOCK_K: int,
reduce_threads: int, reduce_threads: int,
dtype: str = "float16", dtype: T.dtype = T.float16,
accum_dtype: str = "float", accum_dtype: T.dtype = T.float,
): ):
TILE_K = T.ceildiv(BLOCK_K, reduce_threads) TILE_K = T.ceildiv(BLOCK_K, reduce_threads)
...@@ -124,8 +124,8 @@ def splitk_gemv_vectorized( ...@@ -124,8 +124,8 @@ def splitk_gemv_vectorized(
K: int, K: int,
BLOCK_N: int, BLOCK_N: int,
reduce_threads: int, reduce_threads: int,
dtype: str = "float16", dtype: T.dtype = T.float16,
accum_dtype: str = "float", accum_dtype: T.dtype = T.float,
): ):
MAX_TRANSACTION_SIZE_IN_BITS = 128 MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
...@@ -165,8 +165,8 @@ def splitk_gemv_vectorized_tvm( ...@@ -165,8 +165,8 @@ def splitk_gemv_vectorized_tvm(
K: int, K: int,
BLOCK_N: int, BLOCK_N: int,
reduce_threads: int, reduce_threads: int,
dtype: str = "float16", dtype: T.dtype = T.float16,
accum_dtype: str = "float", accum_dtype: T.dtype = T.float,
): ):
MAX_TRANSACTION_SIZE_IN_BITS = 128 MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
...@@ -233,7 +233,9 @@ def get_block_template_configs(): ...@@ -233,7 +233,9 @@ def get_block_template_configs():
}, },
out_idx=[2], out_idx=[2],
) )
def gemv_alloc_reducer(M, N, block_M=128, block_N=128, num_stages=2, threads=256, dtype: str = "float16", accum_dtype: str = "float"): def gemv_alloc_reducer(
M, N, block_M=128, block_N=128, num_stages=2, threads=256, dtype: T.dtype = T.float16, accum_dtype: T.dtype = T.float
):
@T.prim_func @T.prim_func
def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, dtype)): # type: ignore def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, dtype)): # type: ignore
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m: with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m:
...@@ -274,8 +276,8 @@ def get_autotuned_kernel( ...@@ -274,8 +276,8 @@ def get_autotuned_kernel(
BLOCK_N=None, BLOCK_N=None,
reduce_threads=None, reduce_threads=None,
): ):
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
MAX_TRANSACTION_SIZE_IN_BITS = 128 MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K BLOCK_K = reduce_threads * TILE_K
......
...@@ -6,29 +6,29 @@ import tilelang.language as T ...@@ -6,29 +6,29 @@ import tilelang.language as T
@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) @tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"): def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16):
""" """
args: args:
a (torch.Tensor): Input tensor of shape (M, K). a (torch.Tensor): Input tensor of shape (M, K).
b (torch.Tensor): Input tensor of shape (G, K, N). b (torch.Tensor): Input tensor of shape (G, K, N).
""" """
accum_dtype = "float32" accum_dtype = T.float32
@T.prim_func @T.prim_func
def kernel( def kernel(
A: T.Tensor([batch_sum, K], dtype), # type: ignore A: T.Tensor([batch_sum, K], dtype), # type: ignore
B: T.Tensor([batch_count, K, N], dtype), # type: ignore B: T.Tensor([batch_count, K, N], dtype), # type: ignore
C: T.Tensor([batch_sum, N], dtype), # type: ignore C: T.Tensor([batch_sum, N], dtype), # type: ignore
batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore
batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore
batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore batch_padded_offsets: T.Tensor([batch_count], T.int32), # type: ignore
): ):
with T.Kernel(T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), threads=threads) as (bx, by):
A_shared = T.alloc_shared([block_M, block_K], dtype) A_shared = T.alloc_shared([block_M, block_K], dtype)
B_shared = T.alloc_shared([block_K, block_N], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype)
C_local = T.alloc_fragment([block_M, block_N], accum_dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
cur_batch_idx = T.alloc_local([1], "int32") cur_batch_idx = T.alloc_local([1], T.int32)
cur_batch_size = T.alloc_local([1], "int32") cur_batch_size = T.alloc_local([1], T.int32)
m_start_padded = bx * block_M m_start_padded = bx * block_M
...@@ -158,21 +158,21 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): ...@@ -158,21 +158,21 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) @tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"): def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16):
""" """
args: args:
a (torch.Tensor): Input tensor of shape (M, K). a (torch.Tensor): Input tensor of shape (M, K).
b (torch.Tensor): Input tensor of shape (G, K, N). b (torch.Tensor): Input tensor of shape (G, K, N).
""" """
accum_dtype = "float32" accum_dtype = T.float32
@T.prim_func @T.prim_func
def kernel( def kernel(
A: T.Tensor([batch_sum, M], dtype), # type: ignore A: T.Tensor([batch_sum, M], dtype), # type: ignore
B: T.Tensor([batch_sum, N], dtype), # type: ignore B: T.Tensor([batch_sum, N], dtype), # type: ignore
C: T.Tensor([batch_count, M, N], dtype), # type: ignore C: T.Tensor([batch_count, M, N], dtype), # type: ignore
batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore
batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore
): ):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared([block_K, block_M], dtype) A_shared = T.alloc_shared([block_K, block_M], dtype)
......
...@@ -37,7 +37,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): ...@@ -37,7 +37,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"): def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16):
""" """
args: args:
a (torch.Tensor): Input tensor of shape (M, K). a (torch.Tensor): Input tensor of shape (M, K).
...@@ -45,7 +45,7 @@ def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2 ...@@ -45,7 +45,7 @@ def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2
""" """
batch_sum = sum(batch_sizes_list) batch_sum = sum(batch_sizes_list)
batch_count = len(batch_sizes_list) batch_count = len(batch_sizes_list)
accum_dtype = "float32" accum_dtype = T.float32
total_m_blocks = sum((size + block_M - 1) // block_M for size in batch_sizes_list) total_m_blocks = sum((size + block_M - 1) // block_M for size in batch_sizes_list)
@T.prim_func @T.prim_func
...@@ -53,16 +53,16 @@ def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2 ...@@ -53,16 +53,16 @@ def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2
A: T.Tensor([batch_sum, K], dtype), # type: ignore A: T.Tensor([batch_sum, K], dtype), # type: ignore
B: T.Tensor([batch_count, K, N], dtype), # type: ignore B: T.Tensor([batch_count, K, N], dtype), # type: ignore
C: T.Tensor([batch_sum, N], dtype), # type: ignore C: T.Tensor([batch_sum, N], dtype), # type: ignore
batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore
batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore
batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore batch_padded_offsets: T.Tensor([batch_count], T.int32), # type: ignore
): ):
with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by): with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by):
A_shared = T.alloc_shared([block_M, block_K], dtype) A_shared = T.alloc_shared([block_M, block_K], dtype)
B_shared = T.alloc_shared([block_K, block_N], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype)
C_local = T.alloc_fragment([block_M, block_N], accum_dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
cur_batch_idx = T.alloc_local([1], "int32") cur_batch_idx = T.alloc_local([1], T.int32)
cur_batch_size = T.alloc_local([1], "int32") cur_batch_size = T.alloc_local([1], T.int32)
m_start_padded = bx * block_M m_start_padded = bx * block_M
......
...@@ -17,7 +17,7 @@ def is_pow_of_2(n): ...@@ -17,7 +17,7 @@ def is_pow_of_2(n):
def hadamard(b, n, dtype): def hadamard(b, n, dtype):
assert is_pow_of_2(n), "n must be a power of 2" assert is_pow_of_2(n), "n must be a power of 2"
assert 2 <= n <= 32768, "n must be in [2, 32768]" assert 2 <= n <= 32768, "n must be in [2, 32768]"
elem_size = {"float32": 4, "float16": 2, "bfloat16": 2}[dtype] elem_size = {T.float32: 4, T.float16: 2, T.bfloat16: 2}[dtype]
logN = int(math.log2(n)) logN = int(math.log2(n))
threads = [0, 1, 1, 1, 2, 4, 8, 16, 32, 32, 128, 256, 256, 256, 256, 256][logN] threads = [0, 1, 1, 1, 2, 4, 8, 16, 32, 32, 128, 256, 256, 256, 256, 256][logN]
...@@ -138,7 +138,7 @@ def main(): ...@@ -138,7 +138,7 @@ def main():
B, D = args.batch, args.dim B, D = args.batch, args.dim
x = torch.randn((B, D), device="cuda") x = torch.randn((B, D), device="cuda")
kernel = hadamard(B, D, "float32") kernel = hadamard(B, D, T.float32)
y = kernel(x) y = kernel(x)
y_ref = ref_program(x) y_ref = ref_program(x)
torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2) torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2)
......
...@@ -552,7 +552,7 @@ ...@@ -552,7 +552,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"# from tvm.script import tir as T\n", "# import tilelang.language as T\n",
"\n", "\n",
"@T.prim_func\n", "@T.prim_func\n",
"def foo(x_handle: T.handle):\n", "def foo(x_handle: T.handle):\n",
...@@ -723,7 +723,7 @@ ...@@ -723,7 +723,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"# from tvm.script import tir as T\n", "# import tilelang.language as T\n",
"\n", "\n",
"@T.prim_func\n", "@T.prim_func\n",
"def foo():\n", "def foo():\n",
...@@ -786,4 +786,4 @@ ...@@ -786,4 +786,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 5 "nbformat_minor": 5
} }
\ No newline at end of file
...@@ -552,7 +552,7 @@ ...@@ -552,7 +552,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"# from tvm.script import tir as T\n", "# import tilelang.language as T\n",
"\n", "\n",
"@T.prim_func\n", "@T.prim_func\n",
"def foo(x_handle: T.handle):\n", "def foo(x_handle: T.handle):\n",
...@@ -723,7 +723,7 @@ ...@@ -723,7 +723,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"# from tvm.script import tir as T\n", "# import tilelang.language as T\n",
"\n", "\n",
"@T.prim_func\n", "@T.prim_func\n",
"def foo():\n", "def foo():\n",
...@@ -786,4 +786,4 @@ ...@@ -786,4 +786,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 5 "nbformat_minor": 5
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment