Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
from typing import Literal from typing import Literal
from tilelang import language as T
# Implementation asm for fp4 to bf16, using twiddling # Implementation asm for fp4 to bf16, using twiddling
# Reference: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py#L11-L18 # Reference: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py#L11-L18
...@@ -49,10 +50,10 @@ __device__ void decode_fp4_to_bf16_twiddling(T1 *B_local, T2 *B_local_decode, co ...@@ -49,10 +50,10 @@ __device__ void decode_fp4_to_bf16_twiddling(T1 *B_local, T2 *B_local_decode, co
def get_mxfp_intrin_group( def get_mxfp_intrin_group(
out_dtype: Literal["float16", "bfloat16"] = "bfloat16", out_dtype: Literal[T.float16, T.bfloat16] = T.bfloat16,
source_format: Literal["int", "uint"] = "uint", source_format: Literal[T.int, T.uint] = T.uint,
source_bit: int = 4, source_bit: int = 4,
storage_dtype: Literal["int32", "int8", "uint8"] = "uint8", storage_dtype: Literal[T.int32, T.int8, T.uint8] = T.uint8,
use_twiddling: bool = False, use_twiddling: bool = False,
) -> dict[str, str]: ) -> dict[str, str]:
""" """
...@@ -65,10 +66,10 @@ def get_mxfp_intrin_group( ...@@ -65,10 +66,10 @@ def get_mxfp_intrin_group(
`_twiddling`). `_twiddling`).
Parameters: Parameters:
out_dtype: Target floating-point type for decoded values; either "float16" or "bfloat16". out_dtype: Target floating-point type for decoded values; either T.float16 or T.bfloat16.
source_format: Integer source representation; "int" or "uint". source_format: Integer source representation; "int" or "uint".
source_bit: Bit width of the packed source format (e.g., 4). source_bit: Bit width of the packed source format (e.g., 4).
storage_dtype: Underlying storage integer dtype (one of "int32", "int8", "uint8"). storage_dtype: Underlying storage integer dtype (one of T.int32, T.int8, T.uint8).
use_twiddling: When True, select the twiddling variant of the decoding intrinsic. use_twiddling: When True, select the twiddling variant of the decoding intrinsic.
Returns: Returns:
...@@ -80,11 +81,12 @@ def get_mxfp_intrin_group( ...@@ -80,11 +81,12 @@ def get_mxfp_intrin_group(
AssertionError: if out_dtype, source_format, or storage_dtype are not supported. AssertionError: if out_dtype, source_format, or storage_dtype are not supported.
KeyError: if the constructed key does not match any available C source implementation. KeyError: if the constructed key does not match any available C source implementation.
""" """
assert out_dtype in ["float16", "bfloat16"], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'." out_dtype, source_format, storage_dtype = T.dtype(out_dtype), T.dtype(source_format), T.dtype(storage_dtype)
assert source_format in ["int", "uint"], f"Invalid source_format: {source_format}. Expected 'int' or 'uint'." assert out_dtype in [T.float16, T.bfloat16], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'."
assert storage_dtype in ["int32", "int8", "uint8"], f"Invalid storage_dtype: {storage_dtype}. Expected 'int32' or 'int8' or 'uint8'." assert source_format in [T.int, T.uint], f"Invalid source_format: {source_format}. Expected 'int' or 'uint'."
assert storage_dtype in [T.int32, T.int8, T.uint8], f"Invalid storage_dtype: {storage_dtype}. Expected 'int32' or 'int8' or 'uint8'."
dtype_map = {"float16": "f16", "bfloat16": "bf16"} dtype_map = {T.float16: "f16", T.bfloat16: "bf16"}
key = f"fp{source_bit}_to_{dtype_map[out_dtype]}" key = f"fp{source_bit}_to_{dtype_map[out_dtype]}"
if use_twiddling: if use_twiddling:
key += "_twiddling" key += "_twiddling"
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
# pylint: disable=invalid-name,missing-function-docstring,unused-variable # pylint: disable=invalid-name,missing-function-docstring,unused-variable
"""TIR computation utilities for quantization.""" """TIR computation utilities for quantization."""
from tilelang import language as T
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import tir from tvm import tir
...@@ -36,7 +37,7 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale ...@@ -36,7 +37,7 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
a bfloat16 constructed from the unpacked sign, a scaled exponent, and the 1-bit mantissa. a bfloat16 constructed from the unpacked sign, a scaled exponent, and the 1-bit mantissa.
Behavior: Behavior:
- Validates `nbit == 4`, `dtype == "bfloat16"`, and `val.dtype == "uint8"` (AssertionError if violated). - Validates `nbit == 4`, `dtype == T.bfloat16`, and `val.dtype == T.uint8` (AssertionError if violated).
- Extracts the 4-bit field at position `pos` (fields are packed consecutively in `val`). - Extracts the 4-bit field at position `pos` (fields are packed consecutively in `val`).
- Interprets the 4-bit field as: sign = bit3, exponent = bits1-2, mantissa = bit0. - Interprets the 4-bit field as: sign = bit3, exponent = bits1-2, mantissa = bit0.
- Converts the 2-bit exponent to bf16 exponent space by adding a bias of 126, adds `scale` to that exponent, - Converts the 2-bit exponent to bf16 exponent space by adding a bias of 126, adds `scale` to that exponent,
...@@ -49,27 +50,27 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale ...@@ -49,27 +50,27 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
- val: uint8 expression containing packed fields. - val: uint8 expression containing packed fields.
- pos: index of the field within `val` (0-based); used to compute the bit shift. - pos: index of the field within `val` (0-based); used to compute the bit shift.
- scale: exponent-scale to add to the converted exponent (treated as an unsigned integer expression). - scale: exponent-scale to add to the converted exponent (treated as an unsigned integer expression).
- dtype: must be "bfloat16". - dtype: must be T.bfloat16.
Returns: Returns:
- A tir.PrimExpr of dtype "bfloat16" representing the decoded and scaled value. - A tir.PrimExpr of dtype "bfloat16" representing the decoded and scaled value.
""" """
assert nbit == 4 assert nbit == 4
assert dtype == "bfloat16" assert dtype == T.bfloat16
assert val.dtype == "uint8" assert val.dtype == T.uint8
mask = tir.const((1 << nbit) - 1, "uint16") mask = tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, "uint16") s = f4 >> tir.const(3, T.uint16)
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16)
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16") e_bf16 = e_f4 + tir.const(126, T.uint16)
# Scale is the exponential part, within the representation of uint8 # Scale is the exponential part, within the representation of uint8
# To handle the overflow, we use the max function to limit the exponential part to 8 bits # To handle the overflow, we use the max function to limit the exponential part to 8 bits
e_bf16 = min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) e_bf16 = min(e_bf16 + scale, tir.const((1 << 8) - 1, T.uint16))
m_f4 = f4 & tir.const(1, "uint16") m_f4 = f4 & tir.const(1, T.uint16)
val_bf16 = tir.reinterpret("bfloat16", val_bf16 = tir.reinterpret(T.bfloat16,
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16")) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16))
return val_bf16 return val_bf16
def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True): def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True):
...@@ -88,29 +89,29 @@ def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_eve ...@@ -88,29 +89,29 @@ def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_eve
Returns: Returns:
tir.PrimExpr: A uint32 PrimExpr containing the packed bfloat16 representations (v0 low 16 bits, v1 high 16 bits). tir.PrimExpr: A uint32 PrimExpr containing the packed bfloat16 representations (v0 low 16 bits, v1 high 16 bits).
""" """
mask = tir.const((1 << 16) - 1, "uint32") mask = tir.const((1 << 16) - 1, T.uint32)
res = [] res = []
for data in [v0, v1]: for data in [v0, v1]:
u32_val = tir.reinterpret("uint32", data) u32_val = tir.reinterpret(T.uint32, data)
if round_to_even: if round_to_even:
rounding_bias = ((u32_val >> tir.const(16, "uint32")) rounding_bias = ((u32_val >> tir.const(16, T.uint32))
& tir.const(1, "uint32")) + tir.const(0x7FFF, "uint32") & tir.const(1, T.uint32)) + tir.const(0x7FFF, T.uint32)
u32_val += rounding_bias u32_val += rounding_bias
res.append((u32_val >> tir.const(16, "uint32")) & mask) res.append((u32_val >> tir.const(16, T.uint32)) & mask)
return res[0] | (res[1] << tir.const(16, "uint32")) return res[0] | (res[1] << tir.const(16, T.uint32))
def _tir_u32_to_bf16x2_to_f32x2(x: tir.PrimExpr): def _tir_u32_to_bf16x2_to_f32x2(x: tir.PrimExpr):
mask = tir.const((1 << 16) - 1, "uint32") mask = tir.const((1 << 16) - 1, T.uint32)
x0 = x & mask x0 = x & mask
x1 = (x >> 16) & mask x1 = (x >> 16) & mask
return (tir.reinterpret("float32", x << tir.const(16, "uint32")) for x in [x0, x1]) return (tir.reinterpret(T.float32, x << tir.const(16, T.uint32)) for x in [x0, x1])
def _tir_u32_to_int_to_float(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): def _tir_u32_to_int_to_float(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == "uint32" assert val.dtype == T.uint32
mask = tvm.tir.const((1 << nbit) - 1, "uint32") mask = tvm.tir.const((1 << nbit) - 1, T.uint32)
return tir.Cast(dtype, (val >> (pos * nbit).astype("uint32")) & mask) return tir.Cast(dtype, (val >> (pos * nbit).astype(T.uint32)) & mask)
def _tir_packed_uint_to_uint_to_float(storage_nbit: int): def _tir_packed_uint_to_uint_to_float(storage_nbit: int):
...@@ -119,7 +120,7 @@ def _tir_packed_uint_to_uint_to_float(storage_nbit: int): ...@@ -119,7 +120,7 @@ def _tir_packed_uint_to_uint_to_float(storage_nbit: int):
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
max_int_value = (1 << (nbit - 1)) - 1 max_int_value = (1 << (nbit - 1)) - 1
return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const( return ((val >> (pos.astype(T.uint32) * tir.const(nbit, T.uint32))) & tir.const(
(1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) (1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype)
return f_convert return f_convert
...@@ -130,74 +131,74 @@ def _tir_packed_int_to_int_to_float(storage_nbit: int): ...@@ -130,74 +131,74 @@ def _tir_packed_int_to_int_to_float(storage_nbit: int):
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
mask = tir.const((1 << nbit) - 1, "int32") mask = tir.const((1 << nbit) - 1, T.int32)
unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask unextended = (val >> (pos.astype(T.int32) * tir.const(nbit, T.int32))) & mask
return tir.Cast( return tir.Cast(
dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) dtype, (unextended << tir.const(32 - nbit, T.int32)) >> tir.const(32 - nbit, T.int32))
return f_convert return f_convert
def _tir_f32_to_uint_to_f4(val: tir.PrimExpr): def _tir_f32_to_uint_to_f4(val: tir.PrimExpr):
assert val.dtype == "float32" assert val.dtype == T.float32
val_u32 = tir.reinterpret("uint32", val) val_u32 = tir.reinterpret(T.uint32, val)
# e_f32 > 120 -> e_f4 = min(e_f32 - 120 + M_h, 7) # e_f32 > 120 -> e_f4 = min(e_f32 - 120 + M_h, 7)
# e_f32 == 120 -> e_f4 = 1 # e_f32 == 120 -> e_f4 = 1
# e_f32 < 120 -> e_f4 = 0 # e_f32 < 120 -> e_f4 = 0
m_h = (val_u32 >> tir.const(22, "uint32")) & tir.const(1, "uint32") m_h = (val_u32 >> tir.const(22, T.uint32)) & tir.const(1, T.uint32)
e_f32 = (val_u32 >> tir.const(23, "uint32")) & tir.const(255, "uint32") e_f32 = (val_u32 >> tir.const(23, T.uint32)) & tir.const(255, T.uint32)
s = (val_u32 >> tir.const(31, "uint32")) s = (val_u32 >> tir.const(31, T.uint32))
e_f4 = tir.Select( e_f4 = tir.Select(
e_f32 > tir.const(120, "uint32"), e_f32 > tir.const(120, T.uint32),
tir.Min(e_f32 - tir.const(120, "uint32") + m_h, tir.const(7, "uint32")), tir.Min(e_f32 - tir.const(120, T.uint32) + m_h, tir.const(7, T.uint32)),
tir.Select(e_f32 == tir.const(120, "uint32"), tir.const(1, "uint32"), tir.Select(e_f32 == tir.const(120, T.uint32), tir.const(1, T.uint32),
tir.const(0, "uint32"))) tir.const(0, T.uint32)))
return (s << tir.const(3, "uint32")) | e_f4 return (s << tir.const(3, T.uint32)) | e_f4
def _tir_f16_to_uint_to_f4(val: tir.PrimExpr): def _tir_f16_to_uint_to_f4(val: tir.PrimExpr):
assert val.dtype == "float16" assert val.dtype == T.float16
val_u32 = tir.Cast("uint32", tir.reinterpret("uint16", val)) val_u32 = tir.Cast(T.uint32, tir.reinterpret(T.uint16, val))
m_h = (val_u32 >> tir.const(9, "uint32")) & tir.const(1, "uint32") m_h = (val_u32 >> tir.const(9, T.uint32)) & tir.const(1, T.uint32)
e_f16 = (val_u32 >> tir.const(10, "uint32")) & tir.const(31, "uint32") e_f16 = (val_u32 >> tir.const(10, T.uint32)) & tir.const(31, T.uint32)
s = (val_u32 >> tir.const(15, "uint32")) s = (val_u32 >> tir.const(15, T.uint32))
e_f4 = tir.Select( e_f4 = tir.Select(
e_f16 > tir.const(8, "uint32"), e_f16 > tir.const(8, T.uint32),
tir.Min(e_f16 - tir.const(8, "uint32") + m_h, tir.const(7, "uint32")), tir.Min(e_f16 - tir.const(8, T.uint32) + m_h, tir.const(7, T.uint32)),
tir.Select(e_f16 == tir.const(8, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32"))) tir.Select(e_f16 == tir.const(8, T.uint32), tir.const(1, T.uint32), tir.const(0, T.uint32)))
return (s << tir.const(3, "uint32")) | e_f4 return (s << tir.const(3, T.uint32)) | e_f4
def _tir_u32_to_f4_to_f32(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): def _tir_u32_to_f4_to_f32(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4 assert nbit == 4
assert dtype == "float32" assert dtype == T.float32
assert val.dtype == "uint32" assert val.dtype == T.uint32
# e_f4 == 0 -> e_f32 = 0 # e_f4 == 0 -> e_f32 = 0
# e_f4 != 0 -> e_f32 = e_f4 + 120 = e_f4 | (1111000)_2 # e_f4 != 0 -> e_f32 = e_f4 + 120 = e_f4 | (1111000)_2
mask = tvm.tir.const((1 << nbit) - 1, "uint32") mask = tvm.tir.const((1 << nbit) - 1, T.uint32)
f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask f4 = (val >> (pos.astype(T.uint32) * tir.const(nbit, T.uint32))) & mask
s = f4 >> tir.const(3, "uint32") s = f4 >> tir.const(3, T.uint32)
e_f4 = f4 & tir.const(7, "uint32") e_f4 = f4 & tir.const(7, T.uint32)
e_f32 = e_f4 | tir.const(120, "uint32") e_f32 = e_f4 | tir.const(120, T.uint32)
val_f32 = tir.reinterpret("float32", val_f32 = tir.reinterpret(T.float32,
(e_f32 | (s << tir.const(8, "uint32"))) << tir.const(23, "uint32")) (e_f32 | (s << tir.const(8, T.uint32))) << tir.const(23, T.uint32))
return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float32"), val_f32) return tir.Select(e_f4 == tir.const(0, T.uint32), tir.const(0, T.float32), val_f32)
def _tir_packed_to_fp4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): def _tir_packed_to_fp4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4 assert nbit == 4
assert dtype == "float16" assert dtype == T.float16
assert val.dtype == "uint32" assert val.dtype == T.uint32
# e_f4 == 0 -> e_f16 = 0 # e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2 # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
mask = tvm.tir.const((1 << nbit) - 1, "uint16") mask = tvm.tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, "uint16") s = f4 >> tir.const(3, T.uint16)
e_f4 = f4 & tir.const(7, "uint16") e_f4 = f4 & tir.const(7, T.uint16)
e_f16 = e_f4 | tir.const(8, "uint16") e_f16 = e_f4 | tir.const(8, T.uint16)
val_f16 = tir.reinterpret("float16", val_f16 = tir.reinterpret(T.float16,
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16")) ((e_f16 | (s << tir.const(5, T.uint16))) << tir.const(10, T.uint16)).astype(T.uint16))
return tir.Select(e_f4 == tir.const(0, "uint16"), tir.const(0, "float16"), val_f16) return tir.Select(e_f4 == tir.const(0, T.uint16), tir.const(0, T.float16), val_f16)
def _tir_packed_to_fp4_to_f16(storage_type="uint", storage_nbit=8): def _tir_packed_to_fp4_to_f16(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit) storage_dtype = storage_type + str(storage_nbit)
...@@ -210,37 +211,37 @@ def _tir_packed_to_fp4_to_f16(storage_type="uint", storage_nbit=8): ...@@ -210,37 +211,37 @@ def _tir_packed_to_fp4_to_f16(storage_type="uint", storage_nbit=8):
s = f4 >> tir.const(3, storage_dtype) s = f4 >> tir.const(3, storage_dtype)
e_f4 = f4 & tir.const(7, storage_dtype) e_f4 = f4 & tir.const(7, storage_dtype)
e_f16 = e_f4 | tir.const(8, storage_dtype) e_f16 = e_f4 | tir.const(8, storage_dtype)
val_f16 = tir.reinterpret("float16", val_f16 = tir.reinterpret(T.float16,
((e_f16 | (s << tir.const(5, storage_dtype))) << tir.const(10, storage_dtype)).astype("uint16")) ((e_f16 | (s << tir.const(5, storage_dtype))) << tir.const(10, storage_dtype)).astype(T.uint16))
return tir.Select(e_f4 == tir.const(0, storage_dtype), tir.const(0, "float16"), val_f16) return tir.Select(e_f4 == tir.const(0, storage_dtype), tir.const(0, T.float16), val_f16)
return f_convert return f_convert
def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str): def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8 assert nbit == 8
assert dtype == "float16" assert dtype == T.float16
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16") s_f16 = (val >> tir.const(7, T.uint16)) << tir.const(15, T.uint16)
e4 = val & tir.const(0x40, "uint16") e4 = val & tir.const(0x40, T.uint16)
prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"), prefix = tir.Select(e4 == tir.const(0, T.uint16), tir.const(0x2000, T.uint16),
tir.const(0x4000, "uint16")) tir.const(0x4000, T.uint16))
e_f16 = ((val & tir.const(63, "uint16")) << tir.const(7, "uint16")) | prefix e_f16 = ((val & tir.const(63, T.uint16)) << tir.const(7, T.uint16)) | prefix
return tir.reinterpret("float16", s_f16 | e_f16) return tir.reinterpret(T.float16, s_f16 | e_f16)
def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8 assert nbit == 8
assert dtype == "float16" assert dtype == T.float16
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16") s_f16 = (val >> tir.const(7, T.uint16)) << tir.const(15, T.uint16)
e4 = val & tir.const(0x40, "uint16") e4 = val & tir.const(0x40, T.uint16)
e_f16 = ((val & tir.const(63, "uint16")) << tir.const(7, "uint16")) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16")) e_f16 = ((val & tir.const(63, T.uint16)) << tir.const(7, T.uint16)) | (e4 << tir.const(8, T.uint16)) | (e4 << tir.const(7, T.uint16))
e_f16 = e_f16 ^ tir.const(0x2000, "uint16") e_f16 = e_f16 ^ tir.const(0x2000, T.uint16)
return tir.reinterpret("float16", s_f16 | e_f16) return tir.reinterpret(T.float16, s_f16 | e_f16)
def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8 assert nbit == 8
assert dtype == "float16" assert dtype == T.float16
return tir.reinterpret("float8_e5m2", val).astype("float16") return tir.reinterpret("float8_e5m2", val).astype(T.float16)
def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8): def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8):
...@@ -249,7 +250,7 @@ def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8): ...@@ -249,7 +250,7 @@ def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8):
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
max_int_value = (1 << (nbit - 1)) max_int_value = (1 << (nbit - 1))
return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const( return ((val >> (pos.astype(T.uint32) * tir.const(nbit, T.uint32))) & tir.const(
(1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) (1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype)
return f_convert return f_convert
...@@ -283,10 +284,10 @@ def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8): ...@@ -283,10 +284,10 @@ def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8):
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
mask = tir.const((1 << nbit) - 1, "int32") mask = tir.const((1 << nbit) - 1, T.int32)
unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask unextended = (val >> (pos.astype(T.int32) * tir.const(nbit, T.int32))) & mask
return tir.Cast( return tir.Cast(
dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) dtype, (unextended << tir.const(32 - nbit, T.int32)) >> tir.const(32 - nbit, T.int32))
return f_convert return f_convert
......
...@@ -2,6 +2,7 @@ from dataclasses import dataclass ...@@ -2,6 +2,7 @@ from dataclasses import dataclass
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
from tvm import tir from tvm import tir
from tilelang import language as T
from tilelang.utils.language import is_shared, is_fragment from tilelang.utils.language import is_shared, is_fragment
from tilelang.tileop.base import GemmWarpPolicy from tilelang.tileop.base import GemmWarpPolicy
from tvm.ir.base import Node from tvm.ir.base import Node
...@@ -121,7 +122,7 @@ class GemmBase: ...@@ -121,7 +122,7 @@ class GemmBase:
@property @property
def mbarptr(self) -> PrimExpr: def mbarptr(self) -> PrimExpr:
return getattr(self.gemm_node, "mbarPtr", tvm.tir.const(0, "uint32")) return getattr(self.gemm_node, "mbarPtr", tvm.tir.const(0, T.uint32))
@property @property
def mbar(self) -> tir.Buffer: def mbar(self) -> tir.Buffer:
...@@ -131,7 +132,7 @@ class GemmBase: ...@@ -131,7 +132,7 @@ class GemmBase:
def C_coords(self): def C_coords(self):
coords = getattr(self.gemm_node, "cCoords", None) coords = getattr(self.gemm_node, "cCoords", None)
if coords is None or len(coords) == 0: if coords is None or len(coords) == 0:
zero = tvm.tir.const(0, "int32") zero = tvm.tir.const(0, T.int32)
return [zero, zero] return [zero, zero]
return [coords[i] for i in range(len(coords))] return [coords[i] for i in range(len(coords))]
......
...@@ -98,7 +98,7 @@ class GemmTCGEN5(GemmBase): ...@@ -98,7 +98,7 @@ class GemmTCGEN5(GemmBase):
raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access") raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access")
accum_dtype = str(self.C.dtype) accum_dtype = str(self.C.dtype)
if accum_dtype not in ["float32", "float16"]: if accum_dtype not in [str(T.float32), str(T.float16)]:
raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}") raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}")
A_shared = self.ARegion A_shared = self.ARegion
......
...@@ -100,10 +100,10 @@ class PassConfigKey(str, Enum): ...@@ -100,10 +100,10 @@ class PassConfigKey(str, Enum):
such as `dst[i] = f(src[i])`, avoiding implicit aliasing: such as `dst[i] = f(src[i])`, avoiding implicit aliasing:
``` ```
read = T.allocate([1], "int32", "local.var") read = T.allocate([1], T.int32, "local.var")
write = T.allocate([1], "int32", "local.var") write = T.allocate([1], T.int32, "local.var")
read_buf = T.Buffer((1,), "int32", data=read, scope="local.var") read_buf = T.Buffer((1,), T.int32, data=read, scope="local.var")
write_buf = T.Buffer((1,), "int32", data=write, scope="local.var") write_buf = T.Buffer((1,), T.int32, data=write, scope="local.var")
write_buf[0] = read_buf[0] * 2 write_buf[0] = read_buf[0] * 2
f(write_buf[0]) f(write_buf[0])
``` ```
...@@ -113,8 +113,8 @@ class PassConfigKey(str, Enum): ...@@ -113,8 +113,8 @@ class PassConfigKey(str, Enum):
like: like:
``` ```
read = T.allocate([1], "int32", "local.var") read = T.allocate([1], T.int32, "local.var")
read_buf = T.Buffer((1,), "int32", data=read, scope="local.var") read_buf = T.Buffer((1,), T.int32, data=read, scope="local.var")
read_buf[0] = read_buf[0] * 2 read_buf[0] = read_buf[0] * 2
f(read_buf[0]) f(read_buf[0])
``` ```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment