Unverified Commit a7730272 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Recommend using `T.dynamic` instead of `T.symbolic` (#1076)

* recommend using T.dynamic instead of T.symbolic

* lint fix

* lint fix
parent fd6cec58
...@@ -178,7 +178,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -178,7 +178,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return matmul_relu_kernel return matmul_relu_kernel
M = 1024 # M = T.symbolic("m") if you want to use dynamic shape M = 1024 # M = T.dynamic("m") if you want to use dynamic shape
N = 1024 N = 1024
K = 1024 K = 1024
block_M = 128 block_M = 128
......
...@@ -89,7 +89,7 @@ def elementwise_add( ...@@ -89,7 +89,7 @@ def elementwise_add(
In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this: In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this:
```python ```python
program = elementwise_add(T.symbolic("N"), threads=256, dtype="bfloat16") program = elementwise_add(T.dynamic("N"), threads=256, dtype="bfloat16")
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython") kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
``` ```
......
...@@ -223,12 +223,12 @@ class SparseFlashAttn(torch.nn.Module): ...@@ -223,12 +223,12 @@ class SparseFlashAttn(torch.nn.Module):
block_N=block_N, block_N=block_N,
block_H=self.block_H, block_H=self.block_H,
page_block_size=page_block_size, page_block_size=page_block_size,
num_split=T.symbolic("num_split"), num_split=T.dynamic("num_split"),
num_stages=2, num_stages=2,
threads=128, threads=128,
num_pages=num_pages, num_pages=num_pages,
max_num_blocks_per_seq=T.symbolic("max_num_blocks_per_seq"), max_num_blocks_per_seq=T.dynamic("max_num_blocks_per_seq"),
max_selected_blocks=T.symbolic("max_selected_blocks"), max_selected_blocks=T.dynamic("max_selected_blocks"),
) )
props = torch.cuda.get_device_properties(torch.device("cuda:0")) props = torch.cuda.get_device_properties(torch.device("cuda:0"))
......
...@@ -206,11 +206,11 @@ class SparseFlashAttn(torch.nn.Module): ...@@ -206,11 +206,11 @@ class SparseFlashAttn(torch.nn.Module):
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size, block_N=block_size,
block_H=self.block_H, block_H=self.block_H,
num_split=T.symbolic("num_split"), num_split=T.dynamic("num_split"),
num_stages=2, num_stages=2,
threads=128, threads=128,
max_cache_seqlen=T.symbolic("max_cache_seqlen"), max_cache_seqlen=T.dynamic("max_cache_seqlen"),
max_selected_blocks=T.symbolic("max_selected_blocks")) max_selected_blocks=T.dynamic("max_selected_blocks"))
props = torch.cuda.get_device_properties(torch.device("cuda:0")) props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count self.num_sm = props.multi_processor_count
...@@ -301,11 +301,11 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql ...@@ -301,11 +301,11 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size, block_N=block_size,
block_H=block_H, block_H=block_H,
num_split=T.symbolic("num_split"), num_split=T.dynamic("num_split"),
num_stages=2, num_stages=2,
threads=128, threads=128,
max_cache_seqlen=T.symbolic("max_cache_seqlen"), max_cache_seqlen=T.dynamic("max_cache_seqlen"),
max_selected_blocks=T.symbolic("max_selected_blocks")) max_selected_blocks=T.dynamic("max_selected_blocks"))
output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial) output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial)
return output return output
......
...@@ -193,11 +193,11 @@ class SparseFlashAttn(torch.nn.Module): ...@@ -193,11 +193,11 @@ class SparseFlashAttn(torch.nn.Module):
self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size, block_N=block_size,
block_H=self.block_H, block_H=self.block_H,
num_split=T.symbolic("num_split"), num_split=T.dynamic("num_split"),
num_stages=2, num_stages=2,
threads=128, threads=128,
max_cache_seqlen=T.symbolic("max_cache_seqlen"), max_cache_seqlen=T.dynamic("max_cache_seqlen"),
num_blocks=T.symbolic("num_blocks")) num_blocks=T.dynamic("num_blocks"))
props = torch.cuda.get_device_properties(torch.device("cuda:0")) props = torch.cuda.get_device_properties(torch.device("cuda:0"))
self.num_sm = props.multi_processor_count self.num_sm = props.multi_processor_count
...@@ -282,11 +282,11 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, ...@@ -282,11 +282,11 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens,
kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( kernel = flashattn(batch, heads, heads_kv, dim, dim_v)(
block_N=block_size, block_N=block_size,
block_H=block_H, block_H=block_H,
num_split=T.symbolic("num_split"), num_split=T.dynamic("num_split"),
num_stages=2, num_stages=2,
threads=128, threads=128,
max_cache_seqlen=T.symbolic("max_cache_seqlen"), max_cache_seqlen=T.dynamic("max_cache_seqlen"),
num_blocks=T.symbolic("num_blocks")) num_blocks=T.dynamic("num_blocks"))
glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda')
Output_partial = torch.empty((batch, heads, num_split, dim_v), Output_partial = torch.empty((batch, heads, num_split, dim_v),
dtype=torch.float32, dtype=torch.float32,
......
...@@ -103,8 +103,8 @@ def mqa_attn_return_logits( ...@@ -103,8 +103,8 @@ def mqa_attn_return_logits(
accum_dtype = "float" accum_dtype = "float"
index_dtype = "int32" index_dtype = "int32"
seq_len = T.symbolic("seq_len") seq_len = T.dynamic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv") seq_len_kv = T.dynamic("seq_len_kv")
index_q_shape = [seq_len * heads, index_dim] index_q_shape = [seq_len * heads, index_dim]
index_k_shape = [seq_len_kv, index_dim] index_k_shape = [seq_len_kv, index_dim]
...@@ -182,8 +182,8 @@ def clean_logits_( ...@@ -182,8 +182,8 @@ def clean_logits_(
threads: int = 512, threads: int = 512,
block_K: int = 4096, block_K: int = 4096,
): ):
seq_len = T.symbolic("seq_len") seq_len = T.dynamic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv") seq_len_kv = T.dynamic("seq_len_kv")
dtype = "float" dtype = "float"
indices_dtype = "int32" indices_dtype = "int32"
......
...@@ -34,7 +34,7 @@ def fast_round_scale(amax, fp8_max_inv): ...@@ -34,7 +34,7 @@ def fast_round_scale(amax, fp8_max_inv):
@tilelang.jit(pass_configs=pass_configs) @tilelang.jit(pass_configs=pass_configs)
def act_quant_kernel(N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False): def act_quant_kernel(N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False):
M = T.symbolic("M") M = T.dynamic("M")
fp8_min = -448.0 fp8_min = -448.0
fp8_max = 448.0 fp8_max = 448.0
fp8_max_inv = 1 / fp8_max fp8_max_inv = 1 / fp8_max
...@@ -110,7 +110,7 @@ def act_quant(x: torch.Tensor, ...@@ -110,7 +110,7 @@ def act_quant(x: torch.Tensor,
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"): def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"):
assert out_dtype in [BF16, "float32"] assert out_dtype in [BF16, "float32"]
M = T.symbolic("M") M = T.dynamic("M")
group_size = 128 group_size = 128
block_M = 32 block_M = 32
block_N = 128 block_N = 128
...@@ -192,9 +192,9 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, ...@@ -192,9 +192,9 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor,
@tilelang.jit(out_idx=[4], pass_configs=pass_configs) @tilelang.jit(out_idx=[4], pass_configs=pass_configs)
def fp8_index_kernel(h: int, d: int): def fp8_index_kernel(h: int, d: int):
b = T.symbolic("b") b = T.dynamic("b")
m = T.symbolic("m") m = T.dynamic("m")
n = T.symbolic("n") n = T.dynamic("n")
blk_n1 = 512 blk_n1 = 512
blk_n2 = 128 blk_n2 = 128
......
...@@ -37,9 +37,9 @@ def sparse_mla_fwd( ...@@ -37,9 +37,9 @@ def sparse_mla_fwd(
else: else:
sm_scale = sm_scale * 1.44269504 # log2(e) sm_scale = sm_scale * 1.44269504 # log2(e)
batch = T.symbolic("batch") batch = T.dynamic("batch")
seq_len = T.symbolic("seq_len") seq_len = T.dynamic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv") seq_len_kv = T.dynamic("seq_len_kv")
head_kv = heads // kv_group head_kv = heads // kv_group
q_shape = [batch, seq_len, heads, dim + tail_dim] q_shape = [batch, seq_len, heads, dim + tail_dim]
......
...@@ -26,8 +26,8 @@ def convert_to_uint32(x): ...@@ -26,8 +26,8 @@ def convert_to_uint32(x):
@tilelang.jit(pass_configs=pass_configs) @tilelang.jit(pass_configs=pass_configs)
def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
batch = T.symbolic("batch") batch = T.dynamic("batch")
seq_len = T.symbolic("seq_len") seq_len = T.dynamic("seq_len")
RADIX = 1 << 8 RADIX = 1 << 8
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
SMEM_INPUT_SIZE = 4096 # assume the threshold bucket size after first pass is less than 4K SMEM_INPUT_SIZE = 4096 # assume the threshold bucket size after first pass is less than 4K
......
...@@ -41,7 +41,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -41,7 +41,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main return main
M = 128 # M = T.symbolic("m") if you want to use dynamic shape M = 128 # M = T.dynamic("m") if you want to use dynamic shape
N = 128 N = 128
K = 32 K = 32
block_M = 128 block_M = 128
......
...@@ -48,7 +48,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -48,7 +48,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return matmul_relu_kernel return matmul_relu_kernel
M = 1024 # M = T.symbolic("m") if you want to use dynamic shape M = 1024 # M = T.dynamic("m") if you want to use dynamic shape
N = 1024 N = 1024
K = 1024 K = 1024
block_M = 128 block_M = 128
......
...@@ -24,7 +24,7 @@ def test_empty_kernel_lowering(): ...@@ -24,7 +24,7 @@ def test_empty_kernel_lowering():
@tilelang.jit @tilelang.jit
def _empty_with_dead_code_kernel(): def _empty_with_dead_code_kernel():
num_tokens = T.symbolic("num_tokens") num_tokens = T.dynamic("num_tokens")
@T.prim_func @T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "float32"]): def buggy_kernel(x: T.Tensor[(num_tokens,), "float32"]):
......
...@@ -395,14 +395,14 @@ def run_ctypes_dynamic_shape(M, ...@@ -395,14 +395,14 @@ def run_ctypes_dynamic_shape(M,
def test_ctypes_dynamic_shape(): def test_ctypes_dynamic_shape():
run_ctypes_dynamic_shape( run_ctypes_dynamic_shape(
T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_ctypes_dynamic_shape( run_ctypes_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), 768, False, False, "float16", "float16", "float16", 128, T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2) 256, 32, 2)
run_ctypes_dynamic_shape( run_ctypes_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), T.symbolic("k"), False, False, "float16", "float16", T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16",
"float16", 128, 256, 32, 2) "float16", 128, 256, 32, 2)
......
...@@ -404,14 +404,14 @@ def run_cython_dynamic_shape(M, ...@@ -404,14 +404,14 @@ def run_cython_dynamic_shape(M,
def test_cython_dynamic_shape(): def test_cython_dynamic_shape():
run_cython_dynamic_shape( run_cython_dynamic_shape(
T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cython_dynamic_shape( run_cython_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), 768, False, False, "float16", "float16", "float16", 128, T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2) 256, 32, 2)
run_cython_dynamic_shape( run_cython_dynamic_shape(
T.symbolic("m"), T.symbolic("n"), T.symbolic("k"), False, False, "float16", "float16", T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16",
"float16", 128, 256, 32, 2) "float16", 128, 256, 32, 2)
...@@ -473,7 +473,7 @@ def run_cython_dynamic_shape_with_out_idx(M, ...@@ -473,7 +473,7 @@ def run_cython_dynamic_shape_with_out_idx(M,
def test_cython_dynamic_shape_with_out_idx(): def test_cython_dynamic_shape_with_out_idx():
run_cython_dynamic_shape_with_out_idx( run_cython_dynamic_shape_with_out_idx(
T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def matmul_int_variable( def matmul_int_variable(
......
...@@ -83,7 +83,7 @@ def run_tilelang_copy_with_stride(M=1024, ...@@ -83,7 +83,7 @@ def run_tilelang_copy_with_stride(M=1024,
def test_tilelang_copy_with_stride(): def test_tilelang_copy_with_stride():
run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128) run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128)
run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.symbolic("NN"), block_M=128, block_N=128) run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.dynamic("NN"), block_M=128, block_N=128)
def tilelang_copy_bufferload(num_tokens, dtype="float16"): def tilelang_copy_bufferload(num_tokens, dtype="float16"):
......
...@@ -41,7 +41,7 @@ def assert_vectorize_access(M: int = 64, N: int = 64): ...@@ -41,7 +41,7 @@ def assert_vectorize_access(M: int = 64, N: int = 64):
def issue_1013_buggy_kernel(): def issue_1013_buggy_kernel():
# NOTE: This kernel is mainly to test some corner cases in boundary check # NOTE: This kernel is mainly to test some corner cases in boundary check
num_tokens = T.symbolic('num_tokens') num_tokens = T.dynamic('num_tokens')
num_threads = 128 num_threads = 128
@T.prim_func @T.prim_func
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from typing import Optional, Callable, Dict from typing import Optional
# from .parser import * # from .parser import *
# now is fully compatible with the upstream # now is fully compatible with the upstream
# tir script # tir script
...@@ -84,124 +84,10 @@ from .builtin import * # noqa: F401 ...@@ -84,124 +84,10 @@ from .builtin import * # noqa: F401
from .utils import index_to_coordinates # noqa: F401 from .utils import index_to_coordinates # noqa: F401
from .symbolics import dynamic, symbolic # noqa: F401
def symbolic(name: str, dtype: str = "int32"): from .annotations import ( # noqa: F401
""" use_swizzle, annotate_layout, annotate_safe_value, annotate_l2_hit_ratio,
Create a TIR symbolic variable. )
Parameters:
name (str): Identifier for the variable in generated TIR.
dtype (str): Data type string for the variable (e.g., "int32"). Defaults to "int32".
Returns:
tir.Var: A TIR variable with the given name and dtype for use in TIR/TensorIR kernels.
"""
return tir.Var(name, dtype)
def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
# If order is row, use rasterization2DRow, otherwise use rasterization2DColumn
# The panel size is the number of threads in a warp
# Use to improve the L2 Cache Locality
device_func = ("rasterization2DRow" if order == "row" else "rasterization2DColumn")
return attr(None, "threadblock_swizzle_pattern",
f"tl::{device_func}<{panel_size}>") if enable else None
def annotate_layout(layout_map: Dict):
"""Annotate the layout of the buffer
Args:
layout_map (Dict): a dictionary of buffer to layout
Returns:
block_attr: a block attribute
Example:
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.annotate_layout({A_shared: layout})
for i, j in T.Parallel(block_M, block_N):
A_shared[i, j] = A[by * block_M + i, bx * block_N + j]
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = A_shared[i, j]
return main
"""
# layout_map is a dictionary of buffer to layout
_layout_map = {}
for buffer, layout in layout_map.items():
if isinstance(layout, Layout):
_layout_map[buffer.data] = layout
elif isinstance(layout, Callable):
_layout_map[buffer.data] = Layout(buffer.shape, layout)
else:
raise ValueError(f"Invalid layout: {layout}")
return block_attr({"layout_map": _layout_map})
def annotate_safe_value(safe_value_map: Dict):
"""Annotate the safe value of the buffer.
A safe value of a buffer is the value that will be used when the
buffer is accessed out of bounds.
Args:
safe_value_map (dict): a dictionary of buffer to safe value
Returns:
block_attr: a block attribute
Example:
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.annotate_safe_value({A: safe_value})
for i, j in T.Parallel(block_M, block_N):
A_shared[i, j] = A[by * block_M + i - 10, bx * block_N + j]
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = A_shared[i, j]
return main
"""
# safe_value_map is a dictionary of buffer to safe value
_safe_value_map = {}
for buffer, safe_value in safe_value_map.items():
_safe_value_map[buffer.data] = safe_value
return block_attr({"safe_value_map": _safe_value_map})
def annotate_l2_hit_ratio(l2_hit_ratio_map: Dict):
"""Annotate the L2 hit ratio of the buffer, detailed explanation please refer to:
https://docs.nvidia.com/cuda/cuda-c-programming-guide/#l2-policy-for-persisting-accesses
Args:
l2_hit_ratio_map (dict): a dictionary of buffer to L2 hit ratio value
Example:
# 0.5 is the hit ratio
T.annotate_l2_hit_ratio({A: 0.5})
"""
_l2_hit_ratio_map = {}
for buffer, hit_ratio in l2_hit_ratio_map.items():
assert buffer.scope() == "global", "persistent L2 can only be applied to global buffers"
_l2_hit_ratio_map[buffer.data] = float(hit_ratio)
return block_attr({"l2_hit_ratio_map": _l2_hit_ratio_map})
def import_source(source: Optional[str] = None): def import_source(source: Optional[str] = None):
......
"""Annotation helpers exposed on the TileLang language surface."""
from typing import Callable, Dict
from tilelang.layout import Layout
from tvm.script.parser.tir import attr, block_attr
__all__ = [
"use_swizzle",
"annotate_layout",
"annotate_safe_value",
"annotate_l2_hit_ratio",
]
def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
"""Annotate a kernel to use a specific threadblock swizzle pattern."""
device_func = "rasterization2DRow" if order == "row" else "rasterization2DColumn"
if not enable:
return None
return attr(None, "threadblock_swizzle_pattern", f"tl::{device_func}<{panel_size}>")
def annotate_layout(layout_map: Dict):
"""Annotate the layout of the buffer."""
_layout_map = {}
for buffer, layout in layout_map.items():
if isinstance(layout, Layout):
_layout_map[buffer.data] = layout
elif isinstance(layout, Callable):
_layout_map[buffer.data] = Layout(buffer.shape, layout)
else:
raise ValueError(f"Invalid layout: {layout}")
return block_attr({"layout_map": _layout_map})
def annotate_safe_value(safe_value_map: Dict):
"""Annotate the safe value of the buffer."""
_safe_value_map = {}
for buffer, safe_value in safe_value_map.items():
_safe_value_map[buffer.data] = safe_value
return block_attr({"safe_value_map": _safe_value_map})
def annotate_l2_hit_ratio(l2_hit_ratio_map: Dict):
"""Annotate the L2 hit ratio of the buffer."""
_l2_hit_ratio_map = {}
for buffer, hit_ratio in l2_hit_ratio_map.items():
assert buffer.scope() == "global", "persistent L2 can only be applied to global buffers"
_l2_hit_ratio_map[buffer.data] = float(hit_ratio)
return block_attr({"l2_hit_ratio_map": _l2_hit_ratio_map})
"""Symbolic variable helpers exposed on the TileLang language surface."""
from tvm import tir
from tilelang.utils import deprecated
__all__ = ["dynamic", "symbolic"]
@deprecated("T.dynamic(...)", "tir.Var(...)", "v0.1.9")
def dynamic(name: str, dtype: str = "int32"):
"""
Create a TIR dynamic symbolic variable.
Parameters:
name (str): Identifier for the variable in generated TIR.
dtype (str): Data type string for the variable (e.g., "int32"). Defaults to "int32".
Returns:
tir.Var: A TIR variable with the given name and dtype for use in TIR/TensorIR kernels.
"""
return tir.Var(name, dtype)
@deprecated("T.symbolic(...)", "T.dynamic(...)")
def symbolic(name: str, dtype: str = "int32"):
"""Deprecated alias for `T.dynamic`."""
return tir.Var(name, dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment