Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
......@@ -3,7 +3,7 @@ import tilelang.language as T
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def gemm_schedule(
A: T.Tensor((M, K), dtype),
......
......@@ -53,8 +53,8 @@ def get_configs():
)
@tilelang.jit(out_idx=[-1])
def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type):
dtype = "float8_e4m3fnuz"
accum_dtype = "float"
dtype = T.float8_e4m3fnuz
accum_dtype = T.float32
@T.prim_func
def gemm_fp8_rs(
......
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
def calc_diff(x, y):
......@@ -12,7 +11,7 @@ def calc_diff(x, y):
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32):
@T.prim_func
def gemm_fp8(
A: T.Tensor((M, K), dtype),
......@@ -36,7 +35,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
def test_gemm_fp8(M, N, K, dtype):
torch_dtype = map_torch_type(dtype)
torch_dtype = T.dtype(dtype).as_torch()
kernel = matmul(M, N, K, 128, 128, 64, dtype)
......@@ -56,8 +55,8 @@ def test_gemm_fp8(M, N, K, dtype):
def main():
test_gemm_fp8(1024, 1024, 1024, "float8_e4m3")
test_gemm_fp8(1024, 1024, 1024, "float8_e5m2")
test_gemm_fp8(1024, 1024, 1024, T.float8_e4m3fn)
test_gemm_fp8(1024, 1024, 1024, T.float8_e5m2)
if __name__ == "__main__":
......
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype=T.float32):
# for fp8 gemm, do one promote after 4 wgmma inst, i.e. block_K = 128.
# if block_K < 128, promote after 128/block_K iters.
# if block_K > 128, promote after every iter.
......@@ -55,7 +54,7 @@ def calc_diff(x, y):
def test_gemm_fp8(M, N, K, dtype):
torch_dtype = map_torch_type(dtype)
torch_dtype = T.dtype(dtype).as_torch()
kernel = matmul(M, N, K, 128, 128, 64, dtype)
......@@ -74,8 +73,8 @@ def test_gemm_fp8(M, N, K, dtype):
def main():
test_gemm_fp8(1024, 1024, 8192, "float8_e4m3")
test_gemm_fp8(1024, 1024, 8192, "float8_e5m2")
test_gemm_fp8(1024, 1024, 8192, T.float8_e4m3fn)
test_gemm_fp8(1024, 1024, 8192, T.float8_e5m2)
if __name__ == "__main__":
......
......@@ -39,26 +39,26 @@ def tl_matmul(
accum_dtype,
):
assert in_dtype in [
"float16",
"float8_e4m3",
"float8_e5m2",
"int8",
T.float16,
T.float8_e4m3fn,
T.float8_e5m2,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in [
"float8_e4m3",
"float8_e5m2",
"float8_e4m3fn",
"float8_e5m2fnuz",
T.float8_e4m3fn,
T.float8_e5m2,
T.float8_e4m3fn,
T.float8_e5m2fnuz,
]
if out_dtype == "int32" or is_float8:
if out_dtype == T.int32 or is_float8:
micro_size_k = 32
# This is a debug config
......@@ -66,7 +66,7 @@ def tl_matmul(
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32 if in_dtype == "float16" else 64
chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn"
# Pipeline Stage
......@@ -220,8 +220,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
def main():
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32)
assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32)
if __name__ == "__main__":
......
......@@ -73,8 +73,8 @@ block_M, block_N, block_K = 64, 256, 32
trans_A, trans_B = False, True
num_stages = 2
threads = 256
for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]:
for tvm_acc_dtype in ["float16", "float32"]: # , torch.float16]:
for tvm_fp8_dtype in [T.float8_e4m3fn, T.float8_e5m2]:
for tvm_acc_dtype in [T.float16, T.float32]: # , torch.float16]:
torch_fp8_dtype = map_torch_type(tvm_fp8_dtype)
torch_acc_dtype = map_torch_type(tvm_acc_dtype)
print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}")
......
......@@ -40,19 +40,19 @@ import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, K), "bfloat16"),
B: T.Tensor((N, K), "bfloat16"),
C: T.Tensor((M, N), "bfloat16"),
A: T.Tensor((M, K), T.bfloat16),
B: T.Tensor((N, K), T.bfloat16),
C: T.Tensor((M, N), T.bfloat16),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
# 1. Allocate memory buffers
A_shared = T.alloc_shared((block_M, block_K), "bfloat16") # A matrix shared memory
B_shared = T.alloc_shared((block_N, block_K), "bfloat16") # B matrix shared memory
C_tmem = T.alloc_tmem([block_M, block_N], "float") # TCGEN5MMA output to Tensor Memory
A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) # A matrix shared memory
B_shared = T.alloc_shared((block_N, block_K), T.bfloat16) # B matrix shared memory
C_tmem = T.alloc_tmem([block_M, block_N], T.float) # TCGEN5MMA output to Tensor Memory
mbar = T.alloc_barrier(1) # mbarrier synchronization primitive
C_local = T.alloc_fragment((block_M, block_N), "float") # Register storage
C_shared = T.alloc_shared((block_M, block_N), "bfloat16") # Output shared memory
C_local = T.alloc_fragment((block_M, block_N), T.float) # Register storage
C_shared = T.alloc_shared((block_M, block_N), T.bfloat16) # Output shared memory
# 2. Main computation loop
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
......
......@@ -4,7 +4,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......
......@@ -54,7 +54,7 @@ def matmul(
M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 128, 256, 128
trans_A, trans_B = False, True
in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float"
in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float
num_stages = 2
threads = 256
......
......@@ -17,7 +17,7 @@ torch.manual_seed(42)
DEFAULT_CONFIG = { # take best config from autotune script
"4090": {
"float": {
T.float: {
"block_M": 128,
"block_N": 64,
"block_K": 64,
......@@ -26,7 +26,7 @@ DEFAULT_CONFIG = { # take best config from autotune script
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
"float16": {
T.float16: {
"block_M": 256,
"block_N": 128,
"block_K": 64,
......@@ -37,7 +37,7 @@ DEFAULT_CONFIG = { # take best config from autotune script
},
},
"h20": {
"float": {
T.float: {
"block_M": 128,
"block_N": 64,
"block_K": 128,
......@@ -46,7 +46,7 @@ DEFAULT_CONFIG = { # take best config from autotune script
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
"float16": {
T.float16: {
"block_M": 128,
"block_N": 64,
"block_K": 128,
......@@ -65,26 +65,26 @@ ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
def matmul_sp_fp16_custom_compress(
M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy, enable_rasterization, use_cutlass_layout
):
e_factor, e_dtype = (16, "int16")
e_factor, e_dtype = (16, T.int16)
@T.prim_func
def gemm_sp_fp16_custom_compress(
A_sparse: T.Tensor((M, K // 2), "float16"),
A_sparse: T.Tensor((M, K // 2), T.float16),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), "float16"),
B: T.Tensor((K, N), T.float16),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), "float16")
A_shared = T.alloc_shared((block_M, block_K // 2), T.float16)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), "float16")
B_shared = T.alloc_shared((block_K, block_N), T.float16)
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
if use_cutlass_layout:
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="8.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="8.0", block_k=block_K),
E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K),
}
)
T.clear(C_local)
......@@ -253,15 +253,15 @@ def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
if use_cutlass_layout:
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="8.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="8.0", block_k=block_K),
E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="8.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="8.0", block_k=block_K),
}
)
T.clear(A_sp_shared)
T.clear(E_shared)
# TODO: alloc_var seems buggy here
non_zero_cnt = T.alloc_local((1,), dtype="uint8")
non_zero_elt_log_idx = T.alloc_local((elem,), dtype="uint8")
non_zero_cnt = T.alloc_local((1,), dtype=T.uint8)
non_zero_elt_log_idx = T.alloc_local((elem,), dtype=T.uint8)
T.copy(A[bx * block_M, by * block_K], A_shared)
for tm in T.Parallel(block_M):
for g_i in range(0, block_K // group):
......@@ -300,7 +300,7 @@ def main():
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor")
parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference")
parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")
parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090"], default="4090")
args = parser.parse_args()
kernel = matmul_sp_fp16_custom_compress(
......@@ -314,7 +314,7 @@ def main():
assert not args.use_cutlass_layout, "torch sparse must be used with naive layout"
a_sparse, e = torch_compress(a)
else:
a_sparse, e = compress_kernel(args.m, args.k, 32, 32, "float16", use_cutlass_layout=args.use_cutlass_layout)(a)
a_sparse, e = compress_kernel(args.m, args.k, 32, 32, T.float16, use_cutlass_layout=args.use_cutlass_layout)(a)
c = kernel(a_sparse, e, b)
......
......@@ -16,7 +16,7 @@ arch = nvcc.get_target_compute_version()
DEFAULT_CONFIG = { # take best config from autotune script
"4090": {
"float": {
T.float: {
"block_M": 128,
"block_N": 64,
"block_K": 64,
......@@ -25,7 +25,7 @@ DEFAULT_CONFIG = { # take best config from autotune script
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
"float16": {
T.float16: {
"block_M": 256,
"block_N": 128,
"block_K": 64,
......@@ -36,7 +36,7 @@ DEFAULT_CONFIG = { # take best config from autotune script
},
},
"h20": {
"float": {
T.float: {
"block_M": 128,
"block_N": 64,
"block_K": 128,
......@@ -45,7 +45,7 @@ DEFAULT_CONFIG = { # take best config from autotune script
"policy": T.GemmWarpPolicy.Square,
"enable_rasterization": True,
},
"float16": {
T.float16: {
"block_M": 128,
"block_N": 64,
"block_K": 128,
......@@ -66,15 +66,15 @@ def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages,
@T.prim_func
def gemm_sp_fp16(
A_sparse: T.Tensor((M, K // 2), "float16"),
A_sparse: T.Tensor((M, K // 2), T.float16),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), "float16"),
B: T.Tensor((K, N), T.float16),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), "float16")
A_shared = T.alloc_shared((block_M, block_K // 2), T.float16)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), "float16")
B_shared = T.alloc_shared((block_K, block_N), T.float16)
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -83,8 +83,8 @@ def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages,
T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype="float16", block_k=block_K, arch=arch),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", block_k=block_K, arch=arch),
E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, block_k=block_K, arch=arch),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, block_k=block_K, arch=arch),
}
)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
......@@ -104,7 +104,7 @@ def main():
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")
parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090")
args = parser.parse_args()
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype, **DEFAULT_CONFIG[args.cfg][args.accum_dtype])
......
......@@ -3,7 +3,7 @@ import tilelang.language as T
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float", out_dtype="float32"):
def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32):
splitK = K // split_k
@T.prim_func
......
......@@ -3,7 +3,7 @@ import tilelang.language as T
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float", out_dtype="float32"):
def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype=T.float16, accum_dtype=T.float32, out_dtype=T.float32):
splitK = K // split_k
@T.prim_func
......
......@@ -87,8 +87,8 @@ def tl_matmul_streamk(
C: T.Tensor,
C_local: T.LocalBuffer,
):
start_iter = T.alloc_fragment((1,), "int32", "local")
end_iter = T.alloc_fragment((1,), "int32", "local")
start_iter = T.alloc_fragment((1,), T.int32, "local")
end_iter = T.alloc_fragment((1,), T.int32, "local")
start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles)
last_iter = (pid + 1) * streamk_full_tiles + T.min(pid + 1, streamk_partial_tiles)
......@@ -179,9 +179,9 @@ def main():
BLOCK_SIZE_K,
False,
True,
"float16",
"float16",
"float32",
T.float16,
T.float16,
T.float32,
2,
64,
)
......
......@@ -17,8 +17,8 @@ def naive_gemv(
K: int,
BLOCK_N: int,
BLOCK_K: int,
dtype: str = "float16",
accum_dtype: str = "float",
dtype: T.dtype = T.float16,
accum_dtype: T.dtype = T.float,
):
@T.prim_func
def main(
......@@ -49,8 +49,8 @@ def naive_splitk_gemv(
K: int,
BLOCK_N: int,
BLOCK_K: int,
dtype: str = "float16",
accum_dtype: str = "float",
dtype: T.dtype = T.float16,
accum_dtype: T.dtype = T.float,
):
@T.prim_func
def main(
......@@ -85,8 +85,8 @@ def splitk_gemv(
BLOCK_N: int,
BLOCK_K: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
dtype: T.dtype = T.float16,
accum_dtype: T.dtype = T.float,
):
TILE_K = T.ceildiv(BLOCK_K, reduce_threads)
......@@ -124,8 +124,8 @@ def splitk_gemv_vectorized(
K: int,
BLOCK_N: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
dtype: T.dtype = T.float16,
accum_dtype: T.dtype = T.float,
):
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
......@@ -165,8 +165,8 @@ def splitk_gemv_vectorized_tvm(
K: int,
BLOCK_N: int,
reduce_threads: int,
dtype: str = "float16",
accum_dtype: str = "float",
dtype: T.dtype = T.float16,
accum_dtype: T.dtype = T.float,
):
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
......@@ -233,7 +233,9 @@ def get_block_template_configs():
},
out_idx=[2],
)
def gemv_alloc_reducer(M, N, block_M=128, block_N=128, num_stages=2, threads=256, dtype: str = "float16", accum_dtype: str = "float"):
def gemv_alloc_reducer(
M, N, block_M=128, block_N=128, num_stages=2, threads=256, dtype: T.dtype = T.float16, accum_dtype: T.dtype = T.float
):
@T.prim_func
def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M, dtype)): # type: ignore
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m:
......@@ -274,8 +276,8 @@ def get_autotuned_kernel(
BLOCK_N=None,
reduce_threads=None,
):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
......
......@@ -6,29 +6,29 @@ import tilelang.language as T
@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"):
def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
b (torch.Tensor): Input tensor of shape (G, K, N).
"""
accum_dtype = "float32"
accum_dtype = T.float32
@T.prim_func
def kernel(
A: T.Tensor([batch_sum, K], dtype), # type: ignore
B: T.Tensor([batch_count, K, N], dtype), # type: ignore
C: T.Tensor([batch_sum, N], dtype), # type: ignore
batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore
batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore
batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore
batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore
batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore
batch_padded_offsets: T.Tensor([batch_count], T.int32), # type: ignore
):
with T.Kernel(T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N), threads=threads) as (bx, by):
A_shared = T.alloc_shared([block_M, block_K], dtype)
B_shared = T.alloc_shared([block_K, block_N], dtype)
C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
cur_batch_idx = T.alloc_local([1], "int32")
cur_batch_size = T.alloc_local([1], "int32")
cur_batch_idx = T.alloc_local([1], T.int32)
cur_batch_size = T.alloc_local([1], T.int32)
m_start_padded = bx * block_M
......@@ -158,21 +158,21 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"):
def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
b (torch.Tensor): Input tensor of shape (G, K, N).
"""
accum_dtype = "float32"
accum_dtype = T.float32
@T.prim_func
def kernel(
A: T.Tensor([batch_sum, M], dtype), # type: ignore
B: T.Tensor([batch_sum, N], dtype), # type: ignore
C: T.Tensor([batch_count, M, N], dtype), # type: ignore
batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore
batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore
batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore
batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore
):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count, threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared([block_K, block_M], dtype)
......
......@@ -37,7 +37,7 @@ def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
@tilelang.jit(out_idx=[2])
def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"):
def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
......@@ -45,7 +45,7 @@ def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2
"""
batch_sum = sum(batch_sizes_list)
batch_count = len(batch_sizes_list)
accum_dtype = "float32"
accum_dtype = T.float32
total_m_blocks = sum((size + block_M - 1) // block_M for size in batch_sizes_list)
@T.prim_func
......@@ -53,16 +53,16 @@ def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2
A: T.Tensor([batch_sum, K], dtype), # type: ignore
B: T.Tensor([batch_count, K, N], dtype), # type: ignore
C: T.Tensor([batch_sum, N], dtype), # type: ignore
batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore
batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore
batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore
batch_sizes: T.Tensor([batch_count], T.int32), # type: ignore
batch_offsets: T.Tensor([batch_count], T.int32), # type: ignore
batch_padded_offsets: T.Tensor([batch_count], T.int32), # type: ignore
):
with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by):
A_shared = T.alloc_shared([block_M, block_K], dtype)
B_shared = T.alloc_shared([block_K, block_N], dtype)
C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
cur_batch_idx = T.alloc_local([1], "int32")
cur_batch_size = T.alloc_local([1], "int32")
cur_batch_idx = T.alloc_local([1], T.int32)
cur_batch_size = T.alloc_local([1], T.int32)
m_start_padded = bx * block_M
......
......@@ -17,7 +17,7 @@ def is_pow_of_2(n):
def hadamard(b, n, dtype):
assert is_pow_of_2(n), "n must be a power of 2"
assert 2 <= n <= 32768, "n must be in [2, 32768]"
elem_size = {"float32": 4, "float16": 2, "bfloat16": 2}[dtype]
elem_size = {T.float32: 4, T.float16: 2, T.bfloat16: 2}[dtype]
logN = int(math.log2(n))
threads = [0, 1, 1, 1, 2, 4, 8, 16, 32, 32, 128, 256, 256, 256, 256, 256][logN]
......@@ -138,7 +138,7 @@ def main():
B, D = args.batch, args.dim
x = torch.randn((B, D), device="cuda")
kernel = hadamard(B, D, "float32")
kernel = hadamard(B, D, T.float32)
y = kernel(x)
y_ref = ref_program(x)
torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2)
......
......@@ -552,7 +552,7 @@
{
"data": {
"text/plain": [
"# from tvm.script import tir as T\n",
"# import tilelang.language as T\n",
"\n",
"@T.prim_func\n",
"def foo(x_handle: T.handle):\n",
......@@ -723,7 +723,7 @@
{
"data": {
"text/plain": [
"# from tvm.script import tir as T\n",
"# import tilelang.language as T\n",
"\n",
"@T.prim_func\n",
"def foo():\n",
......@@ -786,4 +786,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
\ No newline at end of file
......@@ -552,7 +552,7 @@
{
"data": {
"text/plain": [
"# from tvm.script import tir as T\n",
"# import tilelang.language as T\n",
"\n",
"@T.prim_func\n",
"def foo(x_handle: T.handle):\n",
......@@ -723,7 +723,7 @@
{
"data": {
"text/plain": [
"# from tvm.script import tir as T\n",
"# import tilelang.language as T\n",
"\n",
"@T.prim_func\n",
"def foo():\n",
......@@ -786,4 +786,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment