Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
......@@ -479,9 +479,9 @@ def tilelang_sparse_attention(
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, seq_len, head_kv, selected_blocks]
block_indices_dtype = "int32"
dtype = "float16"
accum_dtype = "float"
block_indices_dtype = T.int32
dtype = T.float16
accum_dtype = T.float32
block_S = block_size
block_T = min(block_T, tilelang.math.next_power_of_2(dim))
......@@ -876,7 +876,7 @@ if __name__ == "__main__":
parser.add_argument("--dim", type=int, default=128, help="Head dimension")
parser.add_argument("--selected_blocks", type=int, default=16, help="Number of selected blocks")
parser.add_argument("--block_size", type=int, default=32, help="Block size")
parser.add_argument("--dtype", type=str, default="float16", help="Data type (float16 or float32)")
parser.add_argument("--dtype", type=str, default=T.float16, help="Data type (float16 or float32)")
parser.add_argument("--scale", type=float, default=0.1, help="Attention scale factor")
parser.add_argument("--iterations", type=int, default=100, help="Number of iterations")
parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations")
......@@ -901,7 +901,7 @@ if __name__ == "__main__":
if args.suite:
run_benchmark_suite(impl=args.impl)
else:
dtype = torch.float16 if args.dtype == "float16" else torch.float32
dtype = torch.float16 if args.dtype == T.float16 else torch.float32
if args.impl in ["tilelang", "all"]:
print("Benchmarking TileLang implementation:")
......
......@@ -49,9 +49,9 @@ def tilelang_kernel_fwd(
o_slc_shape = [batch, seq_len, heads, dim]
lse_slc_shape = [batch, seq_len, heads]
block_indices_shape = [batch, seq_len, head_kv, selected_blocks]
block_indices_dtype = "int32"
dtype = "float16"
accum_dtype = "float"
block_indices_dtype = T.int32
dtype = T.float16
accum_dtype = T.float32
block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim))
......@@ -170,8 +170,8 @@ def tilelang_kernel_bwd_dkv(
block_size=64,
groups=1,
selected_blocks=16,
dtype="float16",
accum_dtype="float",
dtype=T.float16,
accum_dtype=T.float32,
):
if scale is None:
sm_scale = (1.0 / dim) ** 0.5
......@@ -217,7 +217,7 @@ def tilelang_kernel_bwd_dkv(
DO_slc: T.Tensor(do_slc_shape, dtype),
DK: T.Tensor(dk_shape, dtype),
DV: T.Tensor(dv_shape, dtype),
BlockMask: T.Tensor(block_mask_shape, "int32"),
BlockMask: T.Tensor(block_mask_shape, T.int32),
):
with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
K_shared = T.alloc_shared([BS, BK], dtype)
......@@ -340,8 +340,8 @@ def tilelang_kernel_bwd_dqkv(
block_size=64,
groups=1,
selected_blocks=16,
dtype="float16",
accum_dtype="float",
dtype=T.float16,
accum_dtype=T.float32,
):
if scale is None:
sm_scale = (1.0 / dim) ** 0.5
......@@ -388,7 +388,7 @@ def tilelang_kernel_bwd_dqkv(
DQ: T.Tensor(dq_shape, dtype),
DK: T.Tensor(dk_shape, dtype),
DV: T.Tensor(dv_shape, dtype),
BlockMask: T.Tensor(block_mask_shape, "int32"),
BlockMask: T.Tensor(block_mask_shape, T.int32),
):
with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
K_shared = T.alloc_shared([BS, BK], dtype)
......@@ -505,8 +505,8 @@ def tilelang_kernel_preprocess(
heads,
seq_len,
dim,
dtype="float16",
accum_dtype="float",
dtype=T.float16,
accum_dtype=T.float32,
blk=32,
):
from tilelang import language as T
......@@ -548,7 +548,7 @@ def tilelang_kernel_block_mask(
seq_len,
selected_blocks,
block_size,
dtype="int32",
dtype=T.int32,
):
from tilelang import language as T
......
......@@ -35,9 +35,9 @@ def native_sparse_attention(
q_shape = [batch, 1, heads, dim] # Changed seq_len to 1
kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, 1, head_kv, selected_blocks] # Changed seq_len to 1
block_indices_dtype = "int32"
dtype = "float16"
accum_dtype = "float"
block_indices_dtype = T.int32
dtype = T.float16
accum_dtype = T.float32
block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim))
......
......@@ -26,9 +26,9 @@ def native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale=None, b
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, seq_len, head_kv, selected_blocks]
block_indices_dtype = "int32"
dtype = "float16"
accum_dtype = "float"
block_indices_dtype = T.int32
dtype = T.float16
accum_dtype = T.float32
block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim))
......
......@@ -38,12 +38,12 @@ def native_sparse_attention_varlen(batch, heads, c_seq_len, dim, is_causal, scal
block_counts_shape = [c_seq_len, head_kv]
offsets_shape = [batch + 1]
token_indices_shape = [c_seq_len, 2]
block_indices_dtype = "int32"
block_counts_dtype = "int32"
offsets_dtype = "int32"
token_indices_dtype = "int32"
dtype = "float16"
accum_dtype = "float"
block_indices_dtype = T.int32
block_counts_dtype = T.int32
offsets_dtype = T.int32
token_indices_dtype = T.int32
dtype = T.float16
accum_dtype = T.float32
block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim))
......
......@@ -97,9 +97,9 @@ def mqa_attn_return_logits(
):
if block_Q is None:
block_Q = 128 // heads
dtype = "float8_e4m3"
accum_dtype = "float"
index_dtype = "int32"
dtype = T.float8_e4m3fn
accum_dtype = T.float32
index_dtype = T.int32
seq_len = T.dynamic("seq_len")
seq_len_kv = T.dynamic("seq_len_kv")
......@@ -178,8 +178,8 @@ def clean_logits_(
seq_len = T.dynamic("seq_len")
seq_len_kv = T.dynamic("seq_len_kv")
dtype = "float"
indices_dtype = "int32"
dtype = T.float
indices_dtype = T.int32
@T.prim_func
def clean_logits_kernel(
......
......@@ -11,21 +11,21 @@ pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
}
FP8 = "float8_e4m3"
BF16 = "bfloat16"
FP32 = "float32"
FP8 = T.float8_e4m3fn
BF16 = T.bfloat16
FP32 = T.float32
def fast_log2_ceil(x):
bits_x = T.reinterpret("uint32", x)
bits_x = T.reinterpret(T.uint32, x)
exp_x = (bits_x >> 23) & 0xFF
man_bits = bits_x & ((1 << 23) - 1)
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
return T.Cast(T.int32, exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
def fast_pow2(x):
bits_x = (x + 127) << 23
return T.reinterpret("float32", bits_x)
return T.reinterpret(T.float32, bits_x)
def fast_round_scale(amax, fp8_max_inv):
......@@ -107,8 +107,8 @@ def act_quant(x: torch.Tensor,
@tilelang.jit(pass_configs=pass_configs)
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"):
assert out_dtype in [BF16, "float32"]
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=T.float32):
assert out_dtype in [BF16, T.float32]
M = T.dynamic("M")
group_size = 128
......
......@@ -13,11 +13,11 @@ def preprocess(
D,
block_ND=32,
num_stages=5,
dtype="bfloat16",
accum_dtype="float",
dtype=T.bfloat16,
accum_dtype=T.float32,
):
assert dtype == "bfloat16"
assert accum_dtype == "float"
assert dtype == T.bfloat16
assert accum_dtype == T.float32
shape = [B, S, H, D]
@T.prim_func
......@@ -52,11 +52,11 @@ def postprocess(
kv_group=1,
block_N=64,
threads=128,
dtype="bfloat16",
accum_dtype="float",
dtype=T.bfloat16,
accum_dtype=T.float32,
):
assert dtype == "bfloat16"
assert accum_dtype == "float"
assert dtype == T.bfloat16
assert accum_dtype == T.float32
dkv_shape = [B, S_kv, kv_group, D + D_tail]
@T.prim_func
......@@ -95,15 +95,15 @@ def bwd(
block_size=32,
num_stages=0,
threads=256,
indices_dtype="int32",
dtype="bfloat16",
accum_dtype="float",
indices_dtype=T.int32,
dtype=T.bfloat16,
accum_dtype=T.float32,
):
assert is_causal == True, "non-casual is not supported now"
assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
assert dtype == "bfloat16"
assert accum_dtype == "float"
assert indices_dtype == "int32"
assert dtype == T.bfloat16
assert accum_dtype == T.float32
assert indices_dtype == T.int32
if sm_scale is None:
sm_scale = (D + D_tail) ** (-0.5)
......@@ -116,9 +116,9 @@ def bwd(
indices_shape = [B, S, kv_group, topk]
delta_shape = [B, S, H]
lse_shape = [B, S, H]
assert indices_dtype == "int32"
assert dtype == "bfloat16"
assert accum_dtype == "float"
assert indices_dtype == T.int32
assert dtype == T.bfloat16
assert accum_dtype == T.float32
H = H_kv
padded_H = max(tilelang.math.next_power_of_2(H_kv), 16)
......
......@@ -44,9 +44,9 @@ def sparse_mla_fwd(
o_shape = [batch, seq_len, heads, dim]
indices_shape = [batch, seq_len, kv_group, topk]
lse_shape = [batch, seq_len, heads]
indices_dtype = "int32"
dtype = "bfloat16"
accum_dtype = "float"
indices_dtype = T.int32
dtype = T.bfloat16
accum_dtype = T.float32
G = kv_group
H = head_kv
......
......@@ -53,9 +53,9 @@ def sparse_mla_fwd(
o_shape = [batch, seq_len, heads, dim]
indices_shape = [batch, seq_len, kv_group, topk]
lse_shape = [batch, seq_len, heads]
indices_dtype = "int32"
dtype = "bfloat16"
accum_dtype = "float"
indices_dtype = T.int32
dtype = T.bfloat16
accum_dtype = T.float32
G = kv_group
H = head_kv
......
......@@ -8,24 +8,24 @@ pass_configs = {
def convert_to_uint16(x):
hval = T.Cast("float16", x)
bits_uint = T.reinterpret("uint16", hval)
hval = T.Cast(T.float16, x)
bits_uint = T.reinterpret(T.uint16, hval)
bits_uint = T.if_then_else(x < 0, ~bits_uint & (0xFFFF), bits_uint | (0x8000))
return bits_uint >> 8
def convert_to_uint32(x):
bits_uint = T.reinterpret("uint32", x)
bits_uint = T.reinterpret(T.uint32, x)
bits_uint = T.if_then_else(
x < 0,
~bits_uint & T.Cast("uint32", (0xFFFFFFFF)),
bits_uint | T.Cast("uint32", (0x80000000)),
~bits_uint & T.Cast(T.uint32, (0xFFFFFFFF)),
bits_uint | T.Cast(T.uint32, (0x80000000)),
)
return bits_uint
@tilelang.jit(pass_configs=pass_configs)
def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
def tl_topk_impl(topk, in_dtype=T.float32, out_dtype=T.int32):
batch = T.dynamic("batch")
seq_len = T.dynamic("seq_len")
RADIX = 1 << 8
......@@ -42,20 +42,20 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
with T.Kernel(batch, threads=BLOCK_SIZE) as (bx):
tx = T.get_thread_binding()
s_threshold_bin_id = T.alloc_shared([1], "int32")
s_histogram = T.alloc_shared([RADIX + 1], "int32")
s_num_input = T.alloc_shared([2], "int32")
s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], "int32")
l_threshold_bin_id = T.alloc_var("int32")
l_new_topk = T.alloc_var("int32")
l_num_input = T.alloc_var("int32")
l_bin_id32 = T.alloc_var("int32")
l_val = T.alloc_var("int32")
l_start_pos = T.alloc_var("int32")
l_start_idx = T.alloc_var("int32")
l_end_idx = T.alloc_var("int32")
l_out_pos = T.alloc_var("int32")
s_threshold_bin_id = T.alloc_shared([1], T.int32)
s_histogram = T.alloc_shared([RADIX + 1], T.int32)
s_num_input = T.alloc_shared([2], T.int32)
s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], T.int32)
l_threshold_bin_id = T.alloc_var(T.int32)
l_new_topk = T.alloc_var(T.int32)
l_num_input = T.alloc_var(T.int32)
l_bin_id32 = T.alloc_var(T.int32)
l_val = T.alloc_var(T.int32)
l_start_pos = T.alloc_var(T.int32)
l_start_idx = T.alloc_var(T.int32)
l_end_idx = T.alloc_var(T.int32)
l_out_pos = T.alloc_var(T.int32)
l_new_topk = topk
l_start_idx = starts[bx]
......@@ -99,7 +99,7 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
input_idx = s * BLOCK_SIZE + tx
if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
bin_id = convert_to_uint16(input[bx, input_idx])
l_bin_id32 = T.Cast("int32", bin_id)
l_bin_id32 = T.Cast(T.int32, bin_id)
if l_bin_id32 > l_threshold_bin_id:
# need a pos = T.atomic_add(s_histogram[bin_id32+1], 1)
pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True)
......@@ -128,7 +128,7 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)):
if s * BLOCK_SIZE + tx < l_num_input:
l_bin_id32 = T.Cast(
"int32", ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF)
T.int32, ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF)
)
T.atomic_add(s_histogram[l_bin_id32], 1)
T.sync_threads()
......@@ -157,7 +157,7 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
T.sync_threads()
if s * BLOCK_SIZE + tx < l_num_input:
l_bin_id32 = T.Cast(
"int32", ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF)
T.int32, ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF)
)
if l_bin_id32 > l_threshold_bin_id:
pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
......
......@@ -50,7 +50,7 @@ def matmul(
in_dtype,
out_dtype,
accum_dtype,
source_format="uint",
source_format=T.uint32,
num_bits=4,
fast_dequant=True,
block_M=256,
......@@ -90,7 +90,7 @@ def matmul(
A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel.
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
storage_dtype = T.uint8
QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
......@@ -121,7 +121,7 @@ def matmul(
assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16):
"""
Create a TileLang macro that performs fast, twiddling-based dequantization from packed FP4 to BF16 using an external runtime plugin.
......@@ -131,13 +131,13 @@ def matmul(
- Writes the dequantized BF16 values back to a shared dequantized buffer for use by the kernel.
Notes and preconditions:
- Asserts that `in_dtype == "fp4"` and `out_dtype == "bfloat16"`.
- Asserts that `in_dtype == "fp4"` and `out_dtype == T.bfloat16`.
- The generated macro depends on several surrounding-scope symbols (e.g., `import_source`, `func_name`, `block_K`, `Block_QK`, `threads`, `num_elems_per_byte`, `storage_dtype`, and `out_dtype`) and expects them to be defined consistently in the enclosing kernel.
- The macro is optimized for block-wise, per-thread transactions sized to the target storage width (uses a MAX_TRANSACTION_SIZE_BITS constant) and uses local/register buffers sized accordingly.
- The macro uses `T.import_source` to bring the external plugin into the module and `T.call_extern` to perform the high-throughput dequantization; callers must ensure the external function matches the expected calling convention and memory layout.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
assert out_dtype in [T.bfloat16]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
......@@ -193,7 +193,7 @@ def matmul(
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16):
"""
Create a simple TIR dequantization macro that converts packed 4-bit FP (FP4) stored in uint8 into bfloat16.
......@@ -204,7 +204,7 @@ def matmul(
- Writes the dequantized bfloat16 block into B_dequantize_shared.
Constraints:
- Supports only in_dtype="fp4" and out_dtype="bfloat16".
- Supports only in_dtype="fp4" and out_dtype=T.bfloat16.
- The helper assumes nbit == 4 and produces bfloat16 values.
- The macro uses a fixed test-scale of 0 (no per-element scaling) as written.
......@@ -212,7 +212,7 @@ def matmul(
A TIR macro function performing the described in-place block dequantization from packed uint8 FP4 to bfloat16.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
assert out_dtype in [T.bfloat16]
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str):
"""
......@@ -228,32 +228,32 @@ def matmul(
val (tir.PrimExpr): A uint8 value containing packed FP4 elements.
pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract.
scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16.
dtype (str): Target dtype string; must be "bfloat16".
dtype (str): Target dtype string; must be T.bfloat16.
Returns:
tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value.
Notes:
- The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8".
- The function asserts `nbit == 4`, `dtype == T.bfloat16`, and that `val.dtype` is T.uint8.
- The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16
bit fields and clamps the computed exponent to fit into 8 bits.
"""
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
assert dtype == T.bfloat16
assert val.dtype == T.uint8
mask = tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, T.uint16)
e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16)
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
e_bf16 = e_f4 + tir.const(126, T.uint16)
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we use the max function to limit the exponential part to 8 bits
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, T.uint16))
m_f4 = f4 & tir.const(1, T.uint16)
val_bf16 = tir.reinterpret(
"bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"),
T.bfloat16,
((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16),
)
return val_bf16
......@@ -364,7 +364,7 @@ def ref_program_twiddling(A, qB):
Returns:
torch.Tensor: Result matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert_bit_twiddling(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
......@@ -384,7 +384,7 @@ def ref_program_simple(A, qB):
Returns:
torch.Tensor: Resulting matrix C in bfloat16 with shape (M, N).
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
......@@ -410,15 +410,15 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False):
"""
total_flops = 2 * m * n * k
if tune:
kernel = matmul(m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant)
kernel = matmul(m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, fast_dequant=fast_dequant)
else:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
T.bfloat16,
T.bfloat16,
T.float32,
num_bits=4,
fast_dequant=fast_dequant,
block_M=256,
......
......@@ -20,31 +20,31 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be "bfloat16").
dtype (str): Destination dtype string (must be T.bfloat16).
Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8".
- Preconditions are enforced via assertions: nbit == 4, dtype == T.bfloat16, and val.dtype == T.uint8.
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
"""
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
assert dtype == T.bfloat16
assert val.dtype == T.uint8
mask = tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, T.uint16)
e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16)
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
e_bf16 = e_f4 + tir.const(126, T.uint16)
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we may use the min function to limit the exponential part to 8 bits
# e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
m_f4 = f4 & tir.const(1, T.uint16)
val_bf16 = tir.reinterpret(
"bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"),
T.bfloat16,
((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16),
)
return val_bf16
......@@ -90,7 +90,7 @@ def matmul(
in_dtype,
out_dtype,
accum_dtype,
source_format="uint",
source_format=T.uint32,
num_bits=4,
scale_size=32,
fast_dequant=True,
......@@ -116,7 +116,7 @@ def matmul(
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
out_dtype (str): output tensor element type (e.g., T.bfloat16).
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
......@@ -141,7 +141,7 @@ def matmul(
- An assertion enforces that K % (block_K * split) == 0.
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
storage_dtype = T.uint8
QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
A_shape = (M, K)
......@@ -170,7 +170,7 @@ def matmul(
assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16):
"""
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
......@@ -181,12 +181,12 @@ def matmul(
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
assert out_dtype in [T.bfloat16]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
......@@ -262,19 +262,19 @@ def matmul(
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16):
"""
Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16.
Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared.
Notes:
- Only supports in_dtype="fp4" and out_dtype="bfloat16".
- Only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion.
- Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
assert out_dtype in [T.bfloat16]
@T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k):
......@@ -394,7 +394,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert_bit_twiddling(qB)
B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
......@@ -417,7 +417,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert_bit_twiddling(qB)
B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
......@@ -441,7 +441,7 @@ def ref_program_simple(A, qB, Scale, Bias=None):
No in-place modification is performed on inputs (a local floating copy of B is scaled).
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert(qB)
B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
......@@ -469,7 +469,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
No in-place modification is performed on inputs (a local floating copy of B is scaled).
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert(qB)
B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
......@@ -498,16 +498,16 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
if tune:
kernel = matmul(
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias
m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias
)
else:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
T.bfloat16,
T.bfloat16,
T.float32,
num_bits=4,
scale_size=scale_size,
block_M=256,
......
......@@ -20,31 +20,31 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be "bfloat16").
dtype (str): Destination dtype string (must be T.bfloat16).
Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8".
- Preconditions are enforced via assertions: nbit == 4, dtype == T.bfloat16, and val.dtype == T.uint8.
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
"""
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
assert dtype == T.bfloat16
assert val.dtype == T.uint8
mask = tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, T.uint16)
e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16)
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
e_bf16 = e_f4 + tir.const(126, T.uint16)
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we may use the min function to limit the exponential part to 8 bits
# e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
m_f4 = f4 & tir.const(1, T.uint16)
val_bf16 = tir.reinterpret(
"bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"),
T.bfloat16,
((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16),
)
return val_bf16
......@@ -90,7 +90,7 @@ def matmul(
in_dtype,
out_dtype,
accum_dtype,
source_format="uint",
source_format=T.uint32,
num_bits=4,
scale_size=32,
fast_dequant=True,
......@@ -116,7 +116,7 @@ def matmul(
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
out_dtype (str): output tensor element type (e.g., T.bfloat16).
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
......@@ -141,7 +141,7 @@ def matmul(
- An assertion enforces that K % (block_K * split) == 0.
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
storage_dtype = T.uint8
QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
A_shape = (M, K)
......@@ -170,7 +170,7 @@ def matmul(
assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16):
"""
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
......@@ -181,12 +181,12 @@ def matmul(
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
assert out_dtype in [T.bfloat16]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
......@@ -262,19 +262,19 @@ def matmul(
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16):
"""
Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16.
Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared.
Notes:
- Only supports in_dtype="fp4" and out_dtype="bfloat16".
- Only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion.
- Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
assert out_dtype in [T.bfloat16]
@T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k):
......@@ -402,7 +402,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
......@@ -427,7 +427,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
......@@ -453,7 +453,7 @@ def ref_program_simple(A, qB, Scale, Bias=None):
No in-place modification is performed on inputs (a local floating copy of B is scaled).
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
......@@ -483,7 +483,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
No in-place modification is performed on inputs (a local floating copy of B is scaled).
"""
dtypeC = "bfloat16"
dtypeC = T.bfloat16
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
......@@ -514,16 +514,16 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
if tune:
kernel = matmul(
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias
m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias
)
else:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
T.bfloat16,
T.bfloat16,
T.float32,
num_bits=4,
scale_size=scale_size,
block_M=256,
......
......@@ -26,7 +26,7 @@ def matmul(
from tilelang.quantize import _tir_packed_to_unsigned_convert
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
storage_dtype = T.int8
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
storage_type = str("".join(c for c in storage_dtype if not c.isdigit()))
A_shape = (M, K)
......@@ -149,21 +149,21 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
from bitblas.gpu.intrin.lop3 import decode_i4_to_f16
assert in_dtype in [
"float16",
"int8",
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
num_bits = 4
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
storage_dtype = T.int8
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
if out_dtype == T.int32:
micro_size_k = 32
# This is a debug config
......@@ -182,7 +182,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = 32 if in_dtype == "float16" else 64
block_K = 32 if in_dtype == T.float16 else 64
chunk = block_K // reduce_k
is_smooth_a = False
......@@ -365,7 +365,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
assert src_code is not None
num_bits = 4
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
storage_dtype = T.int8
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
qB = torch.randint(0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype))
......@@ -417,13 +417,13 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
@tilelang.testing.requires_package("bitblas")
def test_run_dequantize_gemm():
run_gemm(256, 256, 256, "float16", "float16", "float16", 128, 128, 32, num_threads=128)
run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128)
run_gemm(256, 256, 256, T.float16, T.float16, T.float16, 128, 128, 32, num_threads=128)
run_gemm(256, 256, 256, T.int8, T.int32, T.int32, 128, 128, 32, num_threads=128)
@tilelang.testing.requires_package("bitblas")
def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(256, 1024, 512, "float16", "float16", "float16", 3)
assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(256, 1024, 512, T.float16, T.float16, T.float16, 3)
def main():
......
......@@ -9,22 +9,22 @@ import argparse
def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "float16"
assert val.dtype == "uint8"
assert dtype == T.float16
assert val.dtype == T.uint8
# e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + ExponentialBias(f16, f4) = e_f4 + (2^4 - 2^1) = e_f4 + 14
# s1e2m1
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
e_f16 = e_f4 + tir.const(14, "uint16")
m_f4 = f4 & tir.const(1, "uint16")
mask = tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, T.uint16)
e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16)
e_f16 = e_f4 + tir.const(14, T.uint16)
m_f4 = f4 & tir.const(1, T.uint16)
m_f16 = m_f4
val_f16 = tir.reinterpret(
"float16", ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16") | m_f16 << tir.const(9, "uint16")).astype("uint16")
T.float16, ((e_f16 | (s << tir.const(5, T.uint16))) << tir.const(10, T.uint16) | m_f16 << tir.const(9, T.uint16)).astype(T.uint16)
)
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, T.float16), val_f16)
return val_f16
......@@ -60,7 +60,7 @@ def torch_convert(tensor):
@tilelang.jit(out_idx=[1])
def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
storage_dtype = T.uint8
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
......@@ -98,7 +98,7 @@ def test_fp4_fp16_convert_close():
K,
block_N,
block_K,
"float16",
T.float16,
)
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
......@@ -125,7 +125,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
storage_dtype = T.uint8
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
......@@ -241,7 +241,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
def ref_program(A, qB):
dtypeC = "float16"
dtypeC = T.float16
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
......@@ -252,7 +252,7 @@ def main(m=256, n=256, k=256, tune=False):
total_flops = 2 * m * n * k
if not tune:
kernel = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)(
kernel = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
......@@ -265,7 +265,7 @@ def main(m=256, n=256, k=256, tune=False):
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_result = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)
best_result = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
print(f"Best latency: {best_latency}")
......
......@@ -9,15 +9,15 @@ import argparse
def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "int8"
assert val.dtype == "uint8"
assert dtype == T.int8
assert val.dtype == T.uint8
mask = tir.const((1 << nbit) - 1, "uint8")
mask = tir.const((1 << nbit) - 1, T.uint8)
i4 = (val >> (pos.astype("uint8") * tir.const(nbit, "uint8"))) & mask
i4 = (val >> (pos.astype(T.uint8) * tir.const(nbit, T.uint8))) & mask
i8_shifted = tir.reinterpret("int8", i4 << tir.const(4, "uint8"))
i8 = i8_shifted >> tir.const(4, "int8")
i8_shifted = tir.reinterpret(T.int8, i4 << tir.const(4, T.uint8))
i8 = i8_shifted >> tir.const(4, T.int8)
return i8
......@@ -35,7 +35,7 @@ def get_configs():
@tilelang.jit(out_idx=[1])
def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
storage_dtype = T.uint8
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
......@@ -85,7 +85,7 @@ def torch_convert(tensor):
def ref_program(A, qB):
dtypeC = "int32"
dtypeC = T.int32
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
......@@ -96,7 +96,7 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune
@tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
storage_dtype = T.uint8
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
......@@ -166,7 +166,7 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune
def main(m=128, n=256, k=256, tune=False):
total_flops = 2 * m * n * k
if not tune:
kernel = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)(
kernel = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune)(
block_M=32, block_N=32, block_K=128, num_stages=1, threads=128
)
profiler = kernel.get_profiler()
......@@ -177,7 +177,7 @@ def main(m=128, n=256, k=256, tune=False):
print(f"Tilelang: {latency} ms")
else:
best_result = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)
best_result = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune)
best_latency = best_result.latency
best_config = best_result.config
print(f"Bset latency: {best_latency}")
......
......@@ -17,7 +17,7 @@ def dequantize_gemv(
out_dtype: str,
accum_dtype: str,
num_bits: int = 4,
storage_dtype: str = "int8",
storage_dtype: T.dtype = T.int8,
source_format: str = "uint",
n_partition: int = 4,
reduce_thread: int = 32,
......@@ -51,7 +51,7 @@ def dequantize_gemv(
C_shape = (M, N)
dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32
import_source: Optional[str] = None
func_name: str = ""
......@@ -159,11 +159,11 @@ def main() -> None:
M = 1
N = 1024
K = 1024
in_dtype = "float16"
out_dtype = "float16"
accum_dtype = "float16"
in_dtype = T.float16
out_dtype = T.float16
accum_dtype = T.float16
num_bits = 4
storage_dtype = "int8"
storage_dtype = T.int8
source_format = "uint"
n_partition = 4
reduce_thread = 32
......
......@@ -49,7 +49,7 @@ def matmul(
in_dtype,
out_dtype,
accum_dtype,
source_format="uint",
source_format=T.uint32,
num_bits=4,
scale_size=32,
fast_dequant=True,
......@@ -83,8 +83,8 @@ def matmul(
topk (int): number of experts selected per token.
E (int): number of experts.
padding_M (int): padded number of tokens after grouping and block alignment.
in_dtype (str): element type of A (e.g., "bfloat16").
out_dtype (str): output tensor element type (e.g., "bfloat16").
in_dtype (str): element type of A (e.g., T.bfloat16).
out_dtype (str): output tensor element type (e.g., T.bfloat16).
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
......@@ -111,7 +111,7 @@ def matmul(
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
storage_dtype = T.uint8
QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
A_shared_shape = (block_M, block_K)
......@@ -137,7 +137,7 @@ def matmul(
import_source = import_source
# the dequant part is the same as in dequant_gemm
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16):
"""
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and:
......@@ -147,12 +147,12 @@ def matmul(
- Writes the scaled BF16 results into B_dequantize_shared.
Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16".
- This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
"""
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
assert out_dtype in [T.bfloat16]
# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
......@@ -227,9 +227,9 @@ def matmul(
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16):
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
assert out_dtype in [T.bfloat16]
@T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k):
......@@ -259,8 +259,8 @@ def matmul(
Bias: T.Tensor((E, N), out_dtype),
# Add fusedmoe tensors
topk_weights: T.Tensor((M * topk), out_dtype),
sorted_token_ids: T.Tensor((padding_M), "int32"),
expert_ids: T.Tensor((padding_M // block_M), "int32"),
sorted_token_ids: T.Tensor((padding_M), T.int32),
expert_ids: T.Tensor((padding_M // block_M), T.int32),
C: T.Tensor((M, topk, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by):
......@@ -271,8 +271,8 @@ def matmul(
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
topk_weights_shared = T.alloc_shared((block_M), out_dtype)
sorted_token_ids_shared = T.alloc_shared((block_M), "int32")
expert_id = T.alloc_local((1), "int32") # the expert id for the current block
sorted_token_ids_shared = T.alloc_shared((block_M), T.int32)
expert_id = T.alloc_local((1), T.int32) # the expert id for the current block
# To use 1D TMA, the last dim of Scale_shared must have stride=1
# May use much more shared memory than necessary
Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype)
......@@ -346,7 +346,7 @@ def matmul(
def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=256):
dtypeC = "bfloat16"
dtypeC = T.bfloat16
M, K = A.shape
E, N, QK = qB.shape
topk = topk_weights.shape[0] // M
......@@ -451,9 +451,9 @@ def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, wi
topk,
E,
padding_M,
"bfloat16",
"bfloat16",
"float32",
T.bfloat16,
T.bfloat16,
T.float32,
num_bits=num_bits,
scale_size=scale_size,
fast_dequant=fast_dequant,
......@@ -467,9 +467,9 @@ def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, wi
topk,
E,
padding_M,
"bfloat16",
"bfloat16",
"float32",
T.bfloat16,
T.bfloat16,
T.float32,
num_bits=num_bits,
scale_size=scale_size,
fast_dequant=fast_dequant,
......
......@@ -9,9 +9,9 @@ from index import prepare_token_indices
from utils import get_abs_err, get_err_ratio
BF16 = "bfloat16"
FP32 = "float32"
INT32 = "int32"
BF16 = T.bfloat16
FP32 = T.float32
INT32 = T.int32
pass_configs = {
tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment