Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
...@@ -479,9 +479,9 @@ def tilelang_sparse_attention( ...@@ -479,9 +479,9 @@ def tilelang_sparse_attention(
q_shape = [batch, seq_len, heads, dim] q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim] kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, seq_len, head_kv, selected_blocks] block_indices_shape = [batch, seq_len, head_kv, selected_blocks]
block_indices_dtype = "int32" block_indices_dtype = T.int32
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
block_S = block_size block_S = block_size
block_T = min(block_T, tilelang.math.next_power_of_2(dim)) block_T = min(block_T, tilelang.math.next_power_of_2(dim))
...@@ -876,7 +876,7 @@ if __name__ == "__main__": ...@@ -876,7 +876,7 @@ if __name__ == "__main__":
parser.add_argument("--dim", type=int, default=128, help="Head dimension") parser.add_argument("--dim", type=int, default=128, help="Head dimension")
parser.add_argument("--selected_blocks", type=int, default=16, help="Number of selected blocks") parser.add_argument("--selected_blocks", type=int, default=16, help="Number of selected blocks")
parser.add_argument("--block_size", type=int, default=32, help="Block size") parser.add_argument("--block_size", type=int, default=32, help="Block size")
parser.add_argument("--dtype", type=str, default="float16", help="Data type (float16 or float32)") parser.add_argument("--dtype", type=str, default=T.float16, help="Data type (float16 or float32)")
parser.add_argument("--scale", type=float, default=0.1, help="Attention scale factor") parser.add_argument("--scale", type=float, default=0.1, help="Attention scale factor")
parser.add_argument("--iterations", type=int, default=100, help="Number of iterations") parser.add_argument("--iterations", type=int, default=100, help="Number of iterations")
parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations") parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations")
...@@ -901,7 +901,7 @@ if __name__ == "__main__": ...@@ -901,7 +901,7 @@ if __name__ == "__main__":
if args.suite: if args.suite:
run_benchmark_suite(impl=args.impl) run_benchmark_suite(impl=args.impl)
else: else:
dtype = torch.float16 if args.dtype == "float16" else torch.float32 dtype = torch.float16 if args.dtype == T.float16 else torch.float32
if args.impl in ["tilelang", "all"]: if args.impl in ["tilelang", "all"]:
print("Benchmarking TileLang implementation:") print("Benchmarking TileLang implementation:")
......
...@@ -49,9 +49,9 @@ def tilelang_kernel_fwd( ...@@ -49,9 +49,9 @@ def tilelang_kernel_fwd(
o_slc_shape = [batch, seq_len, heads, dim] o_slc_shape = [batch, seq_len, heads, dim]
lse_slc_shape = [batch, seq_len, heads] lse_slc_shape = [batch, seq_len, heads]
block_indices_shape = [batch, seq_len, head_kv, selected_blocks] block_indices_shape = [batch, seq_len, head_kv, selected_blocks]
block_indices_dtype = "int32" block_indices_dtype = T.int32
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
block_S = block_size block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim)) block_T = min(128, tilelang.math.next_power_of_2(dim))
...@@ -170,8 +170,8 @@ def tilelang_kernel_bwd_dkv( ...@@ -170,8 +170,8 @@ def tilelang_kernel_bwd_dkv(
block_size=64, block_size=64,
groups=1, groups=1,
selected_blocks=16, selected_blocks=16,
dtype="float16", dtype=T.float16,
accum_dtype="float", accum_dtype=T.float32,
): ):
if scale is None: if scale is None:
sm_scale = (1.0 / dim) ** 0.5 sm_scale = (1.0 / dim) ** 0.5
...@@ -217,7 +217,7 @@ def tilelang_kernel_bwd_dkv( ...@@ -217,7 +217,7 @@ def tilelang_kernel_bwd_dkv(
DO_slc: T.Tensor(do_slc_shape, dtype), DO_slc: T.Tensor(do_slc_shape, dtype),
DK: T.Tensor(dk_shape, dtype), DK: T.Tensor(dk_shape, dtype),
DV: T.Tensor(dv_shape, dtype), DV: T.Tensor(dv_shape, dtype),
BlockMask: T.Tensor(block_mask_shape, "int32"), BlockMask: T.Tensor(block_mask_shape, T.int32),
): ):
with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
K_shared = T.alloc_shared([BS, BK], dtype) K_shared = T.alloc_shared([BS, BK], dtype)
...@@ -340,8 +340,8 @@ def tilelang_kernel_bwd_dqkv( ...@@ -340,8 +340,8 @@ def tilelang_kernel_bwd_dqkv(
block_size=64, block_size=64,
groups=1, groups=1,
selected_blocks=16, selected_blocks=16,
dtype="float16", dtype=T.float16,
accum_dtype="float", accum_dtype=T.float32,
): ):
if scale is None: if scale is None:
sm_scale = (1.0 / dim) ** 0.5 sm_scale = (1.0 / dim) ** 0.5
...@@ -388,7 +388,7 @@ def tilelang_kernel_bwd_dqkv( ...@@ -388,7 +388,7 @@ def tilelang_kernel_bwd_dqkv(
DQ: T.Tensor(dq_shape, dtype), DQ: T.Tensor(dq_shape, dtype),
DK: T.Tensor(dk_shape, dtype), DK: T.Tensor(dk_shape, dtype),
DV: T.Tensor(dv_shape, dtype), DV: T.Tensor(dv_shape, dtype),
BlockMask: T.Tensor(block_mask_shape, "int32"), BlockMask: T.Tensor(block_mask_shape, T.int32),
): ):
with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
K_shared = T.alloc_shared([BS, BK], dtype) K_shared = T.alloc_shared([BS, BK], dtype)
...@@ -505,8 +505,8 @@ def tilelang_kernel_preprocess( ...@@ -505,8 +505,8 @@ def tilelang_kernel_preprocess(
heads, heads,
seq_len, seq_len,
dim, dim,
dtype="float16", dtype=T.float16,
accum_dtype="float", accum_dtype=T.float32,
blk=32, blk=32,
): ):
from tilelang import language as T from tilelang import language as T
...@@ -548,7 +548,7 @@ def tilelang_kernel_block_mask( ...@@ -548,7 +548,7 @@ def tilelang_kernel_block_mask(
seq_len, seq_len,
selected_blocks, selected_blocks,
block_size, block_size,
dtype="int32", dtype=T.int32,
): ):
from tilelang import language as T from tilelang import language as T
......
...@@ -35,9 +35,9 @@ def native_sparse_attention( ...@@ -35,9 +35,9 @@ def native_sparse_attention(
q_shape = [batch, 1, heads, dim] # Changed seq_len to 1 q_shape = [batch, 1, heads, dim] # Changed seq_len to 1
kv_shape = [batch, seq_len, head_kv, dim] kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, 1, head_kv, selected_blocks] # Changed seq_len to 1 block_indices_shape = [batch, 1, head_kv, selected_blocks] # Changed seq_len to 1
block_indices_dtype = "int32" block_indices_dtype = T.int32
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
block_S = block_size block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim)) block_T = min(128, tilelang.math.next_power_of_2(dim))
......
...@@ -26,9 +26,9 @@ def native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale=None, b ...@@ -26,9 +26,9 @@ def native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale=None, b
q_shape = [batch, seq_len, heads, dim] q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim] kv_shape = [batch, seq_len, head_kv, dim]
block_indices_shape = [batch, seq_len, head_kv, selected_blocks] block_indices_shape = [batch, seq_len, head_kv, selected_blocks]
block_indices_dtype = "int32" block_indices_dtype = T.int32
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
block_S = block_size block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim)) block_T = min(128, tilelang.math.next_power_of_2(dim))
......
...@@ -38,12 +38,12 @@ def native_sparse_attention_varlen(batch, heads, c_seq_len, dim, is_causal, scal ...@@ -38,12 +38,12 @@ def native_sparse_attention_varlen(batch, heads, c_seq_len, dim, is_causal, scal
block_counts_shape = [c_seq_len, head_kv] block_counts_shape = [c_seq_len, head_kv]
offsets_shape = [batch + 1] offsets_shape = [batch + 1]
token_indices_shape = [c_seq_len, 2] token_indices_shape = [c_seq_len, 2]
block_indices_dtype = "int32" block_indices_dtype = T.int32
block_counts_dtype = "int32" block_counts_dtype = T.int32
offsets_dtype = "int32" offsets_dtype = T.int32
token_indices_dtype = "int32" token_indices_dtype = T.int32
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
block_S = block_size block_S = block_size
block_T = min(128, tilelang.math.next_power_of_2(dim)) block_T = min(128, tilelang.math.next_power_of_2(dim))
......
...@@ -97,9 +97,9 @@ def mqa_attn_return_logits( ...@@ -97,9 +97,9 @@ def mqa_attn_return_logits(
): ):
if block_Q is None: if block_Q is None:
block_Q = 128 // heads block_Q = 128 // heads
dtype = "float8_e4m3" dtype = T.float8_e4m3fn
accum_dtype = "float" accum_dtype = T.float32
index_dtype = "int32" index_dtype = T.int32
seq_len = T.dynamic("seq_len") seq_len = T.dynamic("seq_len")
seq_len_kv = T.dynamic("seq_len_kv") seq_len_kv = T.dynamic("seq_len_kv")
...@@ -178,8 +178,8 @@ def clean_logits_( ...@@ -178,8 +178,8 @@ def clean_logits_(
seq_len = T.dynamic("seq_len") seq_len = T.dynamic("seq_len")
seq_len_kv = T.dynamic("seq_len_kv") seq_len_kv = T.dynamic("seq_len_kv")
dtype = "float" dtype = T.float
indices_dtype = "int32" indices_dtype = T.int32
@T.prim_func @T.prim_func
def clean_logits_kernel( def clean_logits_kernel(
......
...@@ -11,21 +11,21 @@ pass_configs = { ...@@ -11,21 +11,21 @@ pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
} }
FP8 = "float8_e4m3" FP8 = T.float8_e4m3fn
BF16 = "bfloat16" BF16 = T.bfloat16
FP32 = "float32" FP32 = T.float32
def fast_log2_ceil(x): def fast_log2_ceil(x):
bits_x = T.reinterpret("uint32", x) bits_x = T.reinterpret(T.uint32, x)
exp_x = (bits_x >> 23) & 0xFF exp_x = (bits_x >> 23) & 0xFF
man_bits = bits_x & ((1 << 23) - 1) man_bits = bits_x & ((1 << 23) - 1)
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) return T.Cast(T.int32, exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
def fast_pow2(x): def fast_pow2(x):
bits_x = (x + 127) << 23 bits_x = (x + 127) << 23
return T.reinterpret("float32", bits_x) return T.reinterpret(T.float32, bits_x)
def fast_round_scale(amax, fp8_max_inv): def fast_round_scale(amax, fp8_max_inv):
...@@ -107,8 +107,8 @@ def act_quant(x: torch.Tensor, ...@@ -107,8 +107,8 @@ def act_quant(x: torch.Tensor,
@tilelang.jit(pass_configs=pass_configs) @tilelang.jit(pass_configs=pass_configs)
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"): def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=T.float32):
assert out_dtype in [BF16, "float32"] assert out_dtype in [BF16, T.float32]
M = T.dynamic("M") M = T.dynamic("M")
group_size = 128 group_size = 128
......
...@@ -13,11 +13,11 @@ def preprocess( ...@@ -13,11 +13,11 @@ def preprocess(
D, D,
block_ND=32, block_ND=32,
num_stages=5, num_stages=5,
dtype="bfloat16", dtype=T.bfloat16,
accum_dtype="float", accum_dtype=T.float32,
): ):
assert dtype == "bfloat16" assert dtype == T.bfloat16
assert accum_dtype == "float" assert accum_dtype == T.float32
shape = [B, S, H, D] shape = [B, S, H, D]
@T.prim_func @T.prim_func
...@@ -52,11 +52,11 @@ def postprocess( ...@@ -52,11 +52,11 @@ def postprocess(
kv_group=1, kv_group=1,
block_N=64, block_N=64,
threads=128, threads=128,
dtype="bfloat16", dtype=T.bfloat16,
accum_dtype="float", accum_dtype=T.float32,
): ):
assert dtype == "bfloat16" assert dtype == T.bfloat16
assert accum_dtype == "float" assert accum_dtype == T.float32
dkv_shape = [B, S_kv, kv_group, D + D_tail] dkv_shape = [B, S_kv, kv_group, D + D_tail]
@T.prim_func @T.prim_func
...@@ -95,15 +95,15 @@ def bwd( ...@@ -95,15 +95,15 @@ def bwd(
block_size=32, block_size=32,
num_stages=0, num_stages=0,
threads=256, threads=256,
indices_dtype="int32", indices_dtype=T.int32,
dtype="bfloat16", dtype=T.bfloat16,
accum_dtype="float", accum_dtype=T.float32,
): ):
assert is_causal == True, "non-casual is not supported now" assert is_causal == True, "non-casual is not supported now"
assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
assert dtype == "bfloat16" assert dtype == T.bfloat16
assert accum_dtype == "float" assert accum_dtype == T.float32
assert indices_dtype == "int32" assert indices_dtype == T.int32
if sm_scale is None: if sm_scale is None:
sm_scale = (D + D_tail) ** (-0.5) sm_scale = (D + D_tail) ** (-0.5)
...@@ -116,9 +116,9 @@ def bwd( ...@@ -116,9 +116,9 @@ def bwd(
indices_shape = [B, S, kv_group, topk] indices_shape = [B, S, kv_group, topk]
delta_shape = [B, S, H] delta_shape = [B, S, H]
lse_shape = [B, S, H] lse_shape = [B, S, H]
assert indices_dtype == "int32" assert indices_dtype == T.int32
assert dtype == "bfloat16" assert dtype == T.bfloat16
assert accum_dtype == "float" assert accum_dtype == T.float32
H = H_kv H = H_kv
padded_H = max(tilelang.math.next_power_of_2(H_kv), 16) padded_H = max(tilelang.math.next_power_of_2(H_kv), 16)
......
...@@ -44,9 +44,9 @@ def sparse_mla_fwd( ...@@ -44,9 +44,9 @@ def sparse_mla_fwd(
o_shape = [batch, seq_len, heads, dim] o_shape = [batch, seq_len, heads, dim]
indices_shape = [batch, seq_len, kv_group, topk] indices_shape = [batch, seq_len, kv_group, topk]
lse_shape = [batch, seq_len, heads] lse_shape = [batch, seq_len, heads]
indices_dtype = "int32" indices_dtype = T.int32
dtype = "bfloat16" dtype = T.bfloat16
accum_dtype = "float" accum_dtype = T.float32
G = kv_group G = kv_group
H = head_kv H = head_kv
......
...@@ -53,9 +53,9 @@ def sparse_mla_fwd( ...@@ -53,9 +53,9 @@ def sparse_mla_fwd(
o_shape = [batch, seq_len, heads, dim] o_shape = [batch, seq_len, heads, dim]
indices_shape = [batch, seq_len, kv_group, topk] indices_shape = [batch, seq_len, kv_group, topk]
lse_shape = [batch, seq_len, heads] lse_shape = [batch, seq_len, heads]
indices_dtype = "int32" indices_dtype = T.int32
dtype = "bfloat16" dtype = T.bfloat16
accum_dtype = "float" accum_dtype = T.float32
G = kv_group G = kv_group
H = head_kv H = head_kv
......
...@@ -8,24 +8,24 @@ pass_configs = { ...@@ -8,24 +8,24 @@ pass_configs = {
def convert_to_uint16(x): def convert_to_uint16(x):
hval = T.Cast("float16", x) hval = T.Cast(T.float16, x)
bits_uint = T.reinterpret("uint16", hval) bits_uint = T.reinterpret(T.uint16, hval)
bits_uint = T.if_then_else(x < 0, ~bits_uint & (0xFFFF), bits_uint | (0x8000)) bits_uint = T.if_then_else(x < 0, ~bits_uint & (0xFFFF), bits_uint | (0x8000))
return bits_uint >> 8 return bits_uint >> 8
def convert_to_uint32(x): def convert_to_uint32(x):
bits_uint = T.reinterpret("uint32", x) bits_uint = T.reinterpret(T.uint32, x)
bits_uint = T.if_then_else( bits_uint = T.if_then_else(
x < 0, x < 0,
~bits_uint & T.Cast("uint32", (0xFFFFFFFF)), ~bits_uint & T.Cast(T.uint32, (0xFFFFFFFF)),
bits_uint | T.Cast("uint32", (0x80000000)), bits_uint | T.Cast(T.uint32, (0x80000000)),
) )
return bits_uint return bits_uint
@tilelang.jit(pass_configs=pass_configs) @tilelang.jit(pass_configs=pass_configs)
def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): def tl_topk_impl(topk, in_dtype=T.float32, out_dtype=T.int32):
batch = T.dynamic("batch") batch = T.dynamic("batch")
seq_len = T.dynamic("seq_len") seq_len = T.dynamic("seq_len")
RADIX = 1 << 8 RADIX = 1 << 8
...@@ -42,20 +42,20 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): ...@@ -42,20 +42,20 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
with T.Kernel(batch, threads=BLOCK_SIZE) as (bx): with T.Kernel(batch, threads=BLOCK_SIZE) as (bx):
tx = T.get_thread_binding() tx = T.get_thread_binding()
s_threshold_bin_id = T.alloc_shared([1], "int32") s_threshold_bin_id = T.alloc_shared([1], T.int32)
s_histogram = T.alloc_shared([RADIX + 1], "int32") s_histogram = T.alloc_shared([RADIX + 1], T.int32)
s_num_input = T.alloc_shared([2], "int32") s_num_input = T.alloc_shared([2], T.int32)
s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], "int32") s_input_idx = T.alloc_shared([2, SMEM_INPUT_SIZE], T.int32)
l_threshold_bin_id = T.alloc_var("int32") l_threshold_bin_id = T.alloc_var(T.int32)
l_new_topk = T.alloc_var("int32") l_new_topk = T.alloc_var(T.int32)
l_num_input = T.alloc_var("int32") l_num_input = T.alloc_var(T.int32)
l_bin_id32 = T.alloc_var("int32") l_bin_id32 = T.alloc_var(T.int32)
l_val = T.alloc_var("int32") l_val = T.alloc_var(T.int32)
l_start_pos = T.alloc_var("int32") l_start_pos = T.alloc_var(T.int32)
l_start_idx = T.alloc_var("int32") l_start_idx = T.alloc_var(T.int32)
l_end_idx = T.alloc_var("int32") l_end_idx = T.alloc_var(T.int32)
l_out_pos = T.alloc_var("int32") l_out_pos = T.alloc_var(T.int32)
l_new_topk = topk l_new_topk = topk
l_start_idx = starts[bx] l_start_idx = starts[bx]
...@@ -99,7 +99,7 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): ...@@ -99,7 +99,7 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
input_idx = s * BLOCK_SIZE + tx input_idx = s * BLOCK_SIZE + tx
if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
bin_id = convert_to_uint16(input[bx, input_idx]) bin_id = convert_to_uint16(input[bx, input_idx])
l_bin_id32 = T.Cast("int32", bin_id) l_bin_id32 = T.Cast(T.int32, bin_id)
if l_bin_id32 > l_threshold_bin_id: if l_bin_id32 > l_threshold_bin_id:
# need a pos = T.atomic_add(s_histogram[bin_id32+1], 1) # need a pos = T.atomic_add(s_histogram[bin_id32+1], 1)
pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True)
...@@ -128,7 +128,7 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): ...@@ -128,7 +128,7 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)):
if s * BLOCK_SIZE + tx < l_num_input: if s * BLOCK_SIZE + tx < l_num_input:
l_bin_id32 = T.Cast( l_bin_id32 = T.Cast(
"int32", ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) T.int32, ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF)
) )
T.atomic_add(s_histogram[l_bin_id32], 1) T.atomic_add(s_histogram[l_bin_id32], 1)
T.sync_threads() T.sync_threads()
...@@ -157,7 +157,7 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): ...@@ -157,7 +157,7 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
T.sync_threads() T.sync_threads()
if s * BLOCK_SIZE + tx < l_num_input: if s * BLOCK_SIZE + tx < l_num_input:
l_bin_id32 = T.Cast( l_bin_id32 = T.Cast(
"int32", ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) T.int32, ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF)
) )
if l_bin_id32 > l_threshold_bin_id: if l_bin_id32 > l_threshold_bin_id:
pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
......
...@@ -50,7 +50,7 @@ def matmul( ...@@ -50,7 +50,7 @@ def matmul(
in_dtype, in_dtype,
out_dtype, out_dtype,
accum_dtype, accum_dtype,
source_format="uint", source_format=T.uint32,
num_bits=4, num_bits=4,
fast_dequant=True, fast_dequant=True,
block_M=256, block_M=256,
...@@ -90,7 +90,7 @@ def matmul( ...@@ -90,7 +90,7 @@ def matmul(
A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel. A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel.
""" """
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = T.uint8
QK = K // num_elems_per_byte QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte Block_QK = block_K // num_elems_per_byte
...@@ -121,7 +121,7 @@ def matmul( ...@@ -121,7 +121,7 @@ def matmul(
assert func_name is not None, "mxfp_intrin_info is not found" assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source import_source = import_source
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16):
""" """
Create a TileLang macro that performs fast, twiddling-based dequantization from packed FP4 to BF16 using an external runtime plugin. Create a TileLang macro that performs fast, twiddling-based dequantization from packed FP4 to BF16 using an external runtime plugin.
...@@ -131,13 +131,13 @@ def matmul( ...@@ -131,13 +131,13 @@ def matmul(
- Writes the dequantized BF16 values back to a shared dequantized buffer for use by the kernel. - Writes the dequantized BF16 values back to a shared dequantized buffer for use by the kernel.
Notes and preconditions: Notes and preconditions:
- Asserts that `in_dtype == "fp4"` and `out_dtype == "bfloat16"`. - Asserts that `in_dtype == "fp4"` and `out_dtype == T.bfloat16`.
- The generated macro depends on several surrounding-scope symbols (e.g., `import_source`, `func_name`, `block_K`, `Block_QK`, `threads`, `num_elems_per_byte`, `storage_dtype`, and `out_dtype`) and expects them to be defined consistently in the enclosing kernel. - The generated macro depends on several surrounding-scope symbols (e.g., `import_source`, `func_name`, `block_K`, `Block_QK`, `threads`, `num_elems_per_byte`, `storage_dtype`, and `out_dtype`) and expects them to be defined consistently in the enclosing kernel.
- The macro is optimized for block-wise, per-thread transactions sized to the target storage width (uses a MAX_TRANSACTION_SIZE_BITS constant) and uses local/register buffers sized accordingly. - The macro is optimized for block-wise, per-thread transactions sized to the target storage width (uses a MAX_TRANSACTION_SIZE_BITS constant) and uses local/register buffers sized accordingly.
- The macro uses `T.import_source` to bring the external plugin into the module and `T.call_extern` to perform the high-throughput dequantization; callers must ensure the external function matches the expected calling convention and memory layout. - The macro uses `T.import_source` to bring the external plugin into the module and `T.call_extern` to perform the high-throughput dequantization; callers must ensure the external function matches the expected calling convention and memory layout.
""" """
assert in_dtype in ["fp4"] assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"] assert out_dtype in [T.bfloat16]
# Some variables for dequantization in each thread # Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128 MAX_TRANSACTION_SIZE_BITS = 128
...@@ -193,7 +193,7 @@ def matmul( ...@@ -193,7 +193,7 @@ def matmul(
return fast_dequant_bf16_fp4_twiddling return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16):
""" """
Create a simple TIR dequantization macro that converts packed 4-bit FP (FP4) stored in uint8 into bfloat16. Create a simple TIR dequantization macro that converts packed 4-bit FP (FP4) stored in uint8 into bfloat16.
...@@ -204,7 +204,7 @@ def matmul( ...@@ -204,7 +204,7 @@ def matmul(
- Writes the dequantized bfloat16 block into B_dequantize_shared. - Writes the dequantized bfloat16 block into B_dequantize_shared.
Constraints: Constraints:
- Supports only in_dtype="fp4" and out_dtype="bfloat16". - Supports only in_dtype="fp4" and out_dtype=T.bfloat16.
- The helper assumes nbit == 4 and produces bfloat16 values. - The helper assumes nbit == 4 and produces bfloat16 values.
- The macro uses a fixed test-scale of 0 (no per-element scaling) as written. - The macro uses a fixed test-scale of 0 (no per-element scaling) as written.
...@@ -212,7 +212,7 @@ def matmul( ...@@ -212,7 +212,7 @@ def matmul(
A TIR macro function performing the described in-place block dequantization from packed uint8 FP4 to bfloat16. A TIR macro function performing the described in-place block dequantization from packed uint8 FP4 to bfloat16.
""" """
assert in_dtype in ["fp4"] assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"] assert out_dtype in [T.bfloat16]
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str): def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str):
""" """
...@@ -228,32 +228,32 @@ def matmul( ...@@ -228,32 +228,32 @@ def matmul(
val (tir.PrimExpr): A uint8 value containing packed FP4 elements. val (tir.PrimExpr): A uint8 value containing packed FP4 elements.
pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract. pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract.
scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16. scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16.
dtype (str): Target dtype string; must be "bfloat16". dtype (str): Target dtype string; must be T.bfloat16.
Returns: Returns:
tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value. tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value.
Notes: Notes:
- The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8". - The function asserts `nbit == 4`, `dtype == T.bfloat16`, and that `val.dtype` is T.uint8.
- The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16 - The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16
bit fields and clamps the computed exponent to fit into 8 bits. bit fields and clamps the computed exponent to fit into 8 bits.
""" """
assert nbit == 4 assert nbit == 4
assert dtype == "bfloat16" assert dtype == T.bfloat16
assert val.dtype == "uint8" assert val.dtype == T.uint8
mask = tir.const((1 << nbit) - 1, "uint16") mask = tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, "uint16") s = f4 >> tir.const(3, T.uint16)
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16)
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16") e_bf16 = e_f4 + tir.const(126, T.uint16)
# Scale is the exponential part, within the representation of uint8 # Scale is the exponential part, within the representation of uint8
# To handle the overflow, we use the max function to limit the exponential part to 8 bits # To handle the overflow, we use the max function to limit the exponential part to 8 bits
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, T.uint16))
m_f4 = f4 & tir.const(1, "uint16") m_f4 = f4 & tir.const(1, T.uint16)
val_bf16 = tir.reinterpret( val_bf16 = tir.reinterpret(
"bfloat16", T.bfloat16,
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"), ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16),
) )
return val_bf16 return val_bf16
...@@ -364,7 +364,7 @@ def ref_program_twiddling(A, qB): ...@@ -364,7 +364,7 @@ def ref_program_twiddling(A, qB):
Returns: Returns:
torch.Tensor: Result matrix C with shape (M, N) in bfloat16. torch.Tensor: Result matrix C with shape (M, N) in bfloat16.
""" """
dtypeC = "bfloat16" dtypeC = T.bfloat16
B = torch_convert_bit_twiddling(qB) B = torch_convert_bit_twiddling(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
...@@ -384,7 +384,7 @@ def ref_program_simple(A, qB): ...@@ -384,7 +384,7 @@ def ref_program_simple(A, qB):
Returns: Returns:
torch.Tensor: Resulting matrix C in bfloat16 with shape (M, N). torch.Tensor: Resulting matrix C in bfloat16 with shape (M, N).
""" """
dtypeC = "bfloat16" dtypeC = T.bfloat16
B = torch_convert(qB) B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
...@@ -410,15 +410,15 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False): ...@@ -410,15 +410,15 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False):
""" """
total_flops = 2 * m * n * k total_flops = 2 * m * n * k
if tune: if tune:
kernel = matmul(m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant) kernel = matmul(m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, fast_dequant=fast_dequant)
else: else:
kernel = matmul( kernel = matmul(
m, m,
n, n,
k, k,
"bfloat16", T.bfloat16,
"bfloat16", T.bfloat16,
"float32", T.float32,
num_bits=4, num_bits=4,
fast_dequant=fast_dequant, fast_dequant=fast_dequant,
block_M=256, block_M=256,
......
...@@ -20,31 +20,31 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale ...@@ -20,31 +20,31 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be "bfloat16"). dtype (str): Destination dtype string (must be T.bfloat16).
Returns: Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Notes: Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". - Preconditions are enforced via assertions: nbit == 4, dtype == T.bfloat16, and val.dtype == T.uint8.
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
""" """
assert nbit == 4 assert nbit == 4
assert dtype == "bfloat16" assert dtype == T.bfloat16
assert val.dtype == "uint8" assert val.dtype == T.uint8
mask = tir.const((1 << nbit) - 1, "uint16") mask = tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, "uint16") s = f4 >> tir.const(3, T.uint16)
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16)
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16") e_bf16 = e_f4 + tir.const(126, T.uint16)
# Scale is the exponential part, within the representation of uint8 # Scale is the exponential part, within the representation of uint8
# To handle the overflow, we may use the min function to limit the exponential part to 8 bits # To handle the overflow, we may use the min function to limit the exponential part to 8 bits
# e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16") m_f4 = f4 & tir.const(1, T.uint16)
val_bf16 = tir.reinterpret( val_bf16 = tir.reinterpret(
"bfloat16", T.bfloat16,
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"), ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16),
) )
return val_bf16 return val_bf16
...@@ -90,7 +90,7 @@ def matmul( ...@@ -90,7 +90,7 @@ def matmul(
in_dtype, in_dtype,
out_dtype, out_dtype,
accum_dtype, accum_dtype,
source_format="uint", source_format=T.uint32,
num_bits=4, num_bits=4,
scale_size=32, scale_size=32,
fast_dequant=True, fast_dequant=True,
...@@ -116,7 +116,7 @@ def matmul( ...@@ -116,7 +116,7 @@ def matmul(
Parameters: Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file). in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16"). out_dtype (str): output tensor element type (e.g., T.bfloat16).
accum_dtype (str): accumulation type used for the inner GEMM. accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint"). source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4). num_bits (int, optional): number of bits per quantized element in B (default 4).
...@@ -141,7 +141,7 @@ def matmul( ...@@ -141,7 +141,7 @@ def matmul(
- An assertion enforces that K % (block_K * split) == 0. - An assertion enforces that K % (block_K * split) == 0.
""" """
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = T.uint8
QK = K // num_elems_per_byte QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte Block_QK = block_K // num_elems_per_byte
A_shape = (M, K) A_shape = (M, K)
...@@ -170,7 +170,7 @@ def matmul( ...@@ -170,7 +170,7 @@ def matmul(
assert func_name is not None, "mxfp_intrin_info is not found" assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source import_source = import_source
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16):
""" """
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
...@@ -181,12 +181,12 @@ def matmul( ...@@ -181,12 +181,12 @@ def matmul(
- Writes the scaled BF16 results into B_dequantize_shared. - Writes the scaled BF16 results into B_dequantize_shared.
Notes: Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16". - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
""" """
assert in_dtype in ["fp4"] assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"] assert out_dtype in [T.bfloat16]
# Some variables for dequantization in each thread # Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128 MAX_TRANSACTION_SIZE_BITS = 128
...@@ -262,19 +262,19 @@ def matmul( ...@@ -262,19 +262,19 @@ def matmul(
return fast_dequant_bf16_fp4_twiddling return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16):
""" """
Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16. Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16.
Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared. Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared.
Notes: Notes:
- Only supports in_dtype="fp4" and out_dtype="bfloat16". - Only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion. - The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion.
- Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper. - Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper.
""" """
assert in_dtype in ["fp4"] assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"] assert out_dtype in [T.bfloat16]
@T.macro @T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k): def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k):
...@@ -394,7 +394,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None): ...@@ -394,7 +394,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
Returns: Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
""" """
dtypeC = "bfloat16" dtypeC = T.bfloat16
B = torch_convert_bit_twiddling(qB) B = torch_convert_bit_twiddling(qB)
B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
...@@ -417,7 +417,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias): ...@@ -417,7 +417,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
Returns: Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
""" """
dtypeC = "bfloat16" dtypeC = T.bfloat16
B = torch_convert_bit_twiddling(qB) B = torch_convert_bit_twiddling(qB)
B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
...@@ -441,7 +441,7 @@ def ref_program_simple(A, qB, Scale, Bias=None): ...@@ -441,7 +441,7 @@ def ref_program_simple(A, qB, Scale, Bias=None):
No in-place modification is performed on inputs (a local floating copy of B is scaled). No in-place modification is performed on inputs (a local floating copy of B is scaled).
""" """
dtypeC = "bfloat16" dtypeC = T.bfloat16
B = torch_convert(qB) B = torch_convert(qB)
B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
...@@ -469,7 +469,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias): ...@@ -469,7 +469,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
No in-place modification is performed on inputs (a local floating copy of B is scaled). No in-place modification is performed on inputs (a local floating copy of B is scaled).
""" """
dtypeC = "bfloat16" dtypeC = T.bfloat16
B = torch_convert(qB) B = torch_convert(qB)
B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
...@@ -498,16 +498,16 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, ...@@ -498,16 +498,16 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
if tune: if tune:
kernel = matmul( kernel = matmul(
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias
) )
else: else:
kernel = matmul( kernel = matmul(
m, m,
n, n,
k, k,
"bfloat16", T.bfloat16,
"bfloat16", T.bfloat16,
"float32", T.float32,
num_bits=4, num_bits=4,
scale_size=scale_size, scale_size=scale_size,
block_M=256, block_M=256,
......
...@@ -20,31 +20,31 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale ...@@ -20,31 +20,31 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be "bfloat16"). dtype (str): Destination dtype string (must be T.bfloat16).
Returns: Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Notes: Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". - Preconditions are enforced via assertions: nbit == 4, dtype == T.bfloat16, and val.dtype == T.uint8.
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
""" """
assert nbit == 4 assert nbit == 4
assert dtype == "bfloat16" assert dtype == T.bfloat16
assert val.dtype == "uint8" assert val.dtype == T.uint8
mask = tir.const((1 << nbit) - 1, "uint16") mask = tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, "uint16") s = f4 >> tir.const(3, T.uint16)
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16)
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16") e_bf16 = e_f4 + tir.const(126, T.uint16)
# Scale is the exponential part, within the representation of uint8 # Scale is the exponential part, within the representation of uint8
# To handle the overflow, we may use the min function to limit the exponential part to 8 bits # To handle the overflow, we may use the min function to limit the exponential part to 8 bits
# e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16") m_f4 = f4 & tir.const(1, T.uint16)
val_bf16 = tir.reinterpret( val_bf16 = tir.reinterpret(
"bfloat16", T.bfloat16,
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"), ((((s << tir.const(8, T.uint16)) | e_bf16) << tir.const(7, T.uint16)) | (m_f4 << tir.const(6, T.uint16))).astype(T.uint16),
) )
return val_bf16 return val_bf16
...@@ -90,7 +90,7 @@ def matmul( ...@@ -90,7 +90,7 @@ def matmul(
in_dtype, in_dtype,
out_dtype, out_dtype,
accum_dtype, accum_dtype,
source_format="uint", source_format=T.uint32,
num_bits=4, num_bits=4,
scale_size=32, scale_size=32,
fast_dequant=True, fast_dequant=True,
...@@ -116,7 +116,7 @@ def matmul( ...@@ -116,7 +116,7 @@ def matmul(
Parameters: Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split). M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file). in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16"). out_dtype (str): output tensor element type (e.g., T.bfloat16).
accum_dtype (str): accumulation type used for the inner GEMM. accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint"). source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4). num_bits (int, optional): number of bits per quantized element in B (default 4).
...@@ -141,7 +141,7 @@ def matmul( ...@@ -141,7 +141,7 @@ def matmul(
- An assertion enforces that K % (block_K * split) == 0. - An assertion enforces that K % (block_K * split) == 0.
""" """
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = T.uint8
QK = K // num_elems_per_byte QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte Block_QK = block_K // num_elems_per_byte
A_shape = (M, K) A_shape = (M, K)
...@@ -170,7 +170,7 @@ def matmul( ...@@ -170,7 +170,7 @@ def matmul(
assert func_name is not None, "mxfp_intrin_info is not found" assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source import_source = import_source
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16):
""" """
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
...@@ -181,12 +181,12 @@ def matmul( ...@@ -181,12 +181,12 @@ def matmul(
- Writes the scaled BF16 results into B_dequantize_shared. - Writes the scaled BF16 results into B_dequantize_shared.
Notes: Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16". - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
""" """
assert in_dtype in ["fp4"] assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"] assert out_dtype in [T.bfloat16]
# Some variables for dequantization in each thread # Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128 MAX_TRANSACTION_SIZE_BITS = 128
...@@ -262,19 +262,19 @@ def matmul( ...@@ -262,19 +262,19 @@ def matmul(
return fast_dequant_bf16_fp4_twiddling return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16):
""" """
Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16. Create a simple (scalar) dequantization macro that converts 4-bit packed inputs to bfloat16.
Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared. Returns a T.macro that, given shared-storage buffers B_shared, B_dequantize_shared, a Scale tensor, and block index k, unpacks 4-bit values from B_shared, converts each nibble to a bfloat16 value using _tir_u8_to_f4_to_bf16, applies the per-element exponential Scale, and writes the dequantized BF16 block into B_dequantize_shared.
Notes: Notes:
- Only supports in_dtype="fp4" and out_dtype="bfloat16". - Only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion. - The macro expects B_shared and B_dequantize_shared to have the shapes established in the enclosing scope (B_shared_shape, B_dequantize_shared_shape) and performs block-local copying into allocated fragments before elementwise conversion.
- Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper. - Scale holds the exponent-like scaling values indexed per output element as used by the conversion helper.
""" """
assert in_dtype in ["fp4"] assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"] assert out_dtype in [T.bfloat16]
@T.macro @T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k):
...@@ -402,7 +402,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None): ...@@ -402,7 +402,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
Returns: Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
""" """
dtypeC = "bfloat16" dtypeC = T.bfloat16
B = torch_convert_bit_twiddling(qB) B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]): for i in range(B.shape[0]):
for j in range(B.shape[1]): for j in range(B.shape[1]):
...@@ -427,7 +427,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias): ...@@ -427,7 +427,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
Returns: Returns:
torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16. torch.Tensor: Resulting matrix C with shape (M, N) in bfloat16.
""" """
dtypeC = "bfloat16" dtypeC = T.bfloat16
B = torch_convert_bit_twiddling(qB) B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]): for i in range(B.shape[0]):
for j in range(B.shape[1]): for j in range(B.shape[1]):
...@@ -453,7 +453,7 @@ def ref_program_simple(A, qB, Scale, Bias=None): ...@@ -453,7 +453,7 @@ def ref_program_simple(A, qB, Scale, Bias=None):
No in-place modification is performed on inputs (a local floating copy of B is scaled). No in-place modification is performed on inputs (a local floating copy of B is scaled).
""" """
dtypeC = "bfloat16" dtypeC = T.bfloat16
B = torch_convert(qB) B = torch_convert(qB)
for i in range(B.shape[0]): for i in range(B.shape[0]):
for j in range(B.shape[1]): for j in range(B.shape[1]):
...@@ -483,7 +483,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias): ...@@ -483,7 +483,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
No in-place modification is performed on inputs (a local floating copy of B is scaled). No in-place modification is performed on inputs (a local floating copy of B is scaled).
""" """
dtypeC = "bfloat16" dtypeC = T.bfloat16
B = torch_convert(qB) B = torch_convert(qB)
for i in range(B.shape[0]): for i in range(B.shape[0]):
for j in range(B.shape[1]): for j in range(B.shape[1]):
...@@ -514,16 +514,16 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, ...@@ -514,16 +514,16 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
if tune: if tune:
kernel = matmul( kernel = matmul(
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias m, n, k, T.bfloat16, T.bfloat16, T.float32, num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias
) )
else: else:
kernel = matmul( kernel = matmul(
m, m,
n, n,
k, k,
"bfloat16", T.bfloat16,
"bfloat16", T.bfloat16,
"float32", T.float32,
num_bits=4, num_bits=4,
scale_size=scale_size, scale_size=scale_size,
block_M=256, block_M=256,
......
...@@ -26,7 +26,7 @@ def matmul( ...@@ -26,7 +26,7 @@ def matmul(
from tilelang.quantize import _tir_packed_to_unsigned_convert from tilelang.quantize import _tir_packed_to_unsigned_convert
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "int8" storage_dtype = T.int8
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) storage_type = str("".join(c for c in storage_dtype if not c.isdigit()))
A_shape = (M, K) A_shape = (M, K)
...@@ -149,21 +149,21 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -149,21 +149,21 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 from bitblas.gpu.intrin.lop3 import decode_i4_to_f16
assert in_dtype in [ assert in_dtype in [
"float16", T.float16,
"int8", T.int8,
], "Currently only float16 and int8 are supported" ], "Currently only float16 and int8 are supported"
assert out_dtype in [ assert out_dtype in [
"float16", T.float16,
"float32", T.float32,
"int32", T.int32,
], "Currently only float16, float32 and int32 are supported" ], "Currently only float16, float32 and int32 are supported"
num_bits = 4 num_bits = 4
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "int8" storage_dtype = T.int8
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32": if out_dtype == T.int32:
micro_size_k = 32 micro_size_k = 32
# This is a debug config # This is a debug config
...@@ -182,7 +182,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -182,7 +182,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
block_M = block_row_warps * warp_row_tiles block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles block_N = block_col_warps * warp_col_tiles
block_K = 32 if in_dtype == "float16" else 64 block_K = 32 if in_dtype == T.float16 else 64
chunk = block_K // reduce_k chunk = block_K // reduce_k
is_smooth_a = False is_smooth_a = False
...@@ -365,7 +365,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct ...@@ -365,7 +365,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
assert src_code is not None assert src_code is not None
num_bits = 4 num_bits = 4
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "int8" storage_dtype = T.int8
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
qB = torch.randint(0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype)) qB = torch.randint(0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype))
...@@ -417,13 +417,13 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct ...@@ -417,13 +417,13 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
@tilelang.testing.requires_package("bitblas") @tilelang.testing.requires_package("bitblas")
def test_run_dequantize_gemm(): def test_run_dequantize_gemm():
run_gemm(256, 256, 256, "float16", "float16", "float16", 128, 128, 32, num_threads=128) run_gemm(256, 256, 256, T.float16, T.float16, T.float16, 128, 128, 32, num_threads=128)
run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128) run_gemm(256, 256, 256, T.int8, T.int32, T.int32, 128, 128, 32, num_threads=128)
@tilelang.testing.requires_package("bitblas") @tilelang.testing.requires_package("bitblas")
def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(256, 1024, 512, "float16", "float16", "float16", 3) assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(256, 1024, 512, T.float16, T.float16, T.float16, 3)
def main(): def main():
......
...@@ -9,22 +9,22 @@ import argparse ...@@ -9,22 +9,22 @@ import argparse
def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4 assert nbit == 4
assert dtype == "float16" assert dtype == T.float16
assert val.dtype == "uint8" assert val.dtype == T.uint8
# e_f4 == 0 -> e_f16 = 0 # e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + ExponentialBias(f16, f4) = e_f4 + (2^4 - 2^1) = e_f4 + 14 # e_f4 != 0 -> e_f16 = e_f4 + ExponentialBias(f16, f4) = e_f4 + (2^4 - 2^1) = e_f4 + 14
# s1e2m1 # s1e2m1
mask = tir.const((1 << nbit) - 1, "uint16") mask = tir.const((1 << nbit) - 1, T.uint16)
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask f4 = (val >> (pos.astype(T.uint16) * tir.const(nbit, T.uint16))) & mask
s = f4 >> tir.const(3, "uint16") s = f4 >> tir.const(3, T.uint16)
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") e_f4 = (f4 & tir.const(6, T.uint16)) >> tir.const(1, T.uint16)
e_f16 = e_f4 + tir.const(14, "uint16") e_f16 = e_f4 + tir.const(14, T.uint16)
m_f4 = f4 & tir.const(1, "uint16") m_f4 = f4 & tir.const(1, T.uint16)
m_f16 = m_f4 m_f16 = m_f4
val_f16 = tir.reinterpret( val_f16 = tir.reinterpret(
"float16", ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16") | m_f16 << tir.const(9, "uint16")).astype("uint16") T.float16, ((e_f16 | (s << tir.const(5, T.uint16))) << tir.const(10, T.uint16) | m_f16 << tir.const(9, T.uint16)).astype(T.uint16)
) )
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, T.float16), val_f16)
return val_f16 return val_f16
...@@ -60,7 +60,7 @@ def torch_convert(tensor): ...@@ -60,7 +60,7 @@ def torch_convert(tensor):
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = T.uint8
B_shape = (N, K // num_elems_per_byte) B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte) B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K) B_dequantize_shared_shape = (block_N, block_K)
...@@ -98,7 +98,7 @@ def test_fp4_fp16_convert_close(): ...@@ -98,7 +98,7 @@ def test_fp4_fp16_convert_close():
K, K,
block_N, block_N,
block_K, block_K,
"float16", T.float16,
) )
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8) B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
...@@ -125,7 +125,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -125,7 +125,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = T.uint8
A_shape = (M, K) A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte) B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K) A_shared_shape = (block_M, block_K)
...@@ -241,7 +241,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -241,7 +241,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
def ref_program(A, qB): def ref_program(A, qB):
dtypeC = "float16" dtypeC = T.float16
B = torch_convert(qB) B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
...@@ -252,7 +252,7 @@ def main(m=256, n=256, k=256, tune=False): ...@@ -252,7 +252,7 @@ def main(m=256, n=256, k=256, tune=False):
total_flops = 2 * m * n * k total_flops = 2 * m * n * k
if not tune: if not tune:
kernel = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)( kernel = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1 block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1
) )
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
...@@ -265,7 +265,7 @@ def main(m=256, n=256, k=256, tune=False): ...@@ -265,7 +265,7 @@ def main(m=256, n=256, k=256, tune=False):
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else: else:
best_result = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune) best_result = matmul(m, n, k, T.float16, T.float16, T.float32, num_bits=4, tune=tune)
best_latency = best_result.latency best_latency = best_result.latency
best_config = best_result.config best_config = best_result.config
print(f"Best latency: {best_latency}") print(f"Best latency: {best_latency}")
......
...@@ -9,15 +9,15 @@ import argparse ...@@ -9,15 +9,15 @@ import argparse
def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): def _tir_u8_to_i4_to_i8(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4 assert nbit == 4
assert dtype == "int8" assert dtype == T.int8
assert val.dtype == "uint8" assert val.dtype == T.uint8
mask = tir.const((1 << nbit) - 1, "uint8") mask = tir.const((1 << nbit) - 1, T.uint8)
i4 = (val >> (pos.astype("uint8") * tir.const(nbit, "uint8"))) & mask i4 = (val >> (pos.astype(T.uint8) * tir.const(nbit, T.uint8))) & mask
i8_shifted = tir.reinterpret("int8", i4 << tir.const(4, "uint8")) i8_shifted = tir.reinterpret(T.int8, i4 << tir.const(4, T.uint8))
i8 = i8_shifted >> tir.const(4, "int8") i8 = i8_shifted >> tir.const(4, T.int8)
return i8 return i8
...@@ -35,7 +35,7 @@ def get_configs(): ...@@ -35,7 +35,7 @@ def get_configs():
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = T.uint8
B_shape = (N, K // num_elems_per_byte) B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte) B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K) B_dequantize_shared_shape = (block_N, block_K)
...@@ -85,7 +85,7 @@ def torch_convert(tensor): ...@@ -85,7 +85,7 @@ def torch_convert(tensor):
def ref_program(A, qB): def ref_program(A, qB):
dtypeC = "int32" dtypeC = T.int32
B = torch_convert(qB) B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
...@@ -96,7 +96,7 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune ...@@ -96,7 +96,7 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads): def kernel_func(block_M, block_N, block_K, num_stages, threads):
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = T.uint8
A_shape = (M, K) A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte) B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K) A_shared_shape = (block_M, block_K)
...@@ -166,7 +166,7 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune ...@@ -166,7 +166,7 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune
def main(m=128, n=256, k=256, tune=False): def main(m=128, n=256, k=256, tune=False):
total_flops = 2 * m * n * k total_flops = 2 * m * n * k
if not tune: if not tune:
kernel = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)( kernel = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune)(
block_M=32, block_N=32, block_K=128, num_stages=1, threads=128 block_M=32, block_N=32, block_K=128, num_stages=1, threads=128
) )
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
...@@ -177,7 +177,7 @@ def main(m=128, n=256, k=256, tune=False): ...@@ -177,7 +177,7 @@ def main(m=128, n=256, k=256, tune=False):
print(f"Tilelang: {latency} ms") print(f"Tilelang: {latency} ms")
else: else:
best_result = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune) best_result = matmul_int8xint4(m, n, k, T.int8, T.int32, T.int32, num_bits=4, tune=tune)
best_latency = best_result.latency best_latency = best_result.latency
best_config = best_result.config best_config = best_result.config
print(f"Bset latency: {best_latency}") print(f"Bset latency: {best_latency}")
......
...@@ -17,7 +17,7 @@ def dequantize_gemv( ...@@ -17,7 +17,7 @@ def dequantize_gemv(
out_dtype: str, out_dtype: str,
accum_dtype: str, accum_dtype: str,
num_bits: int = 4, num_bits: int = 4,
storage_dtype: str = "int8", storage_dtype: T.dtype = T.int8,
source_format: str = "uint", source_format: str = "uint",
n_partition: int = 4, n_partition: int = 4,
reduce_thread: int = 32, reduce_thread: int = 32,
...@@ -51,7 +51,7 @@ def dequantize_gemv( ...@@ -51,7 +51,7 @@ def dequantize_gemv(
C_shape = (M, N) C_shape = (M, N)
dp4a_size = 4 dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32" use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32
import_source: Optional[str] = None import_source: Optional[str] = None
func_name: str = "" func_name: str = ""
...@@ -159,11 +159,11 @@ def main() -> None: ...@@ -159,11 +159,11 @@ def main() -> None:
M = 1 M = 1
N = 1024 N = 1024
K = 1024 K = 1024
in_dtype = "float16" in_dtype = T.float16
out_dtype = "float16" out_dtype = T.float16
accum_dtype = "float16" accum_dtype = T.float16
num_bits = 4 num_bits = 4
storage_dtype = "int8" storage_dtype = T.int8
source_format = "uint" source_format = "uint"
n_partition = 4 n_partition = 4
reduce_thread = 32 reduce_thread = 32
......
...@@ -49,7 +49,7 @@ def matmul( ...@@ -49,7 +49,7 @@ def matmul(
in_dtype, in_dtype,
out_dtype, out_dtype,
accum_dtype, accum_dtype,
source_format="uint", source_format=T.uint32,
num_bits=4, num_bits=4,
scale_size=32, scale_size=32,
fast_dequant=True, fast_dequant=True,
...@@ -83,8 +83,8 @@ def matmul( ...@@ -83,8 +83,8 @@ def matmul(
topk (int): number of experts selected per token. topk (int): number of experts selected per token.
E (int): number of experts. E (int): number of experts.
padding_M (int): padded number of tokens after grouping and block alignment. padding_M (int): padded number of tokens after grouping and block alignment.
in_dtype (str): element type of A (e.g., "bfloat16"). in_dtype (str): element type of A (e.g., T.bfloat16).
out_dtype (str): output tensor element type (e.g., "bfloat16"). out_dtype (str): output tensor element type (e.g., T.bfloat16).
accum_dtype (str): accumulation type used for the inner GEMM. accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint"). source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4). num_bits (int, optional): number of bits per quantized element in B (default 4).
...@@ -111,7 +111,7 @@ def matmul( ...@@ -111,7 +111,7 @@ def matmul(
""" """
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = T.uint8
QK = K // num_elems_per_byte QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte Block_QK = block_K // num_elems_per_byte
A_shared_shape = (block_M, block_K) A_shared_shape = (block_M, block_K)
...@@ -137,7 +137,7 @@ def matmul( ...@@ -137,7 +137,7 @@ def matmul(
import_source = import_source import_source = import_source
# the dequant part is the same as in dequant_gemm # the dequant part is the same as in dequant_gemm
def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype=T.bfloat16):
""" """
Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16. Return a TileLang macro that performs fast dequantization of twiddled FP4-packed data into BF16.
The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and: The returned macro has signature (B_shared, B_dequantize_shared, Scale, k) and:
...@@ -147,12 +147,12 @@ def matmul( ...@@ -147,12 +147,12 @@ def matmul(
- Writes the scaled BF16 results into B_dequantize_shared. - Writes the scaled BF16 results into B_dequantize_shared.
Notes: Notes:
- This factory only supports in_dtype="fp4" and out_dtype="bfloat16". - This factory only supports in_dtype="fp4" and out_dtype=T.bfloat16.
- The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro. - The macro depends on several names from the enclosing scope (e.g., import_source, func_name, DataType, num_elems_per_byte, storage_dtype, block_N, block_K, threads, scale_size); those must be defined and consistent with the kernel that will use the macro.
- The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime. - The macro issues a T.import_source and T.call_extern to invoke the external intrinsic; ensure the external implementation matching `func_name` is available at compilation/runtime.
""" """
assert in_dtype in ["fp4"] assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"] assert out_dtype in [T.bfloat16]
# Some variables for dequantization in each thread # Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128 MAX_TRANSACTION_SIZE_BITS = 128
...@@ -227,9 +227,9 @@ def matmul( ...@@ -227,9 +227,9 @@ def matmul(
return fast_dequant_bf16_fp4_twiddling return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): def get_simple_dequant_func(in_dtype="fp4", out_dtype=T.bfloat16):
assert in_dtype in ["fp4"] assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"] assert out_dtype in [T.bfloat16]
@T.macro @T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k):
...@@ -259,8 +259,8 @@ def matmul( ...@@ -259,8 +259,8 @@ def matmul(
Bias: T.Tensor((E, N), out_dtype), Bias: T.Tensor((E, N), out_dtype),
# Add fusedmoe tensors # Add fusedmoe tensors
topk_weights: T.Tensor((M * topk), out_dtype), topk_weights: T.Tensor((M * topk), out_dtype),
sorted_token_ids: T.Tensor((padding_M), "int32"), sorted_token_ids: T.Tensor((padding_M), T.int32),
expert_ids: T.Tensor((padding_M // block_M), "int32"), expert_ids: T.Tensor((padding_M // block_M), T.int32),
C: T.Tensor((M, topk, N), out_dtype), C: T.Tensor((M, topk, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by):
...@@ -271,8 +271,8 @@ def matmul( ...@@ -271,8 +271,8 @@ def matmul(
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype)
topk_weights_shared = T.alloc_shared((block_M), out_dtype) topk_weights_shared = T.alloc_shared((block_M), out_dtype)
sorted_token_ids_shared = T.alloc_shared((block_M), "int32") sorted_token_ids_shared = T.alloc_shared((block_M), T.int32)
expert_id = T.alloc_local((1), "int32") # the expert id for the current block expert_id = T.alloc_local((1), T.int32) # the expert id for the current block
# To use 1D TMA, the last dim of Scale_shared must have stride=1 # To use 1D TMA, the last dim of Scale_shared must have stride=1
# May use much more shared memory than necessary # May use much more shared memory than necessary
Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype)
...@@ -346,7 +346,7 @@ def matmul( ...@@ -346,7 +346,7 @@ def matmul(
def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=256): def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=256):
dtypeC = "bfloat16" dtypeC = T.bfloat16
M, K = A.shape M, K = A.shape
E, N, QK = qB.shape E, N, QK = qB.shape
topk = topk_weights.shape[0] // M topk = topk_weights.shape[0] // M
...@@ -451,9 +451,9 @@ def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, wi ...@@ -451,9 +451,9 @@ def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, wi
topk, topk,
E, E,
padding_M, padding_M,
"bfloat16", T.bfloat16,
"bfloat16", T.bfloat16,
"float32", T.float32,
num_bits=num_bits, num_bits=num_bits,
scale_size=scale_size, scale_size=scale_size,
fast_dequant=fast_dequant, fast_dequant=fast_dequant,
...@@ -467,9 +467,9 @@ def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, wi ...@@ -467,9 +467,9 @@ def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, wi
topk, topk,
E, E,
padding_M, padding_M,
"bfloat16", T.bfloat16,
"bfloat16", T.bfloat16,
"float32", T.float32,
num_bits=num_bits, num_bits=num_bits,
scale_size=scale_size, scale_size=scale_size,
fast_dequant=fast_dequant, fast_dequant=fast_dequant,
......
...@@ -9,9 +9,9 @@ from index import prepare_token_indices ...@@ -9,9 +9,9 @@ from index import prepare_token_indices
from utils import get_abs_err, get_err_ratio from utils import get_abs_err, get_err_ratio
BF16 = "bfloat16" BF16 = T.bfloat16
FP32 = "float32" FP32 = T.float32
INT32 = "int32" INT32 = T.int32
pass_configs = { pass_configs = {
tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment