Unverified Commit a7730272 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Recommend using `T.dynamic` instead of `T.symbolic` (#1076)

* recommend using T.dynamic instead of T.symbolic

* lint fix

* lint fix
parent fd6cec58
......@@ -178,7 +178,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return matmul_relu_kernel
M = 1024 # M = T.symbolic("m") if you want to use dynamic shape
M = 1024 # M = T.dynamic("m") if you want to use dynamic shape
N = 1024
K = 1024
block_M = 128
......
......@@ -89,7 +89,7 @@ def elementwise_add(
In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this:
```python
program = elementwise_add(T.symbolic("N"), threads=256, dtype="bfloat16")
program = elementwise_add(T.dynamic("N"), threads=256, dtype="bfloat16")
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
```
......
......@@ -223,12 +223,12 @@ class SparseFlashAttn(torch.nn.Module):
block_N=block_N,
block_H=self.block_H,
page_block_size=page_block_size,
num_split=T.symbolic("num_split"),
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
num_pages=num_pages,
max_num_blocks_per_seq=T.symbolic("max_num_blocks_per_seq"),
max_selected_blocks=T.symbolic("max_selected_blocks"),
max_num_blocks_per_seq=T.dynamic("max_num_blocks_per_seq"),
max_selected_blocks=T.dynamic("max_selected_blocks"),
)
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
......
......@@ -206,11 +206,11 @@ class SparseFlashAttn(torch.nn.Module):
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=self.block_H,
num_split=T.symbolic("num_split"),
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
max_selected_blocks=T.symbolic("max_selected_blocks"))
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
max_selected_blocks=T.dynamic("max_selected_blocks"))
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count
......@@ -301,11 +301,11 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=block_H,
num_split=T.symbolic("num_split"),
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
max_selected_blocks=T.symbolic("max_selected_blocks"))
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
max_selected_blocks=T.dynamic("max_selected_blocks"))
output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial)
return output
......
......@@ -193,11 +193,11 @@ class SparseFlashAttn(torch.nn.Module):
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=self.block_H,
num_split=T.symbolic("num_split"),
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
num_blocks=T.symbolic("num_blocks"))
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
num_blocks=T.dynamic("num_blocks"))
props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count
......@@ -282,11 +282,11 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size,
block_H=block_H,
num_split=T.symbolic("num_split"),
num_split=T.dynamic("num_split"),
num_stages=2,
threads=128,
max_cache_seqlen=T.symbolic("max_cache_seqlen"),
num_blocks=T.symbolic("num_blocks"))
max_cache_seqlen=T.dynamic("max_cache_seqlen"),
num_blocks=T.dynamic("num_blocks"))
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32,
......
......@@ -103,8 +103,8 @@ def mqa_attn_return_logits(
accum_dtype = "float"
index_dtype = "int32"
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
seq_len = T.dynamic("seq_len")
seq_len_kv = T.dynamic("seq_len_kv")
index_q_shape = [seq_len * heads, index_dim]
index_k_shape = [seq_len_kv, index_dim]
......@@ -182,8 +182,8 @@ def clean_logits_(
threads: int = 512,
block_K: int = 4096,
):
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
seq_len = T.dynamic("seq_len")
seq_len_kv = T.dynamic("seq_len_kv")
dtype = "float"
indices_dtype = "int32"
......
......@@ -34,7 +34,7 @@ def fast_round_scale(amax, fp8_max_inv):
@tilelang.jit(pass_configs=pass_configs)
def act_quant_kernel(N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False):
M = T.symbolic("M")
M = T.dynamic("M")
fp8_min = -448.0
fp8_max = 448.0
fp8_max_inv = 1 / fp8_max
......@@ -110,7 +110,7 @@ def act_quant(x: torch.Tensor,
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"):
assert out_dtype in [BF16, "float32"]
M = T.symbolic("M")
M = T.dynamic("M")
group_size = 128
block_M = 32
block_N = 128
......@@ -192,9 +192,9 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor,
@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
def fp8_index_kernel(h: int, d: int):
b = T.symbolic("b")
m = T.symbolic("m")
n = T.symbolic("n")
b = T.dynamic("b")
m = T.dynamic("m")
n = T.dynamic("n")
blk_n1 = 512
blk_n2 = 128
......
......@@ -37,9 +37,9 @@ def sparse_mla_fwd(
else:
sm_scale = sm_scale * 1.44269504 # log2(e)
batch = T.symbolic("batch")
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
batch = T.dynamic("batch")
seq_len = T.dynamic("seq_len")
seq_len_kv = T.dynamic("seq_len_kv")
head_kv = heads // kv_group
q_shape = [batch, seq_len, heads, dim + tail_dim]
......
......@@ -26,8 +26,8 @@ def convert_to_uint32(x):
@tilelang.jit(pass_configs=pass_configs)
def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
batch = T.symbolic("batch")
seq_len = T.symbolic("seq_len")
batch = T.dynamic("batch")
seq_len = T.dynamic("seq_len")
RADIX = 1 << 8
BLOCK_SIZE = 1024
SMEM_INPUT_SIZE = 4096 # assume the threshold bucket size after first pass is less than 4K
......
......@@ -41,7 +41,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main
M = 128 # M = T.symbolic("m") if you want to use dynamic shape
M = 128 # M = T.dynamic("m") if you want to use dynamic shape
N = 128
K = 32
block_M = 128
......
......@@ -48,7 +48,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return matmul_relu_kernel
M = 1024 # M = T.symbolic("m") if you want to use dynamic shape
M = 1024 # M = T.dynamic("m") if you want to use dynamic shape
N = 1024
K = 1024
block_M = 128
......
......@@ -24,7 +24,7 @@ def test_empty_kernel_lowering():
@tilelang.jit
def _empty_with_dead_code_kernel():
num_tokens = T.symbolic("num_tokens")
num_tokens = T.dynamic("num_tokens")
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "float32"]):
......
......@@ -395,14 +395,14 @@ def run_ctypes_dynamic_shape(M,
def test_ctypes_dynamic_shape():
run_ctypes_dynamic_shape(
T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_ctypes_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), 768, False, False, "float16", "float16", "float16", 128,
T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
run_ctypes_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), T.symbolic("k"), False, False, "float16", "float16",
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16",
"float16", 128, 256, 32, 2)
......
......@@ -404,14 +404,14 @@ def run_cython_dynamic_shape(M,
def test_cython_dynamic_shape():
run_cython_dynamic_shape(
T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cython_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), 768, False, False, "float16", "float16", "float16", 128,
T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
run_cython_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), T.symbolic("k"), False, False, "float16", "float16",
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16",
"float16", 128, 256, 32, 2)
......@@ -473,7 +473,7 @@ def run_cython_dynamic_shape_with_out_idx(M,
def test_cython_dynamic_shape_with_out_idx():
run_cython_dynamic_shape_with_out_idx(
T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def matmul_int_variable(
......
......@@ -83,7 +83,7 @@ def run_tilelang_copy_with_stride(M=1024,
def test_tilelang_copy_with_stride():
run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128)
run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.symbolic("NN"), block_M=128, block_N=128)
run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.dynamic("NN"), block_M=128, block_N=128)
def tilelang_copy_bufferload(num_tokens, dtype="float16"):
......
......@@ -41,7 +41,7 @@ def assert_vectorize_access(M: int = 64, N: int = 64):
def issue_1013_buggy_kernel():
# NOTE: This kernel is mainly to test some corner cases in boundary check
num_tokens = T.symbolic('num_tokens')
num_tokens = T.dynamic('num_tokens')
num_threads = 128
@T.prim_func
......
"""The language interface for tl programs."""
from typing import Optional, Callable, Dict
from typing import Optional
# from .parser import *
# now is fully compatible with the upstream
# tir script
......@@ -84,124 +84,10 @@ from .builtin import * # noqa: F401
from .utils import index_to_coordinates # noqa: F401
def symbolic(name: str, dtype: str = "int32"):
"""
Create a TIR symbolic variable.
Parameters:
name (str): Identifier for the variable in generated TIR.
dtype (str): Data type string for the variable (e.g., "int32"). Defaults to "int32".
Returns:
tir.Var: A TIR variable with the given name and dtype for use in TIR/TensorIR kernels.
"""
return tir.Var(name, dtype)
def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
# If order is row, use rasterization2DRow, otherwise use rasterization2DColumn
# The panel size is the number of threads in a warp
# Use to improve the L2 Cache Locality
device_func = ("rasterization2DRow" if order == "row" else "rasterization2DColumn")
return attr(None, "threadblock_swizzle_pattern",
f"tl::{device_func}<{panel_size}>") if enable else None
def annotate_layout(layout_map: Dict):
"""Annotate the layout of the buffer
Args:
layout_map (Dict): a dictionary of buffer to layout
Returns:
block_attr: a block attribute
Example:
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.annotate_layout({A_shared: layout})
for i, j in T.Parallel(block_M, block_N):
A_shared[i, j] = A[by * block_M + i, bx * block_N + j]
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = A_shared[i, j]
return main
"""
# layout_map is a dictionary of buffer to layout
_layout_map = {}
for buffer, layout in layout_map.items():
if isinstance(layout, Layout):
_layout_map[buffer.data] = layout
elif isinstance(layout, Callable):
_layout_map[buffer.data] = Layout(buffer.shape, layout)
else:
raise ValueError(f"Invalid layout: {layout}")
return block_attr({"layout_map": _layout_map})
def annotate_safe_value(safe_value_map: Dict):
"""Annotate the safe value of the buffer.
A safe value of a buffer is the value that will be used when the
buffer is accessed out of bounds.
Args:
safe_value_map (dict): a dictionary of buffer to safe value
Returns:
block_attr: a block attribute
Example:
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.annotate_safe_value({A: safe_value})
for i, j in T.Parallel(block_M, block_N):
A_shared[i, j] = A[by * block_M + i - 10, bx * block_N + j]
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = A_shared[i, j]
return main
"""
# safe_value_map is a dictionary of buffer to safe value
_safe_value_map = {}
for buffer, safe_value in safe_value_map.items():
_safe_value_map[buffer.data] = safe_value
return block_attr({"safe_value_map": _safe_value_map})
def annotate_l2_hit_ratio(l2_hit_ratio_map: Dict):
"""Annotate the L2 hit ratio of the buffer, detailed explanation please refer to:
https://docs.nvidia.com/cuda/cuda-c-programming-guide/#l2-policy-for-persisting-accesses
Args:
l2_hit_ratio_map (dict): a dictionary of buffer to L2 hit ratio value
Example:
# 0.5 is the hit ratio
T.annotate_l2_hit_ratio({A: 0.5})
"""
_l2_hit_ratio_map = {}
for buffer, hit_ratio in l2_hit_ratio_map.items():
assert buffer.scope() == "global", "persistent L2 can only be applied to global buffers"
_l2_hit_ratio_map[buffer.data] = float(hit_ratio)
return block_attr({"l2_hit_ratio_map": _l2_hit_ratio_map})
from .symbolics import dynamic, symbolic # noqa: F401
from .annotations import ( # noqa: F401
use_swizzle, annotate_layout, annotate_safe_value, annotate_l2_hit_ratio,
)
def import_source(source: Optional[str] = None):
......
"""Annotation helpers exposed on the TileLang language surface."""
from typing import Callable, Dict
from tilelang.layout import Layout
from tvm.script.parser.tir import attr, block_attr
__all__ = [
"use_swizzle",
"annotate_layout",
"annotate_safe_value",
"annotate_l2_hit_ratio",
]
def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
"""Annotate a kernel to use a specific threadblock swizzle pattern."""
device_func = "rasterization2DRow" if order == "row" else "rasterization2DColumn"
if not enable:
return None
return attr(None, "threadblock_swizzle_pattern", f"tl::{device_func}<{panel_size}>")
def annotate_layout(layout_map: Dict):
"""Annotate the layout of the buffer."""
_layout_map = {}
for buffer, layout in layout_map.items():
if isinstance(layout, Layout):
_layout_map[buffer.data] = layout
elif isinstance(layout, Callable):
_layout_map[buffer.data] = Layout(buffer.shape, layout)
else:
raise ValueError(f"Invalid layout: {layout}")
return block_attr({"layout_map": _layout_map})
def annotate_safe_value(safe_value_map: Dict):
"""Annotate the safe value of the buffer."""
_safe_value_map = {}
for buffer, safe_value in safe_value_map.items():
_safe_value_map[buffer.data] = safe_value
return block_attr({"safe_value_map": _safe_value_map})
def annotate_l2_hit_ratio(l2_hit_ratio_map: Dict):
"""Annotate the L2 hit ratio of the buffer."""
_l2_hit_ratio_map = {}
for buffer, hit_ratio in l2_hit_ratio_map.items():
assert buffer.scope() == "global", "persistent L2 can only be applied to global buffers"
_l2_hit_ratio_map[buffer.data] = float(hit_ratio)
return block_attr({"l2_hit_ratio_map": _l2_hit_ratio_map})
"""Symbolic variable helpers exposed on the TileLang language surface."""
from tvm import tir
from tilelang.utils import deprecated
__all__ = ["dynamic", "symbolic"]
@deprecated("T.dynamic(...)", "tir.Var(...)", "v0.1.9")
def dynamic(name: str, dtype: str = "int32"):
"""
Create a TIR dynamic symbolic variable.
Parameters:
name (str): Identifier for the variable in generated TIR.
dtype (str): Data type string for the variable (e.g., "int32"). Defaults to "int32".
Returns:
tir.Var: A TIR variable with the given name and dtype for use in TIR/TensorIR kernels.
"""
return tir.Var(name, dtype)
@deprecated("T.symbolic(...)", "T.dynamic(...)")
def symbolic(name: str, dtype: str = "int32"):
"""Deprecated alias for `T.dynamic`."""
return tir.Var(name, dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment