Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
......@@ -25,7 +25,7 @@ def _check(original, transformed):
M = 512
N = 512
K = 512
dtype = "float16"
dtype = T.float16
block_M = 64
block_N = 64
block_K = 32
......@@ -40,15 +40,15 @@ def test_warp_specialized():
with T.block(""):
T.reads(A[by * 64, 0:481], B[0:481, bx * 64])
T.writes()
A_shared = T.alloc_buffer((3, 1, 8, 256), "float16", scope="shared.dyn")
B_shared = T.alloc_buffer((3, 1, 4, 512), "float16", scope="shared.dyn")
A_shared = T.alloc_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn")
B_shared = T.alloc_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn")
C_local = T.alloc_buffer((32,), scope="local")
for k in T.serial(16, annotations={"num_stages": T.int32(3)}):
if v == 0:
T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
0,
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32,
by * 64,
)
......@@ -56,16 +56,16 @@ def test_warp_specialized():
T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
0,
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 2),
bx * 64,
k * 32,
)
T.call_extern(
"handle",
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3),
)
@T.prim_func
......@@ -73,8 +73,8 @@ def test_warp_specialized():
bx = T.launch_thread("blockIdx.x", 8)
by = T.launch_thread("blockIdx.y", 8)
v = T.launch_thread("threadIdx.x", 256)
A_shared = T.decl_buffer((3, 1, 8, 256), "float16", scope="shared.dyn")
B_shared = T.decl_buffer((3, 1, 4, 512), "float16", scope="shared.dyn")
A_shared = T.decl_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn")
B_shared = T.decl_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn")
C_local = T.decl_buffer((32,), scope="local")
T.create_list_of_mbarrier(128, 128, 128, 128, 128, 128)
T.attr([128, 128], "kWarpSpecializationScope", 0)
......@@ -88,7 +88,7 @@ def test_warp_specialized():
T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
T.get_mbarrier(k % 3),
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32,
by * 64,
)
......@@ -98,7 +98,7 @@ def test_warp_specialized():
T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
T.get_mbarrier(k % 3),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 2),
bx * 64,
k * 32,
)
......@@ -110,9 +110,9 @@ def test_warp_specialized():
T.call_extern(
"handle",
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3),
)
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
......
......@@ -4,7 +4,7 @@ import tilelang.testing
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......@@ -38,8 +38,8 @@ def assert_gemm_codegen(
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float",
dtype=T.float16,
accum_dtype=T.float32,
):
func = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype)
# Because the current pass context have been polluted by previous testing.
......
......@@ -141,6 +141,7 @@ from . import (
engine, # noqa: F401
tools, # noqa: F401
)
from .language.v2 import dtypes # noqa: F401
from .autotuner import autotune # noqa: F401
from .transform import PassConfigKey # noqa: F401
......
......@@ -6,7 +6,7 @@ from dataclasses import dataclass
import torch
from tilelang import tvm as tvm
from tvm.tir import Buffer, IntImm, Var, PrimExpr
from tilelang.utils.tensor import map_torch_type
import tilelang.language as T
@dataclass
......@@ -138,7 +138,7 @@ class KernelParam:
>>> param = KernelParam.from_buffer(buffer)
>>> tensor = torch.empty(shape, dtype=param.torch_dtype())
"""
return map_torch_type(str(self.dtype))
return T.dtype(self.dtype).as_torch()
@dataclass
......
......@@ -61,9 +61,9 @@ class MatrixCoreIntrinEmitter:
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_dtype: str = T.float16,
b_dtype: str = T.float16,
accum_dtype: str = T.float16,
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
......@@ -105,9 +105,9 @@ class MatrixCoreIntrinEmitter:
self.num_elems_per_byte = num_elems_per_byte
self.thread_var = thread_var
def _initialize_k_dim(self, a_dtype="float16"):
def _initialize_k_dim(self, a_dtype=T.float16):
if isinstance(a_dtype, str):
if a_dtype in ["float8_e4m3fnuz", "int8"]:
if a_dtype in ["float8_e4m3fnuz", T.int8]:
self.k_dim = 32
return
a_dtype = DataType(a_dtype)
......@@ -132,7 +132,7 @@ class MatrixCoreIntrinEmitter:
def _initialize_mfma_prefix(self, k_dim=16):
in_dtype, out_dtype = self.a_dtype, self.accum_dtype
M_DIM, N_DIM = self.M_DIM, self.N_DIM
out_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"}[out_dtype]
out_dtype_abbrv = {T.float16: "f16", T.float32: "f32", T.int8: "i8", T.int32: "i32"}[out_dtype]
in_dtype_abbrv = {
"bfloat16": "bf16",
......@@ -221,7 +221,7 @@ class MatrixCoreIntrinEmitter:
def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(mfma_store_index_map, index_dtype="int32")
index_map = IndexMap.from_func(mfma_store_index_map, index_dtype=T.int32)
if not inverse:
return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c])
......@@ -521,7 +521,7 @@ class MatrixCoreIntrinEmitter:
self.block_col_warps,
)
inverse_mfma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
inverse_mfma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32)
def forward_thread(i: int, j: int) -> int:
"""
......@@ -670,9 +670,9 @@ class MatrixCoreIntrinEmitter:
class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_dtype: str = T.float16,
b_dtype: str = T.float16,
accum_dtype: str = T.float16,
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
......
......@@ -60,9 +60,9 @@ class TensorCoreIntrinEmitter:
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_dtype: str = T.float16,
b_dtype: str = T.float16,
accum_dtype: str = T.float16,
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
......@@ -108,7 +108,7 @@ class TensorCoreIntrinEmitter:
f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}"
)
def _initialize_k_dim(self, a_dtype="float16"):
def _initialize_k_dim(self, a_dtype=T.float16):
if isinstance(a_dtype, str):
a_dtype = DataType(a_dtype)
self.k_dim = 256 // a_dtype.bits
......@@ -194,9 +194,9 @@ class TensorCoreIntrinEmitter:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
if DataType(self.accum_dtype).bits == 64:
index_map = IndexMap.from_func(mma_store_index_map_fp64, index_dtype="int32")
index_map = IndexMap.from_func(mma_store_index_map_fp64, index_dtype=T.int32)
else:
index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32")
index_map = IndexMap.from_func(mma_store_index_map, index_dtype=T.int32)
if not inverse:
return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c])
......@@ -649,7 +649,7 @@ class TensorCoreIntrinEmitter:
self.block_col_warps,
)
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32)
def forward_thread(i: int, j: int) -> int:
"""
......@@ -806,9 +806,9 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_dtype: str = T.float16,
b_dtype: str = T.float16,
accum_dtype: str = T.float16,
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
......@@ -839,7 +839,7 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
)
self._initialize_transform_kind(transform_kind_a, transform_kind_b)
def _initialize_k_dim(self, a_dtype="float16"):
def _initialize_k_dim(self, a_dtype=T.float16):
self.k_dim = 256 // DataType(a_dtype).bits
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32):
......@@ -1266,7 +1266,7 @@ class INT4TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitterWith
a_dtype_abbrv = "int4"
b_dtype_abbrv = "int4"
accum_dtype = self.accum_dtype
accum_dtype_abbrv = "int32"
accum_dtype_abbrv = T.int32
mma_prefix = "m16n8k32"
@T.macro
......
......@@ -46,9 +46,9 @@ class TensorCoreIntrinEmitter:
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_dtype: str = T.float16,
b_dtype: str = T.float16,
accum_dtype: str = T.float16,
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
......@@ -89,7 +89,7 @@ class TensorCoreIntrinEmitter:
f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}"
)
def _initialize_k_dim(self, a_dtype="float16"):
def _initialize_k_dim(self, a_dtype=T.float16):
self.k_dim = 4
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16):
......@@ -147,8 +147,8 @@ class TensorCoreIntrinEmitter:
def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(
mma_32x8_to_shared_16x16_layout_fp32 if self.accum_dtype == "float32" else mma_32x8_to_shared_16x16_layout_fp16,
index_dtype="int32",
mma_32x8_to_shared_16x16_layout_fp32 if self.accum_dtype == T.float32 else mma_32x8_to_shared_16x16_layout_fp16,
index_dtype=T.int32,
)
if not inverse:
return index_map
......@@ -383,7 +383,7 @@ class TensorCoreIntrinEmitter:
self.block_col_warps,
)
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32)
def forward(i: int, j: int, rep: int) -> int:
"""
......
......@@ -133,10 +133,10 @@ class SparseTensorCoreIntrinEmitter:
def __init__(
self,
a_dtype: str = "float16",
e_dtype: str = "uint8",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_dtype: str = T.float16,
e_dtype: str = T.uint8,
b_dtype: str = T.float16,
accum_dtype: str = T.float16,
a_transposed: bool = False,
b_transposed: bool = False,
e_transposed: bool = False,
......@@ -181,7 +181,7 @@ class SparseTensorCoreIntrinEmitter:
f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}"
)
def _initialize_k_dim(self, a_dtype="float16"):
def _initialize_k_dim(self, a_dtype=T.float16):
if isinstance(a_dtype, str):
a_dtype = DataType(a_dtype)
# NOTE: k_dim here represents the logical shape of the MMA operation.
......@@ -250,7 +250,7 @@ class SparseTensorCoreIntrinEmitter:
def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32")
index_map = IndexMap.from_func(mma_store_index_map, index_dtype=T.int32)
if not inverse:
return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c])
......@@ -708,7 +708,7 @@ class SparseTensorCoreIntrinEmitter:
self.block_col_warps,
)
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32)
def forward_thread(i: int, j: int) -> int:
"""
......
......@@ -73,9 +73,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_dtype: str = T.float16,
b_dtype: str = T.float16,
accum_dtype: str = T.float16,
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
......@@ -245,7 +245,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
)
# Allocate an instruction descriptor wrapper and initialize it
a_dtype_abbrv = self.a_dtype_abbrv
mask_zero = T.Cast("int32", 0)
mask_zero = T.Cast(T.int32, 0)
mask0 = mask1 = mask2 = mask3 = mask_zero
# TCGEN05 only has one warp group
......
......@@ -83,9 +83,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
def __init__(
self,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_dtype: str = T.float16,
b_dtype: str = T.float16,
accum_dtype: str = T.float16,
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
......@@ -515,7 +515,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
self.block_col_warps,
)
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32)
def forward_thread(i: int, j: int) -> int:
"""
......
......@@ -28,6 +28,7 @@ from tvm.tir import PrimExpr
from tvm.script.parser.tir import block_attr
from tvm.tir.buffer import Buffer
from tvm.tir.expr import FloatImm, IntImm
from .v2 import dtypes as _dtypes
from .v2.dtypes import dtype as tl_dtype
from .v2.builder import OutTensor
from .v2.annot import Tensor, SharedBuffer, LocalBuffer, FragmentBuffer
......@@ -158,7 +159,7 @@ def alloc_barrier(arrive_count: int):
Returns:
T.Buffer: A TVM buffer object allocated as a barrier
"""
return T.alloc_buffer([arrive_count], "uint64", scope="shared.barrier")
return T.alloc_buffer([arrive_count], _dtypes.uint64, scope="shared.barrier")
def alloc_tmem(shape, dtype):
......@@ -231,7 +232,7 @@ DescKind = Literal["wgmma", "tcgen05_smem", "tcgen05_instr"]
def alloc_descriptor(
kind: DescKind = "wgmma",
dtype: str = "uint64",
dtype: str = _dtypes.uint64,
):
"""Allocate a descriptor buffer for WGMMA and TCGEN5.MMA.
......@@ -248,28 +249,28 @@ def alloc_descriptor(
return T.alloc_buffer([1], dtype, scope=scope)
def alloc_wgmma_desc(dtype: str = "uint64"):
def alloc_wgmma_desc(dtype: str = _dtypes.uint64):
return alloc_descriptor("wgmma", dtype=dtype)
def alloc_tcgen05_smem_desc(dtype: str = "uint64"):
def alloc_tcgen05_smem_desc(dtype: str = _dtypes.uint64):
return alloc_descriptor("tcgen05_smem", dtype=dtype)
def alloc_tcgen05_instruction_desc(dtype: str = "uint32"):
def alloc_tcgen05_instruction_desc(dtype: str = _dtypes.uint32):
return alloc_descriptor("tcgen05_instr", dtype=dtype)
# Alias: short name consistent with imports
def alloc_tcgen05_instr_desc(dtype: str = "uint32"):
def alloc_tcgen05_instr_desc(dtype: str = _dtypes.uint32):
return alloc_tcgen05_instruction_desc(dtype)
@overload
def empty(shape: tuple[Unpack[_Shapes]], dtype: str = "float32") -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ...
def empty(shape: tuple[Unpack[_Shapes]], dtype: str = _dtypes.float32) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ...
def empty(*shape: Unpack[_Shapes], dtype: str = "float32") -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
def empty(*shape: Unpack[_Shapes], dtype: str = _dtypes.float32) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
if len(shape) == 1 and isinstance(shape[0], (tuple, list)):
return OutTensor(shape[0], dtype)
elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str):
......
......@@ -92,7 +92,7 @@ from tvm.script.ir_builder.tir import frame
def buffer(
shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
dtype: str = "float32",
dtype: str = T.float32,
data: Var = None,
strides: List[PrimExpr] = None,
elem_offset: PrimExpr = None,
......@@ -143,7 +143,7 @@ def buffer(
"""
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
if strides is not None:
strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides]
strides = [Var(s, T.int32) if isinstance(s, str) else s for s in strides]
else:
strides = []
return _ffi_api.Buffer( # type: ignore[attr-defined] # pylint: disable=no-member
......@@ -244,7 +244,7 @@ def func_ret(ret_type: Type) -> Type:
def match_buffer(
param: Union[Var, BufferLoad, BufferRegion],
shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] = None,
dtype: str = "float32",
dtype: str = T.float32,
data: Var = None,
strides: List[PrimExpr] = None,
elem_offset: PrimExpr = None,
......@@ -266,11 +266,11 @@ def match_buffer(
-------
Match buffer from function parameter
.. code-block:: python
A = T.match_buffer(a, (128, 128), dtype="float32")
A = T.match_buffer(a, (128, 128), dtype=T.float32)
Match buffer from Buffer subregion
.. code-block:: python
A = T.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32")
A = T.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype=T.float32)
Parameters
----------
......@@ -320,7 +320,7 @@ def match_buffer(
raise ValueError("Shape must be specified when binding input param")
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
if strides is not None:
idx_dtype = shape[0].dtype if isinstance(shape[0], PrimExpr) else "int32"
idx_dtype = shape[0].dtype if isinstance(shape[0], PrimExpr) else T.int32
strides = [Var(s, idx_dtype) if isinstance(s, str) else s for s in strides]
else:
strides = []
......@@ -440,7 +440,7 @@ def block_attr(attrs: Dict[str, Any]) -> None:
def alloc_buffer(
shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
dtype: str = "float32",
dtype: str = T.float32,
data: Var = None,
strides: List[PrimExpr] = None,
elem_offset: PrimExpr = None,
......@@ -491,7 +491,7 @@ def alloc_buffer(
"""
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
if strides is not None:
strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides]
strides = [Var(s, T.int32) if isinstance(s, str) else s for s in strides]
else:
strides = []
return _ffi_api.AllocBuffer( # type: ignore[attr-defined] # pylint: disable=no-member
......@@ -537,7 +537,7 @@ class axis: # pylint: disable=invalid-name
def spatial(
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr,
dtype: str = "int32",
dtype: str = T.int32,
) -> Var:
"""The spatial block axis defining function.
......@@ -565,7 +565,7 @@ class axis: # pylint: disable=invalid-name
def reduce(
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr,
dtype: str = "int32",
dtype: str = T.int32,
) -> Var:
"""The reduced block axis defining function.
......@@ -593,7 +593,7 @@ class axis: # pylint: disable=invalid-name
def scan(
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr,
dtype: str = "int32",
dtype: str = T.int32,
) -> Var:
"""The scanning block axis defining function.
......@@ -621,7 +621,7 @@ class axis: # pylint: disable=invalid-name
def opaque(
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr,
dtype: str = "int32",
dtype: str = T.int32,
) -> Var:
"""The opaque block axis defining function.
......@@ -646,7 +646,7 @@ class axis: # pylint: disable=invalid-name
)
@staticmethod
def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[List[Var], Var]:
def remap(kinds: str, bindings: List[PrimExpr], dtype: str = T.int32) -> Union[List[Var], Var]:
"""The block axis remapping function.
Parameters
......@@ -1133,7 +1133,7 @@ def Else() -> frame.ElseFrame: # pylint: disable=invalid-name
def decl_buffer(
shape,
dtype="float32",
dtype=T.float32,
data=None,
strides=None,
elem_offset=None,
......@@ -1184,7 +1184,7 @@ def decl_buffer(
"""
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
if strides is not None:
strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides]
strides = [Var(s, T.int32) if isinstance(s, str) else s for s in strides]
else:
strides = []
return _ffi_api.DeclBuffer( # type: ignore[attr-defined] # pylint: disable=no-member
......@@ -1237,7 +1237,7 @@ def launch_thread(
return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined] # pylint: disable=no-member
def env_thread(thread_tag: str, dtype: str = "int32") -> IterVar:
def env_thread(thread_tag: str, dtype: str = T.int32) -> IterVar:
"""Bind a var to thread env
Parameters
......@@ -1656,7 +1656,7 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer:
args = []
for name, i in zip(params.keys(), identity + identity):
if isinstance(i, int):
args.append(Var(name, "int32"))
args.append(Var(name, T.int32))
else:
args.append(Var(name, i.dtype))
res = combiner(*args)
......
......@@ -94,7 +94,7 @@ def _gemm_impl(
offset_a = A_offset[-1]
offset_b = B_offset[-1]
mbar = to_buffer_region(mbar, access_type="rw") if mbar is not None else tir.const(0, "uint32")
mbar = to_buffer_region(mbar, access_type="rw") if mbar is not None else tir.const(0, T.uint32)
C_coords = [r.min for r in C_region.region]
# Convert BufferRegion to tl.region calls for arguments
A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape])
......
......@@ -157,7 +157,7 @@ class BufferProxy:
def __call__(
self,
shape,
dtype="float32",
dtype=T.float32,
data=None,
strides=None,
elem_offset=None,
......
......@@ -89,12 +89,12 @@ def macro(*args, hygienic: bool = True) -> Callable:
@T.prim_func
def use1(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
def use1(A: T.Buffer((1024,), T.int32), B: T.Buffer((), T.int32)) -> None:
for x_value in T.serial(10):
static_capture(A, B) ### Produces B[()] = A[128]
@T.prim_func
def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
def use2(A: T.Buffer((1024,), T.int32), B: T.Buffer((), T.int32)) -> None:
for x_value in T.serial(10):
dynamic_capture(A, B) ### Produces B[()] = A[x_value]
```
......
......@@ -1163,7 +1163,7 @@ def ptx_tcgen05_mma_ss(
desc_val, scale_out, mask0, mask1, mask2, mask3[, enable_ws]).
Aliases: you can also pass `ws` or `warp_specialized` (booleans) instead of `enable_ws`.
Alternatively, use `variant="ws"` (or "default").
- kind_dtype: instruction kind selector (e.g., "float16" for kind::f16,
- kind_dtype: instruction kind selector (e.g., T.float16 for kind::f16,
"tf32" for kind::tf32, "int8" for kind::i8, "float8_e4m3" for kind::f8f6f4).
"""
# Aliases precedence: if either `ws` or `warp_specialized` is provided, they override enable_ws
......@@ -1224,7 +1224,7 @@ def ptx_tcgen05_mma_ts(
Expects 13 positional arguments:
(kind_dtype, A_ptr, A_offset, desc_b, B_offset, C_ptr, C_offset,
desc_val, scale_out, mask0, mask1, mask2, mask3).
- kind_dtype: instruction kind selector (e.g., "float16" for kind::f16,
- kind_dtype: instruction kind selector (e.g., T.float16 for kind::f16,
"tf32" for kind::tf32, "int8" for kind::i8, "float8_e4m3" for kind::f8f6f4).
"""
return call_intrin(
......
......@@ -13,6 +13,7 @@ from .utils import construct_strides
import tvm
from tvm.tir import Buffer
from tvm.script.ir_builder import tir, IRBuilder
from tvm.tir.expr import BufferLoad, EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var
from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, Union
from collections.abc import Sequence
......
......@@ -11,7 +11,7 @@ _T = TypeVar("_T")
if TYPE_CHECKING:
class dtype(Generic[_T]):
def torch(self) -> torch.dtype: ...
def as_torch(self) -> torch.dtype: ...
else:
dtype = tvm.DataType
......@@ -68,7 +68,32 @@ _TORCH_DTYPE_TO_STR = {
torch.bfloat16: "bfloat16",
}
# _STR_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_STR.items()}
_extended_torch_dtypes = [
("float8_e4m3fn",),
("float8_e4m3fnuz",),
("float8_e5m2",),
("float8_e5m2fnuz",),
("float8_e8m0fnu",),
("float4_e2m1fnx2",),
]
for dtype_name_tuple in _extended_torch_dtypes:
dtype_name = dtype_name_tuple[0]
torch_dtype = getattr(torch, dtype_name, None)
if torch_dtype is not None:
_TORCH_DTYPE_TO_STR[torch_dtype] = dtype_name
_CANONICAL_TO_DISPLAY_STR = {
"double": "float64",
"float": "float32",
"int": "int32",
"long": "int64",
"short": "int16",
"uint": "uint32",
"ulong": "uint64",
}
_STR_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_STR.items()}
# _STR_TO_NUMPY_DTYPE = {v: k for k, v in _NUMPY_DTYPE_TO_STR.items()}
......@@ -76,7 +101,9 @@ _DTYPE_TO_STR = {**_PYTHON_DTYPE_TO_STR, **_NUMPY_DTYPE_TO_STR, **_TORCH_DTYPE_T
_STR_TO_TVM_DTYPE_CALL = {
"bool": "Boolean",
"int4": "Int4",
"int8": "Int8",
"int16": "Int16",
"int32": "Int32",
"int64": "Int64",
"uint8": "UInt8",
......@@ -127,12 +154,20 @@ def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var
return call(expr, is_size_var)
def __dtype_as_torch__(self: dtype) -> torch.dtype:
"""Convert TileLang dtype to PyTorch dtype."""
dtype_str = str(self)
if dtype_str in _STR_TO_TORCH_DTYPE:
return _STR_TO_TORCH_DTYPE[dtype_str]
raise ValueError(f"Cannot convert dtype '{dtype_str}' to torch.dtype. Supported dtypes: {list(_STR_TO_TORCH_DTYPE.keys())}")
__orig_dtype_new = dtype.__new__
def __dtype_new__(cls, value: AnyDType) -> dtype:
if isinstance(value, str):
return __orig_dtype_new(cls, value)
return __orig_dtype_new(cls, _CANONICAL_TO_DISPLAY_STR.get(value, value))
elif value in _DTYPE_TO_STR:
return __orig_dtype_new(cls, _DTYPE_TO_STR[value])
else:
......@@ -142,6 +177,7 @@ def __dtype_new__(cls, value: AnyDType) -> dtype:
dtype.__call__ = __dtype_call__
dtype.__new__ = __dtype_new__
dtype.as_torch = __dtype_as_torch__
def get_tvm_dtype(value: AnyDType) -> dtype:
......@@ -155,10 +191,12 @@ if TYPE_CHECKING:
class bool(dtype): ...
class short(dtype): ...
class int(dtype): ...
class uint(dtype): ...
class long(dtype): ...
class half(dtype): ...
class float(dtype): ...
class double(dtype): ...
class int4(dtype): ...
class int8(dtype): ...
class int16(dtype): ...
class int32(dtype): ...
......@@ -320,10 +358,12 @@ else:
bool = dtype("bool")
short = dtype("int16")
int = dtype("int32")
uint = dtype("uint32")
long = dtype("int64")
half = dtype("float16")
float = dtype("float32")
double = dtype("float64")
int4 = dtype("int4")
int8 = dtype("int8")
int16 = dtype("int16")
int32 = dtype("int32")
......@@ -484,10 +524,12 @@ _all_dtypes = {
"bool",
"short",
"int",
"uint",
"long",
"half",
"float",
"double",
"int4",
"int8",
"int16",
"int32",
......
......@@ -31,10 +31,20 @@ def make_cutlass_metadata_layout_sm90(buffer: tvm.tir.Buffer, mma_dtype: str, bl
block_k = 128
# Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
warnings.warn(f"block_k {block_k} is too large, set to 128 for {mma_dtype}.", stacklevel=2)
if mma_dtype not in ["float16", "bfloat16", "float32", "int8", "float8_e4m3", "float8_e5m2"]:
if mma_dtype not in [
T.float16,
T.bfloat16,
T.float32,
T.int8,
T.float8_e4m3,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
]:
raise NotImplementedError(f"Unsupported dtype: {mma_dtype}")
if buffer.dtype not in ["uint8", "int8"]:
if buffer.dtype not in [T.uint8, T.int8]:
raise ValueError(f"metadata should be 8 bit, got {buffer.dtype}")
bits_map = {
......@@ -43,7 +53,10 @@ def make_cutlass_metadata_layout_sm90(buffer: tvm.tir.Buffer, mma_dtype: str, bl
"float32": 32,
"int8": 8,
"float8_e4m3": 8,
"float8_e4m3fn": 8,
"float8_e4m3fnuz": 8,
"float8_e5m2": 8,
"float8_e5m2fnuz": 8,
}
# ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl#L108-L117
......@@ -112,10 +125,10 @@ def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str):
buffer: metadata buffer shape, for sm80 it should be a 16bit type
"""
if mma_dtype in ["float16", "bfloat16"] and buffer.dtype not in ["uint16", "int16"]:
if mma_dtype in [T.float16, T.bfloat16] and buffer.dtype not in [T.uint16, T.int16]:
raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}")
if mma_dtype in ["float8_e4m3", "float8_e5m2", "int8", "uint8"] and buffer.dtype not in ["uint32", "int32"]:
if mma_dtype in ["float8_e4m3", "float8_e5m2", T.int8, T.uint8] and buffer.dtype not in [T.uint32, T.int32]:
raise ValueError(f"metadata should be 32 bit, got {buffer.dtype}")
m, k = buffer.shape
......@@ -134,7 +147,7 @@ def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str):
return T.Layout(buffer.shape, ColumnMajorInterleaved)
def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer, mma_dtype: str = "float16", arch: str | None = None, **extra_args):
def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer, mma_dtype: str = T.float16, arch: str | None = None, **extra_args):
if arch is None:
arch = nvcc.get_target_compute_version()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Literal
from tilelang import language as T
decode_i4_to_f16 = """
template <typename T1, typename T2, bool isSigned = false>
......@@ -1088,10 +1089,10 @@ __device__ void decode_i2u_to_i4s(T1 *_i4u, T2 *B_local_decode, const int N = 16
def get_lop3_intrin_group(
out_dtype: Literal["float16", "int8", "int4"],
source_format: Literal["int", "uint"] = "uint",
out_dtype: Literal[T.float16, T.int8, T.int4],
source_format: Literal[T.int, T.uint] = T.uint,
source_bit: int = 4,
storage_dtype: Literal["int32", "int8"] = "int8",
storage_dtype: Literal[T.int32, T.int8] = T.int8,
with_scaling: bool = False,
with_zeros: bool = False,
zeros_mode: Literal["original", "rescale", "quantized"] = "original",
......@@ -1104,10 +1105,10 @@ def get_lop3_intrin_group(
Parameters
----------
in_dtype : Literal["int8"]
in_dtype : Literal[T.int8]
The data type of the input. It should be "int8".
out_dtype : Literal["float16", "int8", "int4"]
out_dtype : Literal[T.float16, T.int8, T.int4]
The data type of the output. It can be either "float16" or "int8" or "int4".
storage_nbit : int, optional
......@@ -1130,18 +1131,17 @@ def get_lop3_intrin_group(
Dict[str, str]
A dictionary mapping the names of the intrinsics to their corresponding implementations.
"""
assert out_dtype in ["float16", "int8", "int4"], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' ."
out_dtype, source_format, storage_dtype = T.dtype(out_dtype), T.dtype(source_format), T.dtype(storage_dtype)
assert out_dtype in [T.float16, T.int8, T.int4], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' ."
dtype_mapping = {"float16": "f16", "int4": "i4", "int8": "i8", "int32": "i32"}
dtype_mapping = {T.float16: "f16", T.int4: "i4", T.int8: "i8", T.int32: "i32"}
target_dtype = dtype_mapping[out_dtype]
if source_format not in ["int", "uint"]:
raise ValueError(f"Invalid source_format. Expected 'int' or 'uint', but got {source_format}.")
if with_zeros and source_format == "int":
if source_format not in [T.int, T.uint]:
raise ValueError(f"Invalid source_format. Expected 'int' or 'uint', but got {source_format}, {type(source_format)}.")
if with_zeros and source_format == T.int:
raise ValueError(f"Zeros are not supported for signed integers, but got {source_format}")
source_symbol = "i" if source_format == "int" else "u"
import_c_map = {
"i4_to_f16": decode_i4_to_f16,
"i2_to_f16": decode_i2_to_f16,
......@@ -1176,15 +1176,15 @@ def get_lop3_intrin_group(
if is_ladder_stage3:
key += "_offset"
if out_dtype == "float16":
if out_dtype == T.float16:
d4f = "f16"
elif out_dtype == "int8":
elif out_dtype == T.int8:
d4f = "i8s"
elif out_dtype == "int4":
elif out_dtype == T.int4:
d4f = "i4s"
else:
raise ValueError(f"Unsupported target dtype: {target_dtype}")
source_symbol = "u" if source_format == "uint" else "s"
source_symbol = "u" if source_format == T.uint else "s"
func_name = f"decode_i{source_bit}{source_symbol}_to_{d4f}"
if with_scaling:
func_name += "_scale"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment