Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
...@@ -25,7 +25,7 @@ def _check(original, transformed): ...@@ -25,7 +25,7 @@ def _check(original, transformed):
M = 512 M = 512
N = 512 N = 512
K = 512 K = 512
dtype = "float16" dtype = T.float16
block_M = 64 block_M = 64
block_N = 64 block_N = 64
block_K = 32 block_K = 32
...@@ -40,15 +40,15 @@ def test_warp_specialized(): ...@@ -40,15 +40,15 @@ def test_warp_specialized():
with T.block(""): with T.block(""):
T.reads(A[by * 64, 0:481], B[0:481, bx * 64]) T.reads(A[by * 64, 0:481], B[0:481, bx * 64])
T.writes() T.writes()
A_shared = T.alloc_buffer((3, 1, 8, 256), "float16", scope="shared.dyn") A_shared = T.alloc_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn")
B_shared = T.alloc_buffer((3, 1, 4, 512), "float16", scope="shared.dyn") B_shared = T.alloc_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn")
C_local = T.alloc_buffer((32,), scope="local") C_local = T.alloc_buffer((32,), scope="local")
for k in T.serial(16, annotations={"num_stages": T.int32(3)}): for k in T.serial(16, annotations={"num_stages": T.int32(3)}):
if v == 0: if v == 0:
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
0, 0,
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32, k * 32,
by * 64, by * 64,
) )
...@@ -56,16 +56,16 @@ def test_warp_specialized(): ...@@ -56,16 +56,16 @@ def test_warp_specialized():
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
0, 0,
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 2),
bx * 64, bx * 64,
k * 32, k * 32,
) )
T.call_extern( T.call_extern(
"handle", "handle",
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3),
) )
@T.prim_func @T.prim_func
...@@ -73,8 +73,8 @@ def test_warp_specialized(): ...@@ -73,8 +73,8 @@ def test_warp_specialized():
bx = T.launch_thread("blockIdx.x", 8) bx = T.launch_thread("blockIdx.x", 8)
by = T.launch_thread("blockIdx.y", 8) by = T.launch_thread("blockIdx.y", 8)
v = T.launch_thread("threadIdx.x", 256) v = T.launch_thread("threadIdx.x", 256)
A_shared = T.decl_buffer((3, 1, 8, 256), "float16", scope="shared.dyn") A_shared = T.decl_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn")
B_shared = T.decl_buffer((3, 1, 4, 512), "float16", scope="shared.dyn") B_shared = T.decl_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn")
C_local = T.decl_buffer((32,), scope="local") C_local = T.decl_buffer((32,), scope="local")
T.create_list_of_mbarrier(128, 128, 128, 128, 128, 128) T.create_list_of_mbarrier(128, 128, 128, 128, 128, 128)
T.attr([128, 128], "kWarpSpecializationScope", 0) T.attr([128, 128], "kWarpSpecializationScope", 0)
...@@ -88,7 +88,7 @@ def test_warp_specialized(): ...@@ -88,7 +88,7 @@ def test_warp_specialized():
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
T.get_mbarrier(k % 3), T.get_mbarrier(k % 3),
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32, k * 32,
by * 64, by * 64,
) )
...@@ -98,7 +98,7 @@ def test_warp_specialized(): ...@@ -98,7 +98,7 @@ def test_warp_specialized():
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
T.get_mbarrier(k % 3), T.get_mbarrier(k % 3),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 2),
bx * 64, bx * 64,
k * 32, k * 32,
) )
...@@ -110,9 +110,9 @@ def test_warp_specialized(): ...@@ -110,9 +110,9 @@ def test_warp_specialized():
T.call_extern( T.call_extern(
"handle", "handle",
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3),
) )
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)])) T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
......
...@@ -4,7 +4,7 @@ import tilelang.testing ...@@ -4,7 +4,7 @@ import tilelang.testing
import tilelang.language as T import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
...@@ -38,8 +38,8 @@ def assert_gemm_codegen( ...@@ -38,8 +38,8 @@ def assert_gemm_codegen(
block_M, block_M,
block_N, block_N,
block_K, block_K,
dtype="float16", dtype=T.float16,
accum_dtype="float", accum_dtype=T.float32,
): ):
func = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) func = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype)
# Because the current pass context have been polluted by previous testing. # Because the current pass context have been polluted by previous testing.
......
...@@ -141,6 +141,7 @@ from . import ( ...@@ -141,6 +141,7 @@ from . import (
engine, # noqa: F401 engine, # noqa: F401
tools, # noqa: F401 tools, # noqa: F401
) )
from .language.v2 import dtypes # noqa: F401
from .autotuner import autotune # noqa: F401 from .autotuner import autotune # noqa: F401
from .transform import PassConfigKey # noqa: F401 from .transform import PassConfigKey # noqa: F401
......
...@@ -6,7 +6,7 @@ from dataclasses import dataclass ...@@ -6,7 +6,7 @@ from dataclasses import dataclass
import torch import torch
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.tir import Buffer, IntImm, Var, PrimExpr from tvm.tir import Buffer, IntImm, Var, PrimExpr
from tilelang.utils.tensor import map_torch_type import tilelang.language as T
@dataclass @dataclass
...@@ -138,7 +138,7 @@ class KernelParam: ...@@ -138,7 +138,7 @@ class KernelParam:
>>> param = KernelParam.from_buffer(buffer) >>> param = KernelParam.from_buffer(buffer)
>>> tensor = torch.empty(shape, dtype=param.torch_dtype()) >>> tensor = torch.empty(shape, dtype=param.torch_dtype())
""" """
return map_torch_type(str(self.dtype)) return T.dtype(self.dtype).as_torch()
@dataclass @dataclass
......
...@@ -61,9 +61,9 @@ class MatrixCoreIntrinEmitter: ...@@ -61,9 +61,9 @@ class MatrixCoreIntrinEmitter:
def __init__( def __init__(
self, self,
a_dtype: str = "float16", a_dtype: str = T.float16,
b_dtype: str = "float16", b_dtype: str = T.float16,
accum_dtype: str = "float16", accum_dtype: str = T.float16,
a_transposed: bool = False, a_transposed: bool = False,
b_transposed: bool = False, b_transposed: bool = False,
block_row_warps: int = 2, block_row_warps: int = 2,
...@@ -105,9 +105,9 @@ class MatrixCoreIntrinEmitter: ...@@ -105,9 +105,9 @@ class MatrixCoreIntrinEmitter:
self.num_elems_per_byte = num_elems_per_byte self.num_elems_per_byte = num_elems_per_byte
self.thread_var = thread_var self.thread_var = thread_var
def _initialize_k_dim(self, a_dtype="float16"): def _initialize_k_dim(self, a_dtype=T.float16):
if isinstance(a_dtype, str): if isinstance(a_dtype, str):
if a_dtype in ["float8_e4m3fnuz", "int8"]: if a_dtype in ["float8_e4m3fnuz", T.int8]:
self.k_dim = 32 self.k_dim = 32
return return
a_dtype = DataType(a_dtype) a_dtype = DataType(a_dtype)
...@@ -132,7 +132,7 @@ class MatrixCoreIntrinEmitter: ...@@ -132,7 +132,7 @@ class MatrixCoreIntrinEmitter:
def _initialize_mfma_prefix(self, k_dim=16): def _initialize_mfma_prefix(self, k_dim=16):
in_dtype, out_dtype = self.a_dtype, self.accum_dtype in_dtype, out_dtype = self.a_dtype, self.accum_dtype
M_DIM, N_DIM = self.M_DIM, self.N_DIM M_DIM, N_DIM = self.M_DIM, self.N_DIM
out_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"}[out_dtype] out_dtype_abbrv = {T.float16: "f16", T.float32: "f32", T.int8: "i8", T.int32: "i32"}[out_dtype]
in_dtype_abbrv = { in_dtype_abbrv = {
"bfloat16": "bf16", "bfloat16": "bf16",
...@@ -221,7 +221,7 @@ class MatrixCoreIntrinEmitter: ...@@ -221,7 +221,7 @@ class MatrixCoreIntrinEmitter:
def get_store_index_map(self, inverse: bool = False) -> IndexMap: def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(mfma_store_index_map, index_dtype="int32") index_map = IndexMap.from_func(mfma_store_index_map, index_dtype=T.int32)
if not inverse: if not inverse:
return index_map return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c]) inverse_index_map = index_map.inverse([warp_size, local_size_c])
...@@ -521,7 +521,7 @@ class MatrixCoreIntrinEmitter: ...@@ -521,7 +521,7 @@ class MatrixCoreIntrinEmitter:
self.block_col_warps, self.block_col_warps,
) )
inverse_mfma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") inverse_mfma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32)
def forward_thread(i: int, j: int) -> int: def forward_thread(i: int, j: int) -> int:
""" """
...@@ -670,9 +670,9 @@ class MatrixCoreIntrinEmitter: ...@@ -670,9 +670,9 @@ class MatrixCoreIntrinEmitter:
class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter): class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
def __init__( def __init__(
self, self,
a_dtype: str = "float16", a_dtype: str = T.float16,
b_dtype: str = "float16", b_dtype: str = T.float16,
accum_dtype: str = "float16", accum_dtype: str = T.float16,
a_transposed: bool = False, a_transposed: bool = False,
b_transposed: bool = False, b_transposed: bool = False,
block_row_warps: int = 2, block_row_warps: int = 2,
......
...@@ -60,9 +60,9 @@ class TensorCoreIntrinEmitter: ...@@ -60,9 +60,9 @@ class TensorCoreIntrinEmitter:
def __init__( def __init__(
self, self,
a_dtype: str = "float16", a_dtype: str = T.float16,
b_dtype: str = "float16", b_dtype: str = T.float16,
accum_dtype: str = "float16", accum_dtype: str = T.float16,
a_transposed: bool = False, a_transposed: bool = False,
b_transposed: bool = False, b_transposed: bool = False,
block_row_warps: int = 2, block_row_warps: int = 2,
...@@ -108,7 +108,7 @@ class TensorCoreIntrinEmitter: ...@@ -108,7 +108,7 @@ class TensorCoreIntrinEmitter:
f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}" f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}"
) )
def _initialize_k_dim(self, a_dtype="float16"): def _initialize_k_dim(self, a_dtype=T.float16):
if isinstance(a_dtype, str): if isinstance(a_dtype, str):
a_dtype = DataType(a_dtype) a_dtype = DataType(a_dtype)
self.k_dim = 256 // a_dtype.bits self.k_dim = 256 // a_dtype.bits
...@@ -194,9 +194,9 @@ class TensorCoreIntrinEmitter: ...@@ -194,9 +194,9 @@ class TensorCoreIntrinEmitter:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
if DataType(self.accum_dtype).bits == 64: if DataType(self.accum_dtype).bits == 64:
index_map = IndexMap.from_func(mma_store_index_map_fp64, index_dtype="int32") index_map = IndexMap.from_func(mma_store_index_map_fp64, index_dtype=T.int32)
else: else:
index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32") index_map = IndexMap.from_func(mma_store_index_map, index_dtype=T.int32)
if not inverse: if not inverse:
return index_map return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c]) inverse_index_map = index_map.inverse([warp_size, local_size_c])
...@@ -649,7 +649,7 @@ class TensorCoreIntrinEmitter: ...@@ -649,7 +649,7 @@ class TensorCoreIntrinEmitter:
self.block_col_warps, self.block_col_warps,
) )
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32)
def forward_thread(i: int, j: int) -> int: def forward_thread(i: int, j: int) -> int:
""" """
...@@ -806,9 +806,9 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): ...@@ -806,9 +806,9 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
def __init__( def __init__(
self, self,
a_dtype: str = "float16", a_dtype: str = T.float16,
b_dtype: str = "float16", b_dtype: str = T.float16,
accum_dtype: str = "float16", accum_dtype: str = T.float16,
a_transposed: bool = False, a_transposed: bool = False,
b_transposed: bool = False, b_transposed: bool = False,
block_row_warps: int = 2, block_row_warps: int = 2,
...@@ -839,7 +839,7 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): ...@@ -839,7 +839,7 @@ class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
) )
self._initialize_transform_kind(transform_kind_a, transform_kind_b) self._initialize_transform_kind(transform_kind_a, transform_kind_b)
def _initialize_k_dim(self, a_dtype="float16"): def _initialize_k_dim(self, a_dtype=T.float16):
self.k_dim = 256 // DataType(a_dtype).bits self.k_dim = 256 // DataType(a_dtype).bits
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32):
...@@ -1266,7 +1266,7 @@ class INT4TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitterWith ...@@ -1266,7 +1266,7 @@ class INT4TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitterWith
a_dtype_abbrv = "int4" a_dtype_abbrv = "int4"
b_dtype_abbrv = "int4" b_dtype_abbrv = "int4"
accum_dtype = self.accum_dtype accum_dtype = self.accum_dtype
accum_dtype_abbrv = "int32" accum_dtype_abbrv = T.int32
mma_prefix = "m16n8k32" mma_prefix = "m16n8k32"
@T.macro @T.macro
......
...@@ -46,9 +46,9 @@ class TensorCoreIntrinEmitter: ...@@ -46,9 +46,9 @@ class TensorCoreIntrinEmitter:
def __init__( def __init__(
self, self,
a_dtype: str = "float16", a_dtype: str = T.float16,
b_dtype: str = "float16", b_dtype: str = T.float16,
accum_dtype: str = "float16", accum_dtype: str = T.float16,
a_transposed: bool = False, a_transposed: bool = False,
b_transposed: bool = False, b_transposed: bool = False,
block_row_warps: int = 2, block_row_warps: int = 2,
...@@ -89,7 +89,7 @@ class TensorCoreIntrinEmitter: ...@@ -89,7 +89,7 @@ class TensorCoreIntrinEmitter:
f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}" f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}"
) )
def _initialize_k_dim(self, a_dtype="float16"): def _initialize_k_dim(self, a_dtype=T.float16):
self.k_dim = 4 self.k_dim = 4
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16): def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16):
...@@ -147,8 +147,8 @@ class TensorCoreIntrinEmitter: ...@@ -147,8 +147,8 @@ class TensorCoreIntrinEmitter:
def get_store_index_map(self, inverse: bool = False) -> IndexMap: def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func( index_map = IndexMap.from_func(
mma_32x8_to_shared_16x16_layout_fp32 if self.accum_dtype == "float32" else mma_32x8_to_shared_16x16_layout_fp16, mma_32x8_to_shared_16x16_layout_fp32 if self.accum_dtype == T.float32 else mma_32x8_to_shared_16x16_layout_fp16,
index_dtype="int32", index_dtype=T.int32,
) )
if not inverse: if not inverse:
return index_map return index_map
...@@ -383,7 +383,7 @@ class TensorCoreIntrinEmitter: ...@@ -383,7 +383,7 @@ class TensorCoreIntrinEmitter:
self.block_col_warps, self.block_col_warps,
) )
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32)
def forward(i: int, j: int, rep: int) -> int: def forward(i: int, j: int, rep: int) -> int:
""" """
......
...@@ -133,10 +133,10 @@ class SparseTensorCoreIntrinEmitter: ...@@ -133,10 +133,10 @@ class SparseTensorCoreIntrinEmitter:
def __init__( def __init__(
self, self,
a_dtype: str = "float16", a_dtype: str = T.float16,
e_dtype: str = "uint8", e_dtype: str = T.uint8,
b_dtype: str = "float16", b_dtype: str = T.float16,
accum_dtype: str = "float16", accum_dtype: str = T.float16,
a_transposed: bool = False, a_transposed: bool = False,
b_transposed: bool = False, b_transposed: bool = False,
e_transposed: bool = False, e_transposed: bool = False,
...@@ -181,7 +181,7 @@ class SparseTensorCoreIntrinEmitter: ...@@ -181,7 +181,7 @@ class SparseTensorCoreIntrinEmitter:
f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}" f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}"
) )
def _initialize_k_dim(self, a_dtype="float16"): def _initialize_k_dim(self, a_dtype=T.float16):
if isinstance(a_dtype, str): if isinstance(a_dtype, str):
a_dtype = DataType(a_dtype) a_dtype = DataType(a_dtype)
# NOTE: k_dim here represents the logical shape of the MMA operation. # NOTE: k_dim here represents the logical shape of the MMA operation.
...@@ -250,7 +250,7 @@ class SparseTensorCoreIntrinEmitter: ...@@ -250,7 +250,7 @@ class SparseTensorCoreIntrinEmitter:
def get_store_index_map(self, inverse: bool = False) -> IndexMap: def get_store_index_map(self, inverse: bool = False) -> IndexMap:
warp_size, local_size_c = self.WARP_SIZE, self.local_size_out warp_size, local_size_c = self.WARP_SIZE, self.local_size_out
index_map = IndexMap.from_func(mma_store_index_map, index_dtype="int32") index_map = IndexMap.from_func(mma_store_index_map, index_dtype=T.int32)
if not inverse: if not inverse:
return index_map return index_map
inverse_index_map = index_map.inverse([warp_size, local_size_c]) inverse_index_map = index_map.inverse([warp_size, local_size_c])
...@@ -708,7 +708,7 @@ class SparseTensorCoreIntrinEmitter: ...@@ -708,7 +708,7 @@ class SparseTensorCoreIntrinEmitter:
self.block_col_warps, self.block_col_warps,
) )
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32)
def forward_thread(i: int, j: int) -> int: def forward_thread(i: int, j: int) -> int:
""" """
......
...@@ -73,9 +73,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -73,9 +73,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
def __init__( def __init__(
self, self,
a_dtype: str = "float16", a_dtype: str = T.float16,
b_dtype: str = "float16", b_dtype: str = T.float16,
accum_dtype: str = "float16", accum_dtype: str = T.float16,
a_transposed: bool = False, a_transposed: bool = False,
b_transposed: bool = False, b_transposed: bool = False,
block_row_warps: int = 2, block_row_warps: int = 2,
...@@ -245,7 +245,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -245,7 +245,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
) )
# Allocate an instruction descriptor wrapper and initialize it # Allocate an instruction descriptor wrapper and initialize it
a_dtype_abbrv = self.a_dtype_abbrv a_dtype_abbrv = self.a_dtype_abbrv
mask_zero = T.Cast("int32", 0) mask_zero = T.Cast(T.int32, 0)
mask0 = mask1 = mask2 = mask3 = mask_zero mask0 = mask1 = mask2 = mask3 = mask_zero
# TCGEN05 only has one warp group # TCGEN05 only has one warp group
......
...@@ -83,9 +83,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -83,9 +83,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
def __init__( def __init__(
self, self,
a_dtype: str = "float16", a_dtype: str = T.float16,
b_dtype: str = "float16", b_dtype: str = T.float16,
accum_dtype: str = "float16", accum_dtype: str = T.float16,
a_transposed: bool = False, a_transposed: bool = False,
b_transposed: bool = False, b_transposed: bool = False,
block_row_warps: int = 2, block_row_warps: int = 2,
...@@ -515,7 +515,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -515,7 +515,7 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
self.block_col_warps, self.block_col_warps,
) )
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32)
def forward_thread(i: int, j: int) -> int: def forward_thread(i: int, j: int) -> int:
""" """
......
...@@ -28,6 +28,7 @@ from tvm.tir import PrimExpr ...@@ -28,6 +28,7 @@ from tvm.tir import PrimExpr
from tvm.script.parser.tir import block_attr from tvm.script.parser.tir import block_attr
from tvm.tir.buffer import Buffer from tvm.tir.buffer import Buffer
from tvm.tir.expr import FloatImm, IntImm from tvm.tir.expr import FloatImm, IntImm
from .v2 import dtypes as _dtypes
from .v2.dtypes import dtype as tl_dtype from .v2.dtypes import dtype as tl_dtype
from .v2.builder import OutTensor from .v2.builder import OutTensor
from .v2.annot import Tensor, SharedBuffer, LocalBuffer, FragmentBuffer from .v2.annot import Tensor, SharedBuffer, LocalBuffer, FragmentBuffer
...@@ -158,7 +159,7 @@ def alloc_barrier(arrive_count: int): ...@@ -158,7 +159,7 @@ def alloc_barrier(arrive_count: int):
Returns: Returns:
T.Buffer: A TVM buffer object allocated as a barrier T.Buffer: A TVM buffer object allocated as a barrier
""" """
return T.alloc_buffer([arrive_count], "uint64", scope="shared.barrier") return T.alloc_buffer([arrive_count], _dtypes.uint64, scope="shared.barrier")
def alloc_tmem(shape, dtype): def alloc_tmem(shape, dtype):
...@@ -231,7 +232,7 @@ DescKind = Literal["wgmma", "tcgen05_smem", "tcgen05_instr"] ...@@ -231,7 +232,7 @@ DescKind = Literal["wgmma", "tcgen05_smem", "tcgen05_instr"]
def alloc_descriptor( def alloc_descriptor(
kind: DescKind = "wgmma", kind: DescKind = "wgmma",
dtype: str = "uint64", dtype: str = _dtypes.uint64,
): ):
"""Allocate a descriptor buffer for WGMMA and TCGEN5.MMA. """Allocate a descriptor buffer for WGMMA and TCGEN5.MMA.
...@@ -248,28 +249,28 @@ def alloc_descriptor( ...@@ -248,28 +249,28 @@ def alloc_descriptor(
return T.alloc_buffer([1], dtype, scope=scope) return T.alloc_buffer([1], dtype, scope=scope)
def alloc_wgmma_desc(dtype: str = "uint64"): def alloc_wgmma_desc(dtype: str = _dtypes.uint64):
return alloc_descriptor("wgmma", dtype=dtype) return alloc_descriptor("wgmma", dtype=dtype)
def alloc_tcgen05_smem_desc(dtype: str = "uint64"): def alloc_tcgen05_smem_desc(dtype: str = _dtypes.uint64):
return alloc_descriptor("tcgen05_smem", dtype=dtype) return alloc_descriptor("tcgen05_smem", dtype=dtype)
def alloc_tcgen05_instruction_desc(dtype: str = "uint32"): def alloc_tcgen05_instruction_desc(dtype: str = _dtypes.uint32):
return alloc_descriptor("tcgen05_instr", dtype=dtype) return alloc_descriptor("tcgen05_instr", dtype=dtype)
# Alias: short name consistent with imports # Alias: short name consistent with imports
def alloc_tcgen05_instr_desc(dtype: str = "uint32"): def alloc_tcgen05_instr_desc(dtype: str = _dtypes.uint32):
return alloc_tcgen05_instruction_desc(dtype) return alloc_tcgen05_instruction_desc(dtype)
@overload @overload
def empty(shape: tuple[Unpack[_Shapes]], dtype: str = "float32") -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ... def empty(shape: tuple[Unpack[_Shapes]], dtype: str = _dtypes.float32) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: ...
def empty(*shape: Unpack[_Shapes], dtype: str = "float32") -> Tensor[Callable[[Unpack[_Shapes]]], _DType]: def empty(*shape: Unpack[_Shapes], dtype: str = _dtypes.float32) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
if len(shape) == 1 and isinstance(shape[0], (tuple, list)): if len(shape) == 1 and isinstance(shape[0], (tuple, list)):
return OutTensor(shape[0], dtype) return OutTensor(shape[0], dtype)
elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str): elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str):
......
...@@ -92,7 +92,7 @@ from tvm.script.ir_builder.tir import frame ...@@ -92,7 +92,7 @@ from tvm.script.ir_builder.tir import frame
def buffer( def buffer(
shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
dtype: str = "float32", dtype: str = T.float32,
data: Var = None, data: Var = None,
strides: List[PrimExpr] = None, strides: List[PrimExpr] = None,
elem_offset: PrimExpr = None, elem_offset: PrimExpr = None,
...@@ -143,7 +143,7 @@ def buffer( ...@@ -143,7 +143,7 @@ def buffer(
""" """
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
if strides is not None: if strides is not None:
strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides] strides = [Var(s, T.int32) if isinstance(s, str) else s for s in strides]
else: else:
strides = [] strides = []
return _ffi_api.Buffer( # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.Buffer( # type: ignore[attr-defined] # pylint: disable=no-member
...@@ -244,7 +244,7 @@ def func_ret(ret_type: Type) -> Type: ...@@ -244,7 +244,7 @@ def func_ret(ret_type: Type) -> Type:
def match_buffer( def match_buffer(
param: Union[Var, BufferLoad, BufferRegion], param: Union[Var, BufferLoad, BufferRegion],
shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] = None, shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] = None,
dtype: str = "float32", dtype: str = T.float32,
data: Var = None, data: Var = None,
strides: List[PrimExpr] = None, strides: List[PrimExpr] = None,
elem_offset: PrimExpr = None, elem_offset: PrimExpr = None,
...@@ -266,11 +266,11 @@ def match_buffer( ...@@ -266,11 +266,11 @@ def match_buffer(
------- -------
Match buffer from function parameter Match buffer from function parameter
.. code-block:: python .. code-block:: python
A = T.match_buffer(a, (128, 128), dtype="float32") A = T.match_buffer(a, (128, 128), dtype=T.float32)
Match buffer from Buffer subregion Match buffer from Buffer subregion
.. code-block:: python .. code-block:: python
A = T.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32") A = T.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype=T.float32)
Parameters Parameters
---------- ----------
...@@ -320,7 +320,7 @@ def match_buffer( ...@@ -320,7 +320,7 @@ def match_buffer(
raise ValueError("Shape must be specified when binding input param") raise ValueError("Shape must be specified when binding input param")
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
if strides is not None: if strides is not None:
idx_dtype = shape[0].dtype if isinstance(shape[0], PrimExpr) else "int32" idx_dtype = shape[0].dtype if isinstance(shape[0], PrimExpr) else T.int32
strides = [Var(s, idx_dtype) if isinstance(s, str) else s for s in strides] strides = [Var(s, idx_dtype) if isinstance(s, str) else s for s in strides]
else: else:
strides = [] strides = []
...@@ -440,7 +440,7 @@ def block_attr(attrs: Dict[str, Any]) -> None: ...@@ -440,7 +440,7 @@ def block_attr(attrs: Dict[str, Any]) -> None:
def alloc_buffer( def alloc_buffer(
shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
dtype: str = "float32", dtype: str = T.float32,
data: Var = None, data: Var = None,
strides: List[PrimExpr] = None, strides: List[PrimExpr] = None,
elem_offset: PrimExpr = None, elem_offset: PrimExpr = None,
...@@ -491,7 +491,7 @@ def alloc_buffer( ...@@ -491,7 +491,7 @@ def alloc_buffer(
""" """
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
if strides is not None: if strides is not None:
strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides] strides = [Var(s, T.int32) if isinstance(s, str) else s for s in strides]
else: else:
strides = [] strides = []
return _ffi_api.AllocBuffer( # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.AllocBuffer( # type: ignore[attr-defined] # pylint: disable=no-member
...@@ -537,7 +537,7 @@ class axis: # pylint: disable=invalid-name ...@@ -537,7 +537,7 @@ class axis: # pylint: disable=invalid-name
def spatial( def spatial(
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr, binding: PrimExpr,
dtype: str = "int32", dtype: str = T.int32,
) -> Var: ) -> Var:
"""The spatial block axis defining function. """The spatial block axis defining function.
...@@ -565,7 +565,7 @@ class axis: # pylint: disable=invalid-name ...@@ -565,7 +565,7 @@ class axis: # pylint: disable=invalid-name
def reduce( def reduce(
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr, binding: PrimExpr,
dtype: str = "int32", dtype: str = T.int32,
) -> Var: ) -> Var:
"""The reduced block axis defining function. """The reduced block axis defining function.
...@@ -593,7 +593,7 @@ class axis: # pylint: disable=invalid-name ...@@ -593,7 +593,7 @@ class axis: # pylint: disable=invalid-name
def scan( def scan(
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr, binding: PrimExpr,
dtype: str = "int32", dtype: str = T.int32,
) -> Var: ) -> Var:
"""The scanning block axis defining function. """The scanning block axis defining function.
...@@ -621,7 +621,7 @@ class axis: # pylint: disable=invalid-name ...@@ -621,7 +621,7 @@ class axis: # pylint: disable=invalid-name
def opaque( def opaque(
dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr, binding: PrimExpr,
dtype: str = "int32", dtype: str = T.int32,
) -> Var: ) -> Var:
"""The opaque block axis defining function. """The opaque block axis defining function.
...@@ -646,7 +646,7 @@ class axis: # pylint: disable=invalid-name ...@@ -646,7 +646,7 @@ class axis: # pylint: disable=invalid-name
) )
@staticmethod @staticmethod
def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[List[Var], Var]: def remap(kinds: str, bindings: List[PrimExpr], dtype: str = T.int32) -> Union[List[Var], Var]:
"""The block axis remapping function. """The block axis remapping function.
Parameters Parameters
...@@ -1133,7 +1133,7 @@ def Else() -> frame.ElseFrame: # pylint: disable=invalid-name ...@@ -1133,7 +1133,7 @@ def Else() -> frame.ElseFrame: # pylint: disable=invalid-name
def decl_buffer( def decl_buffer(
shape, shape,
dtype="float32", dtype=T.float32,
data=None, data=None,
strides=None, strides=None,
elem_offset=None, elem_offset=None,
...@@ -1184,7 +1184,7 @@ def decl_buffer( ...@@ -1184,7 +1184,7 @@ def decl_buffer(
""" """
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
if strides is not None: if strides is not None:
strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides] strides = [Var(s, T.int32) if isinstance(s, str) else s for s in strides]
else: else:
strides = [] strides = []
return _ffi_api.DeclBuffer( # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.DeclBuffer( # type: ignore[attr-defined] # pylint: disable=no-member
...@@ -1237,7 +1237,7 @@ def launch_thread( ...@@ -1237,7 +1237,7 @@ def launch_thread(
return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined] # pylint: disable=no-member return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined] # pylint: disable=no-member
def env_thread(thread_tag: str, dtype: str = "int32") -> IterVar: def env_thread(thread_tag: str, dtype: str = T.int32) -> IterVar:
"""Bind a var to thread env """Bind a var to thread env
Parameters Parameters
...@@ -1656,7 +1656,7 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer: ...@@ -1656,7 +1656,7 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer:
args = [] args = []
for name, i in zip(params.keys(), identity + identity): for name, i in zip(params.keys(), identity + identity):
if isinstance(i, int): if isinstance(i, int):
args.append(Var(name, "int32")) args.append(Var(name, T.int32))
else: else:
args.append(Var(name, i.dtype)) args.append(Var(name, i.dtype))
res = combiner(*args) res = combiner(*args)
......
...@@ -94,7 +94,7 @@ def _gemm_impl( ...@@ -94,7 +94,7 @@ def _gemm_impl(
offset_a = A_offset[-1] offset_a = A_offset[-1]
offset_b = B_offset[-1] offset_b = B_offset[-1]
mbar = to_buffer_region(mbar, access_type="rw") if mbar is not None else tir.const(0, "uint32") mbar = to_buffer_region(mbar, access_type="rw") if mbar is not None else tir.const(0, T.uint32)
C_coords = [r.min for r in C_region.region] C_coords = [r.min for r in C_region.region]
# Convert BufferRegion to tl.region calls for arguments # Convert BufferRegion to tl.region calls for arguments
A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape]) A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape])
......
...@@ -157,7 +157,7 @@ class BufferProxy: ...@@ -157,7 +157,7 @@ class BufferProxy:
def __call__( def __call__(
self, self,
shape, shape,
dtype="float32", dtype=T.float32,
data=None, data=None,
strides=None, strides=None,
elem_offset=None, elem_offset=None,
......
...@@ -89,12 +89,12 @@ def macro(*args, hygienic: bool = True) -> Callable: ...@@ -89,12 +89,12 @@ def macro(*args, hygienic: bool = True) -> Callable:
@T.prim_func @T.prim_func
def use1(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: def use1(A: T.Buffer((1024,), T.int32), B: T.Buffer((), T.int32)) -> None:
for x_value in T.serial(10): for x_value in T.serial(10):
static_capture(A, B) ### Produces B[()] = A[128] static_capture(A, B) ### Produces B[()] = A[128]
@T.prim_func @T.prim_func
def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: def use2(A: T.Buffer((1024,), T.int32), B: T.Buffer((), T.int32)) -> None:
for x_value in T.serial(10): for x_value in T.serial(10):
dynamic_capture(A, B) ### Produces B[()] = A[x_value] dynamic_capture(A, B) ### Produces B[()] = A[x_value]
``` ```
......
...@@ -1163,7 +1163,7 @@ def ptx_tcgen05_mma_ss( ...@@ -1163,7 +1163,7 @@ def ptx_tcgen05_mma_ss(
desc_val, scale_out, mask0, mask1, mask2, mask3[, enable_ws]). desc_val, scale_out, mask0, mask1, mask2, mask3[, enable_ws]).
Aliases: you can also pass `ws` or `warp_specialized` (booleans) instead of `enable_ws`. Aliases: you can also pass `ws` or `warp_specialized` (booleans) instead of `enable_ws`.
Alternatively, use `variant="ws"` (or "default"). Alternatively, use `variant="ws"` (or "default").
- kind_dtype: instruction kind selector (e.g., "float16" for kind::f16, - kind_dtype: instruction kind selector (e.g., T.float16 for kind::f16,
"tf32" for kind::tf32, "int8" for kind::i8, "float8_e4m3" for kind::f8f6f4). "tf32" for kind::tf32, "int8" for kind::i8, "float8_e4m3" for kind::f8f6f4).
""" """
# Aliases precedence: if either `ws` or `warp_specialized` is provided, they override enable_ws # Aliases precedence: if either `ws` or `warp_specialized` is provided, they override enable_ws
...@@ -1224,7 +1224,7 @@ def ptx_tcgen05_mma_ts( ...@@ -1224,7 +1224,7 @@ def ptx_tcgen05_mma_ts(
Expects 13 positional arguments: Expects 13 positional arguments:
(kind_dtype, A_ptr, A_offset, desc_b, B_offset, C_ptr, C_offset, (kind_dtype, A_ptr, A_offset, desc_b, B_offset, C_ptr, C_offset,
desc_val, scale_out, mask0, mask1, mask2, mask3). desc_val, scale_out, mask0, mask1, mask2, mask3).
- kind_dtype: instruction kind selector (e.g., "float16" for kind::f16, - kind_dtype: instruction kind selector (e.g., T.float16 for kind::f16,
"tf32" for kind::tf32, "int8" for kind::i8, "float8_e4m3" for kind::f8f6f4). "tf32" for kind::tf32, "int8" for kind::i8, "float8_e4m3" for kind::f8f6f4).
""" """
return call_intrin( return call_intrin(
......
...@@ -13,6 +13,7 @@ from .utils import construct_strides ...@@ -13,6 +13,7 @@ from .utils import construct_strides
import tvm import tvm
from tvm.tir import Buffer from tvm.tir import Buffer
from tvm.script.ir_builder import tir, IRBuilder from tvm.script.ir_builder import tir, IRBuilder
from tvm.tir.expr import BufferLoad, EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var from tvm.tir.expr import BufferLoad, EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var
from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, Union from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, Union
from collections.abc import Sequence from collections.abc import Sequence
......
...@@ -11,7 +11,7 @@ _T = TypeVar("_T") ...@@ -11,7 +11,7 @@ _T = TypeVar("_T")
if TYPE_CHECKING: if TYPE_CHECKING:
class dtype(Generic[_T]): class dtype(Generic[_T]):
def torch(self) -> torch.dtype: ... def as_torch(self) -> torch.dtype: ...
else: else:
dtype = tvm.DataType dtype = tvm.DataType
...@@ -68,7 +68,32 @@ _TORCH_DTYPE_TO_STR = { ...@@ -68,7 +68,32 @@ _TORCH_DTYPE_TO_STR = {
torch.bfloat16: "bfloat16", torch.bfloat16: "bfloat16",
} }
# _STR_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_STR.items()} _extended_torch_dtypes = [
("float8_e4m3fn",),
("float8_e4m3fnuz",),
("float8_e5m2",),
("float8_e5m2fnuz",),
("float8_e8m0fnu",),
("float4_e2m1fnx2",),
]
for dtype_name_tuple in _extended_torch_dtypes:
dtype_name = dtype_name_tuple[0]
torch_dtype = getattr(torch, dtype_name, None)
if torch_dtype is not None:
_TORCH_DTYPE_TO_STR[torch_dtype] = dtype_name
_CANONICAL_TO_DISPLAY_STR = {
"double": "float64",
"float": "float32",
"int": "int32",
"long": "int64",
"short": "int16",
"uint": "uint32",
"ulong": "uint64",
}
_STR_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_STR.items()}
# _STR_TO_NUMPY_DTYPE = {v: k for k, v in _NUMPY_DTYPE_TO_STR.items()} # _STR_TO_NUMPY_DTYPE = {v: k for k, v in _NUMPY_DTYPE_TO_STR.items()}
...@@ -76,7 +101,9 @@ _DTYPE_TO_STR = {**_PYTHON_DTYPE_TO_STR, **_NUMPY_DTYPE_TO_STR, **_TORCH_DTYPE_T ...@@ -76,7 +101,9 @@ _DTYPE_TO_STR = {**_PYTHON_DTYPE_TO_STR, **_NUMPY_DTYPE_TO_STR, **_TORCH_DTYPE_T
_STR_TO_TVM_DTYPE_CALL = { _STR_TO_TVM_DTYPE_CALL = {
"bool": "Boolean", "bool": "Boolean",
"int4": "Int4",
"int8": "Int8", "int8": "Int8",
"int16": "Int16",
"int32": "Int32", "int32": "Int32",
"int64": "Int64", "int64": "Int64",
"uint8": "UInt8", "uint8": "UInt8",
...@@ -127,12 +154,20 @@ def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var ...@@ -127,12 +154,20 @@ def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var
return call(expr, is_size_var) return call(expr, is_size_var)
def __dtype_as_torch__(self: dtype) -> torch.dtype:
"""Convert TileLang dtype to PyTorch dtype."""
dtype_str = str(self)
if dtype_str in _STR_TO_TORCH_DTYPE:
return _STR_TO_TORCH_DTYPE[dtype_str]
raise ValueError(f"Cannot convert dtype '{dtype_str}' to torch.dtype. Supported dtypes: {list(_STR_TO_TORCH_DTYPE.keys())}")
__orig_dtype_new = dtype.__new__ __orig_dtype_new = dtype.__new__
def __dtype_new__(cls, value: AnyDType) -> dtype: def __dtype_new__(cls, value: AnyDType) -> dtype:
if isinstance(value, str): if isinstance(value, str):
return __orig_dtype_new(cls, value) return __orig_dtype_new(cls, _CANONICAL_TO_DISPLAY_STR.get(value, value))
elif value in _DTYPE_TO_STR: elif value in _DTYPE_TO_STR:
return __orig_dtype_new(cls, _DTYPE_TO_STR[value]) return __orig_dtype_new(cls, _DTYPE_TO_STR[value])
else: else:
...@@ -142,6 +177,7 @@ def __dtype_new__(cls, value: AnyDType) -> dtype: ...@@ -142,6 +177,7 @@ def __dtype_new__(cls, value: AnyDType) -> dtype:
dtype.__call__ = __dtype_call__ dtype.__call__ = __dtype_call__
dtype.__new__ = __dtype_new__ dtype.__new__ = __dtype_new__
dtype.as_torch = __dtype_as_torch__
def get_tvm_dtype(value: AnyDType) -> dtype: def get_tvm_dtype(value: AnyDType) -> dtype:
...@@ -155,10 +191,12 @@ if TYPE_CHECKING: ...@@ -155,10 +191,12 @@ if TYPE_CHECKING:
class bool(dtype): ... class bool(dtype): ...
class short(dtype): ... class short(dtype): ...
class int(dtype): ... class int(dtype): ...
class uint(dtype): ...
class long(dtype): ... class long(dtype): ...
class half(dtype): ... class half(dtype): ...
class float(dtype): ... class float(dtype): ...
class double(dtype): ... class double(dtype): ...
class int4(dtype): ...
class int8(dtype): ... class int8(dtype): ...
class int16(dtype): ... class int16(dtype): ...
class int32(dtype): ... class int32(dtype): ...
...@@ -320,10 +358,12 @@ else: ...@@ -320,10 +358,12 @@ else:
bool = dtype("bool") bool = dtype("bool")
short = dtype("int16") short = dtype("int16")
int = dtype("int32") int = dtype("int32")
uint = dtype("uint32")
long = dtype("int64") long = dtype("int64")
half = dtype("float16") half = dtype("float16")
float = dtype("float32") float = dtype("float32")
double = dtype("float64") double = dtype("float64")
int4 = dtype("int4")
int8 = dtype("int8") int8 = dtype("int8")
int16 = dtype("int16") int16 = dtype("int16")
int32 = dtype("int32") int32 = dtype("int32")
...@@ -484,10 +524,12 @@ _all_dtypes = { ...@@ -484,10 +524,12 @@ _all_dtypes = {
"bool", "bool",
"short", "short",
"int", "int",
"uint",
"long", "long",
"half", "half",
"float", "float",
"double", "double",
"int4",
"int8", "int8",
"int16", "int16",
"int32", "int32",
......
...@@ -31,10 +31,20 @@ def make_cutlass_metadata_layout_sm90(buffer: tvm.tir.Buffer, mma_dtype: str, bl ...@@ -31,10 +31,20 @@ def make_cutlass_metadata_layout_sm90(buffer: tvm.tir.Buffer, mma_dtype: str, bl
block_k = 128 block_k = 128
# Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 # Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
warnings.warn(f"block_k {block_k} is too large, set to 128 for {mma_dtype}.", stacklevel=2) warnings.warn(f"block_k {block_k} is too large, set to 128 for {mma_dtype}.", stacklevel=2)
if mma_dtype not in ["float16", "bfloat16", "float32", "int8", "float8_e4m3", "float8_e5m2"]: if mma_dtype not in [
T.float16,
T.bfloat16,
T.float32,
T.int8,
T.float8_e4m3,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
]:
raise NotImplementedError(f"Unsupported dtype: {mma_dtype}") raise NotImplementedError(f"Unsupported dtype: {mma_dtype}")
if buffer.dtype not in ["uint8", "int8"]: if buffer.dtype not in [T.uint8, T.int8]:
raise ValueError(f"metadata should be 8 bit, got {buffer.dtype}") raise ValueError(f"metadata should be 8 bit, got {buffer.dtype}")
bits_map = { bits_map = {
...@@ -43,7 +53,10 @@ def make_cutlass_metadata_layout_sm90(buffer: tvm.tir.Buffer, mma_dtype: str, bl ...@@ -43,7 +53,10 @@ def make_cutlass_metadata_layout_sm90(buffer: tvm.tir.Buffer, mma_dtype: str, bl
"float32": 32, "float32": 32,
"int8": 8, "int8": 8,
"float8_e4m3": 8, "float8_e4m3": 8,
"float8_e4m3fn": 8,
"float8_e4m3fnuz": 8,
"float8_e5m2": 8, "float8_e5m2": 8,
"float8_e5m2fnuz": 8,
} }
# ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl#L108-L117 # ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl#L108-L117
...@@ -112,10 +125,10 @@ def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str): ...@@ -112,10 +125,10 @@ def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str):
buffer: metadata buffer shape, for sm80 it should be a 16bit type buffer: metadata buffer shape, for sm80 it should be a 16bit type
""" """
if mma_dtype in ["float16", "bfloat16"] and buffer.dtype not in ["uint16", "int16"]: if mma_dtype in [T.float16, T.bfloat16] and buffer.dtype not in [T.uint16, T.int16]:
raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}") raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}")
if mma_dtype in ["float8_e4m3", "float8_e5m2", "int8", "uint8"] and buffer.dtype not in ["uint32", "int32"]: if mma_dtype in ["float8_e4m3", "float8_e5m2", T.int8, T.uint8] and buffer.dtype not in [T.uint32, T.int32]:
raise ValueError(f"metadata should be 32 bit, got {buffer.dtype}") raise ValueError(f"metadata should be 32 bit, got {buffer.dtype}")
m, k = buffer.shape m, k = buffer.shape
...@@ -134,7 +147,7 @@ def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str): ...@@ -134,7 +147,7 @@ def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str):
return T.Layout(buffer.shape, ColumnMajorInterleaved) return T.Layout(buffer.shape, ColumnMajorInterleaved)
def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer, mma_dtype: str = "float16", arch: str | None = None, **extra_args): def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer, mma_dtype: str = T.float16, arch: str | None = None, **extra_args):
if arch is None: if arch is None:
arch = nvcc.get_target_compute_version() arch = nvcc.get_target_compute_version()
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
from typing import Literal from typing import Literal
from tilelang import language as T
decode_i4_to_f16 = """ decode_i4_to_f16 = """
template <typename T1, typename T2, bool isSigned = false> template <typename T1, typename T2, bool isSigned = false>
...@@ -1088,10 +1089,10 @@ __device__ void decode_i2u_to_i4s(T1 *_i4u, T2 *B_local_decode, const int N = 16 ...@@ -1088,10 +1089,10 @@ __device__ void decode_i2u_to_i4s(T1 *_i4u, T2 *B_local_decode, const int N = 16
def get_lop3_intrin_group( def get_lop3_intrin_group(
out_dtype: Literal["float16", "int8", "int4"], out_dtype: Literal[T.float16, T.int8, T.int4],
source_format: Literal["int", "uint"] = "uint", source_format: Literal[T.int, T.uint] = T.uint,
source_bit: int = 4, source_bit: int = 4,
storage_dtype: Literal["int32", "int8"] = "int8", storage_dtype: Literal[T.int32, T.int8] = T.int8,
with_scaling: bool = False, with_scaling: bool = False,
with_zeros: bool = False, with_zeros: bool = False,
zeros_mode: Literal["original", "rescale", "quantized"] = "original", zeros_mode: Literal["original", "rescale", "quantized"] = "original",
...@@ -1104,10 +1105,10 @@ def get_lop3_intrin_group( ...@@ -1104,10 +1105,10 @@ def get_lop3_intrin_group(
Parameters Parameters
---------- ----------
in_dtype : Literal["int8"] in_dtype : Literal[T.int8]
The data type of the input. It should be "int8". The data type of the input. It should be "int8".
out_dtype : Literal["float16", "int8", "int4"] out_dtype : Literal[T.float16, T.int8, T.int4]
The data type of the output. It can be either "float16" or "int8" or "int4". The data type of the output. It can be either "float16" or "int8" or "int4".
storage_nbit : int, optional storage_nbit : int, optional
...@@ -1130,18 +1131,17 @@ def get_lop3_intrin_group( ...@@ -1130,18 +1131,17 @@ def get_lop3_intrin_group(
Dict[str, str] Dict[str, str]
A dictionary mapping the names of the intrinsics to their corresponding implementations. A dictionary mapping the names of the intrinsics to their corresponding implementations.
""" """
assert out_dtype in ["float16", "int8", "int4"], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' ." out_dtype, source_format, storage_dtype = T.dtype(out_dtype), T.dtype(source_format), T.dtype(storage_dtype)
assert out_dtype in [T.float16, T.int8, T.int4], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'int8' or 'int4' ."
dtype_mapping = {"float16": "f16", "int4": "i4", "int8": "i8", "int32": "i32"} dtype_mapping = {T.float16: "f16", T.int4: "i4", T.int8: "i8", T.int32: "i32"}
target_dtype = dtype_mapping[out_dtype] target_dtype = dtype_mapping[out_dtype]
if source_format not in ["int", "uint"]: if source_format not in [T.int, T.uint]:
raise ValueError(f"Invalid source_format. Expected 'int' or 'uint', but got {source_format}.") raise ValueError(f"Invalid source_format. Expected 'int' or 'uint', but got {source_format}, {type(source_format)}.")
if with_zeros and source_format == "int": if with_zeros and source_format == T.int:
raise ValueError(f"Zeros are not supported for signed integers, but got {source_format}") raise ValueError(f"Zeros are not supported for signed integers, but got {source_format}")
source_symbol = "i" if source_format == "int" else "u"
import_c_map = { import_c_map = {
"i4_to_f16": decode_i4_to_f16, "i4_to_f16": decode_i4_to_f16,
"i2_to_f16": decode_i2_to_f16, "i2_to_f16": decode_i2_to_f16,
...@@ -1176,15 +1176,15 @@ def get_lop3_intrin_group( ...@@ -1176,15 +1176,15 @@ def get_lop3_intrin_group(
if is_ladder_stage3: if is_ladder_stage3:
key += "_offset" key += "_offset"
if out_dtype == "float16": if out_dtype == T.float16:
d4f = "f16" d4f = "f16"
elif out_dtype == "int8": elif out_dtype == T.int8:
d4f = "i8s" d4f = "i8s"
elif out_dtype == "int4": elif out_dtype == T.int4:
d4f = "i4s" d4f = "i4s"
else: else:
raise ValueError(f"Unsupported target dtype: {target_dtype}") raise ValueError(f"Unsupported target dtype: {target_dtype}")
source_symbol = "u" if source_format == "uint" else "s" source_symbol = "u" if source_format == T.uint else "s"
func_name = f"decode_i{source_bit}{source_symbol}_to_{d4f}" func_name = f"decode_i{source_bit}{source_symbol}_to_{d4f}"
if with_scaling: if with_scaling:
func_name += "_scale" func_name += "_scale"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment