Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
from typing import Literal
from tilelang import language as T
# Implementation asm for fp4 to bf16, using twiddling
# Reference: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py#L11-L18
......@@ -49,10 +50,10 @@ __device__ void decode_fp4_to_bf16_twiddling(T1 *B_local, T2 *B_local_decode, co
def get_mxfp_intrin_group(
out_dtype: Literal["float16", "bfloat16"] = "bfloat16",
source_format: Literal["int", "uint"] = "uint",
out_dtype: Literal[T.float16, T.bfloat16] = T.bfloat16,
source_format: Literal[T.int, T.uint] = T.uint,
source_bit: int = 4,
storage_dtype: Literal["int32", "int8", "uint8"] = "uint8",
storage_dtype: Literal[T.int32, T.int8, T.uint8] = T.uint8,
use_twiddling: bool = False,
) -> dict[str, str]:
"""
......@@ -65,10 +66,10 @@ def get_mxfp_intrin_group(
`_twiddling`).
Parameters:
out_dtype: Target floating-point type for decoded values; either "float16" or "bfloat16".
out_dtype: Target floating-point type for decoded values; either T.float16 or T.bfloat16.
source_format: Integer source representation; "int" or "uint".
source_bit: Bit width of the packed source format (e.g., 4).
storage_dtype: Underlying storage integer dtype (one of "int32", "int8", "uint8").
storage_dtype: Underlying storage integer dtype (one of T.int32, T.int8, T.uint8).
use_twiddling: When True, select the twiddling variant of the decoding intrinsic.
Returns:
......@@ -80,11 +81,12 @@ def get_mxfp_intrin_group(
AssertionError: if out_dtype, source_format, or storage_dtype are not supported.
KeyError: if the constructed key does not match any available C source implementation.
"""
assert out_dtype in ["float16", "bfloat16"], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'."
assert source_format in ["int", "uint"], f"Invalid source_format: {source_format}. Expected 'int' or 'uint'."
assert storage_dtype in ["int32", "int8", "uint8"], f"Invalid storage_dtype: {storage_dtype}. Expected 'int32' or 'int8' or 'uint8'."
out_dtype, source_format, storage_dtype = T.dtype(out_dtype), T.dtype(source_format), T.dtype(storage_dtype)
assert out_dtype in [T.float16, T.bfloat16], f"Invalid out_dtype: {out_dtype}. Expected 'float16' or 'bfloat16'."
assert source_format in [T.int, T.uint], f"Invalid source_format: {source_format}. Expected 'int' or 'uint'."
assert storage_dtype in [T.int32, T.int8, T.uint8], f"Invalid storage_dtype: {storage_dtype}. Expected 'int32' or 'int8' or 'uint8'."
dtype_map = {"float16": "f16", "bfloat16": "bf16"}
dtype_map = {T.float16: "f16", T.bfloat16: "bf16"}
key = f"fp{source_bit}_to_{dtype_map[out_dtype]}"
if use_twiddling:
key += "_twiddling"
......
......@@ -22,6 +22,7 @@
# pylint: disable=invalid-name,missing-function-docstring,unused-variable
"""TIR computation utilities for quantization."""
from tilelang import language as T
from tilelang import tvm as tvm
from tvm import tir
......@@ -36,7 +37,7 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
a bfloat16 constructed from the unpacked sign, a scaled exponent, and the 1-bit mantissa.
Behavior:
- Validates `nbit == 4`, `dtype == "bfloat16"`, and `val.dtype == "uint8"` (AssertionError if violated).
- Validates `nbit == 4`, `dtype == T.bfloat16`, and `val.dtype == T.uint8` (AssertionError if violated).
- Extracts the 4-bit field at position `pos` (fields are packed consecutively in `val`).
- Interprets the 4-bit field as: sign = bit3, exponent = bits1-2, mantissa = bit0.
- Converts the 2-bit exponent to bf16 exponent space by adding a bias of 126, adds `scale` to that exponent,
......@@ -49,27 +50,27 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
- val: uint8 expression containing packed fields.
- pos: index of the field within `val` (0-based); used to compute the bit shift.
- scale: exponent-scale to add to the converted exponent (treated as an unsigned integer expression).
- dtype: must be "bfloat16".
- dtype: must be T.bfloat16.
Returns:
- A tir.PrimExpr of dtype "bfloat16" representing the decoded and scaled value.
"""
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
assert dtype == T.bfloat16
assert val.dtype == T.uint8
mask = tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, T.uint16)
e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16)
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
e_bf16 = e_f4 + tir.const(126, T.uint16)
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we use the max function to limit the exponential part to 8 bits
e_bf16 = min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret("bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16"))
e_bf16 = min(e_bf16 + scale, tir.const((1 << 8) - 1, T.uint16))
m_f4 = f4 & tir.const(1, T.uint16)
val_bf16 = tir.reinterpret(T.bfloat16,
((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16))
| (m_f4 << tir.const(6, T.uint16))).astype(T.uint16))
return val_bf16
def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True):
......@@ -88,29 +89,29 @@ def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_eve
Returns:
tir.PrimExpr: A uint32 PrimExpr containing the packed bfloat16 representations (v0 low 16 bits, v1 high 16 bits).
"""
mask = tir.const((1 << 16) - 1, "uint32")
mask = tir.const((1 << 16) - 1, T.uint32)
res = []
for data in [v0, v1]:
u32_val = tir.reinterpret("uint32", data)
u32_val = tir.reinterpret(T.uint32, data)
if round_to_even:
rounding_bias = ((u32_val >> tir.const(16, "uint32"))
& tir.const(1, "uint32")) + tir.const(0x7FFF, "uint32")
rounding_bias = ((u32_val >> tir.const(16, T.uint32))
& tir.const(1, T.uint32)) + tir.const(0x7FFF, T.uint32)
u32_val += rounding_bias
res.append((u32_val >> tir.const(16, "uint32")) & mask)
return res[0] | (res[1] << tir.const(16, "uint32"))
res.append((u32_val >> tir.const(16, T.uint32)) & mask)
return res[0] | (res[1] << tir.const(16, T.uint32))
def _tir_u32_to_bf16x2_to_f32x2(x: tir.PrimExpr):
mask = tir.const((1 << 16) - 1, "uint32")
mask = tir.const((1 << 16) - 1, T.uint32)
x0 = x & mask
x1 = (x >> 16) & mask
return (tir.reinterpret("float32", x << tir.const(16, "uint32")) for x in [x0, x1])
return (tir.reinterpret(T.float32, x << tir.const(16, T.uint32)) for x in [x0, x1])
def _tir_u32_to_int_to_float(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == "uint32"
mask = tvm.tir.const((1 << nbit) - 1, "uint32")
return tir.Cast(dtype, (val >> (pos * nbit).astype("uint32")) & mask)
assert val.dtype == T.uint32
mask = tvm.tir.const((1 << nbit) - 1, T.uint32)
return tir.Cast(dtype, (val >> (pos * nbit).astype(T.uint32)) & mask)
def _tir_packed_uint_to_uint_to_float(storage_nbit: int):
......@@ -119,7 +120,7 @@ def _tir_packed_uint_to_uint_to_float(storage_nbit: int):
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
max_int_value = (1 << (nbit - 1)) - 1
return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const(
return ((val >> (pos.astype(T.uint32) * tir.const(nbit, T.uint32))) & tir.const(
(1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype)
return f_convert
......@@ -130,74 +131,74 @@ def _tir_packed_int_to_int_to_float(storage_nbit: int):
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
mask = tir.const((1 << nbit) - 1, "int32")
unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask
mask = tir.const((1 << nbit) - 1, T.int32)
unextended = (val >> (pos.astype(T.int32) * tir.const(nbit, T.int32))) & mask
return tir.Cast(
dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32"))
dtype, (unextended << tir.const(32 - nbit, T.int32)) >> tir.const(32 - nbit, T.int32))
return f_convert
def _tir_f32_to_uint_to_f4(val: tir.PrimExpr):
assert val.dtype == "float32"
val_u32 = tir.reinterpret("uint32", val)
assert val.dtype == T.float32
val_u32 = tir.reinterpret(T.uint32, val)
# e_f32 > 120 -> e_f4 = min(e_f32 - 120 + M_h, 7)
# e_f32 == 120 -> e_f4 = 1
# e_f32 < 120 -> e_f4 = 0
m_h = (val_u32 >> tir.const(22, "uint32")) & tir.const(1, "uint32")
e_f32 = (val_u32 >> tir.const(23, "uint32")) & tir.const(255, "uint32")
s = (val_u32 >> tir.const(31, "uint32"))
m_h = (val_u32 >> tir.const(22, T.uint32)) & tir.const(1, T.uint32)
e_f32 = (val_u32 >> tir.const(23, T.uint32)) & tir.const(255, T.uint32)
s = (val_u32 >> tir.const(31, T.uint32))
e_f4 = tir.Select(
e_f32 > tir.const(120, "uint32"),
tir.Min(e_f32 - tir.const(120, "uint32") + m_h, tir.const(7, "uint32")),
tir.Select(e_f32 == tir.const(120, "uint32"), tir.const(1, "uint32"),
tir.const(0, "uint32")))
return (s << tir.const(3, "uint32")) | e_f4
e_f32 > tir.const(120, T.uint32),
tir.Min(e_f32 - tir.const(120, T.uint32) + m_h, tir.const(7, T.uint32)),
tir.Select(e_f32 == tir.const(120, T.uint32), tir.const(1, T.uint32),
tir.const(0, T.uint32)))
return (s << tir.const(3, T.uint32)) | e_f4
def _tir_f16_to_uint_to_f4(val: tir.PrimExpr):
assert val.dtype == "float16"
val_u32 = tir.Cast("uint32", tir.reinterpret("uint16", val))
m_h = (val_u32 >> tir.const(9, "uint32")) & tir.const(1, "uint32")
e_f16 = (val_u32 >> tir.const(10, "uint32")) & tir.const(31, "uint32")
s = (val_u32 >> tir.const(15, "uint32"))
assert val.dtype == T.float16
val_u32 = tir.Cast(T.uint32, tir.reinterpret(T.uint16, val))
m_h = (val_u32 >> tir.const(9, T.uint32)) & tir.const(1, T.uint32)
e_f16 = (val_u32 >> tir.const(10, T.uint32)) & tir.const(31, T.uint32)
s = (val_u32 >> tir.const(15, T.uint32))
e_f4 = tir.Select(
e_f16 > tir.const(8, "uint32"),
tir.Min(e_f16 - tir.const(8, "uint32") + m_h, tir.const(7, "uint32")),
tir.Select(e_f16 == tir.const(8, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32")))
return (s << tir.const(3, "uint32")) | e_f4
e_f16 > tir.const(8, T.uint32),
tir.Min(e_f16 - tir.const(8, T.uint32) + m_h, tir.const(7, T.uint32)),
tir.Select(e_f16 == tir.const(8, T.uint32), tir.const(1, T.uint32), tir.const(0, T.uint32)))
return (s << tir.const(3, T.uint32)) | e_f4
def _tir_u32_to_f4_to_f32(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "float32"
assert val.dtype == "uint32"
assert dtype == T.float32
assert val.dtype == T.uint32
# e_f4 == 0 -> e_f32 = 0
# e_f4 != 0 -> e_f32 = e_f4 + 120 = e_f4 | (1111000)_2
mask = tvm.tir.const((1 << nbit) - 1, "uint32")
f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask
s = f4 >> tir.const(3, "uint32")
e_f4 = f4 & tir.const(7, "uint32")
e_f32 = e_f4 | tir.const(120, "uint32")
val_f32 = tir.reinterpret("float32",
(e_f32 | (s << tir.const(8, "uint32"))) << tir.const(23, "uint32"))
return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float32"), val_f32)
mask = tvm.tir.const((1 << nbit) - 1, T.uint32)
f4 = (val >> (pos.astype(T.uint32) * tir.const(nbit, T.uint32))) & mask
s = f4 >> tir.const(3, T.uint32)
e_f4 = f4 & tir.const(7, T.uint32)
e_f32 = e_f4 | tir.const(120, T.uint32)
val_f32 = tir.reinterpret(T.float32,
(e_f32 | (s << tir.const(8, T.uint32))) << tir.const(23, T.uint32))
return tir.Select(e_f4 == tir.const(0, T.uint32), tir.const(0, T.float32), val_f32)
def _tir_packed_to_fp4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "float16"
assert val.dtype == "uint32"
assert dtype == T.float16
assert val.dtype == T.uint32
# e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
mask = tvm.tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = f4 & tir.const(7, "uint16")
e_f16 = e_f4 | tir.const(8, "uint16")
val_f16 = tir.reinterpret("float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
return tir.Select(e_f4 == tir.const(0, "uint16"), tir.const(0, "float16"), val_f16)
mask = tvm.tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, T.uint16)
e_f4 = f4 & tir.const(7, T.uint16)
e_f16 = e_f4 | tir.const(8, T.uint16)
val_f16 = tir.reinterpret(T.float16,
((e_f16 | (s << tir.const(5, T.uint16))) << tir.const(10, T.uint16)).astype(T.uint16))
return tir.Select(e_f4 == tir.const(0, T.uint16), tir.const(0, T.float16), val_f16)
def _tir_packed_to_fp4_to_f16(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)
......@@ -210,37 +211,37 @@ def _tir_packed_to_fp4_to_f16(storage_type="uint", storage_nbit=8):
s = f4 >> tir.const(3, storage_dtype)
e_f4 = f4 & tir.const(7, storage_dtype)
e_f16 = e_f4 | tir.const(8, storage_dtype)
val_f16 = tir.reinterpret("float16",
((e_f16 | (s << tir.const(5, storage_dtype))) << tir.const(10, storage_dtype)).astype("uint16"))
return tir.Select(e_f4 == tir.const(0, storage_dtype), tir.const(0, "float16"), val_f16)
val_f16 = tir.reinterpret(T.float16,
((e_f16 | (s << tir.const(5, storage_dtype))) << tir.const(10, storage_dtype)).astype(T.uint16))
return tir.Select(e_f4 == tir.const(0, storage_dtype), tir.const(0, T.float16), val_f16)
return f_convert
def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16")
e4 = val & tir.const(0x40, "uint16")
prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"),
tir.const(0x4000, "uint16"))
e_f16 = ((val & tir.const(63, "uint16")) << tir.const(7, "uint16")) | prefix
return tir.reinterpret("float16", s_f16 | e_f16)
assert dtype == T.float16
s_f16 = (val >> tir.const(7, T.uint16)) << tir.const(15, T.uint16)
e4 = val & tir.const(0x40, T.uint16)
prefix = tir.Select(e4 == tir.const(0, T.uint16), tir.const(0x2000, T.uint16),
tir.const(0x4000, T.uint16))
e_f16 = ((val & tir.const(63, T.uint16)) << tir.const(7, T.uint16)) | prefix
return tir.reinterpret(T.float16, s_f16 | e_f16)
def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16")
e4 = val & tir.const(0x40, "uint16")
e_f16 = ((val & tir.const(63, "uint16")) << tir.const(7, "uint16")) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16"))
e_f16 = e_f16 ^ tir.const(0x2000, "uint16")
return tir.reinterpret("float16", s_f16 | e_f16)
assert dtype == T.float16
s_f16 = (val >> tir.const(7, T.uint16)) << tir.const(15, T.uint16)
e4 = val & tir.const(0x40, T.uint16)
e_f16 = ((val & tir.const(63, T.uint16)) << tir.const(7, T.uint16)) | (e4 << tir.const(8, T.uint16)) | (e4 << tir.const(7, T.uint16))
e_f16 = e_f16 ^ tir.const(0x2000, T.uint16)
return tir.reinterpret(T.float16, s_f16 | e_f16)
def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
return tir.reinterpret("float8_e5m2", val).astype("float16")
assert dtype == T.float16
return tir.reinterpret("float8_e5m2", val).astype(T.float16)
def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8):
......@@ -249,7 +250,7 @@ def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8):
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
max_int_value = (1 << (nbit - 1))
return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const(
return ((val >> (pos.astype(T.uint32) * tir.const(nbit, T.uint32))) & tir.const(
(1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype)
return f_convert
......@@ -283,10 +284,10 @@ def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8):
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}"
mask = tir.const((1 << nbit) - 1, "int32")
unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask
mask = tir.const((1 << nbit) - 1, T.int32)
unextended = (val >> (pos.astype(T.int32) * tir.const(nbit, T.int32))) & mask
return tir.Cast(
dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32"))
dtype, (unextended << tir.const(32 - nbit, T.int32)) >> tir.const(32 - nbit, T.int32))
return f_convert
......
......@@ -2,6 +2,7 @@ from dataclasses import dataclass
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang import language as T
from tilelang.utils.language import is_shared, is_fragment
from tilelang.tileop.base import GemmWarpPolicy
from tvm.ir.base import Node
......@@ -121,7 +122,7 @@ class GemmBase:
@property
def mbarptr(self) -> PrimExpr:
return getattr(self.gemm_node, "mbarPtr", tvm.tir.const(0, "uint32"))
return getattr(self.gemm_node, "mbarPtr", tvm.tir.const(0, T.uint32))
@property
def mbar(self) -> tir.Buffer:
......@@ -131,7 +132,7 @@ class GemmBase:
def C_coords(self):
coords = getattr(self.gemm_node, "cCoords", None)
if coords is None or len(coords) == 0:
zero = tvm.tir.const(0, "int32")
zero = tvm.tir.const(0, T.int32)
return [zero, zero]
return [coords[i] for i in range(len(coords))]
......
......@@ -98,7 +98,7 @@ class GemmTCGEN5(GemmBase):
raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access")
accum_dtype = str(self.C.dtype)
if accum_dtype not in ["float32", "float16"]:
if accum_dtype not in [str(T.float32), str(T.float16)]:
raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}")
A_shared = self.ARegion
......
......@@ -100,10 +100,10 @@ class PassConfigKey(str, Enum):
such as `dst[i] = f(src[i])`, avoiding implicit aliasing:
```
read = T.allocate([1], "int32", "local.var")
write = T.allocate([1], "int32", "local.var")
read_buf = T.Buffer((1,), "int32", data=read, scope="local.var")
write_buf = T.Buffer((1,), "int32", data=write, scope="local.var")
read = T.allocate([1], T.int32, "local.var")
write = T.allocate([1], T.int32, "local.var")
read_buf = T.Buffer((1,), T.int32, data=read, scope="local.var")
write_buf = T.Buffer((1,), T.int32, data=write, scope="local.var")
write_buf[0] = read_buf[0] * 2
f(write_buf[0])
```
......@@ -113,8 +113,8 @@ class PassConfigKey(str, Enum):
like:
```
read = T.allocate([1], "int32", "local.var")
read_buf = T.Buffer((1,), "int32", data=read, scope="local.var")
read = T.allocate([1], T.int32, "local.var")
read_buf = T.Buffer((1,), T.int32, data=read, scope="local.var")
read_buf[0] = read_buf[0] * 2
f(read_buf[0])
```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment