Unverified Commit cdc67fc4 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[PassConfig] Introduce PassConfig `TL_STORAGE_REWRITE_DETECT_INPLACE` (#1089)

* • Enable configurable StorageRewrite inplace detection

  - Add kStorageRewriteDetectInplace constant and register the flag with PassContext so C++ code no longer hard-codes the key.
  - Wire StorageRewrite to include TileLang builtin constants and honor the new config toggle when deciding inplace reuse.
  - Document the flag across Python surfaces (PassConfigKey, JIT/autotuner docs) with usage guidance and simplified IR examples.

* lint fix

* add test

* lint fix
parent 0c7e7419
......@@ -33,6 +33,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool);
DataType cuTensorMapType() { return DataType::UInt(8, 128); }
......
......@@ -48,6 +48,8 @@ static constexpr const char *kEnablePTXASVerboseOutput =
static constexpr const char *kDisableVectorize256 = "tl.disable_vectorize_256";
static constexpr const char *kDisableWGMMA = "tl.disable_wgmma";
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
static constexpr const char *kStorageRewriteDetectInplace =
"tl.storage_rewrite_detect_inplace";
/*!
* \brief Whether to disable dynamic tail split
*
......
......@@ -38,6 +38,7 @@
#include <unordered_set>
#include <utility>
#include "../op/builtin.h"
#include "arith/int_operator.h"
#include "runtime/thread_storage_scope.h"
#include "tir/ir/buffer_common.h"
......@@ -1914,6 +1915,8 @@ using namespace tir::transform;
namespace transform {
Pass StorageRewrite() {
auto pass_func = [](PrimFunc f, const IRModule &m, PassContext ctx) {
bool detect_inplace =
ctx->GetConfig<Bool>(kStorageRewriteDetectInplace, Bool(false)).value();
bool enable_reuse = true;
bool reuse_require_exact_matched_dtype = false;
bool merge_static_smem =
......@@ -1939,9 +1942,9 @@ Pass StorageRewrite() {
reuse_require_exact_matched_dtype = true;
}
auto *n = f.CopyOnWrite();
n->body =
StoragePlanRewriter().Rewrite(std::move(n->body), true, enable_reuse,
reuse_require_exact_matched_dtype);
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), detect_inplace,
enable_reuse,
reuse_require_exact_matched_dtype);
// Parameters may not be rewritten, but internal allocations may.
// Vectorization of AllocateConst is currently disabled, as it has
// indexing issues for types that include padding (e.g. int8x3
......
import tilelang
import tilelang.testing
from tilelang import language as T
@tilelang.jit
def _compile_kernel_without_inplace():
num_tokens = T.symbolic("num_tokens")
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "float"]):
with T.Kernel(num_tokens, threads=32) as pid:
read = T.alloc_var("int")
read = x[pid]
write = T.alloc_var("int")
write = read * 2
x[pid] = write
return buggy_kernel
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_STORAGE_REWRITE_DETECT_INPLACE: True,
},)
def _compile_kernel_with_inplace():
num_tokens = T.symbolic("num_tokens")
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "float"]):
with T.Kernel(num_tokens, threads=32) as pid:
read = T.alloc_var("int")
read = x[pid]
write = T.alloc_var("int")
write = read * 2
x[pid] = write
return buggy_kernel
def _get_device_kernel_script(detect_inplace: bool) -> str:
if detect_inplace:
kernel = _compile_kernel_with_inplace()
else:
kernel = _compile_kernel_without_inplace()
source = kernel.get_kernel_source()
return source
def test_storage_rewrite_detect_inplace_toggle():
script_off = _get_device_kernel_script(detect_inplace=False)
script_on = _get_device_kernel_script(detect_inplace=True)
assert script_off.count("read = (read * 2);") == 0
assert script_on.count("read = (read * 2);") > 0
if __name__ == "__main__":
tilelang.testing.main()
......@@ -37,14 +37,7 @@ class CompileArgs:
target_host: Target host for cross-compilation (default: None).
verbose: Whether to enable verbose output (default: False).
pass_configs: Additional keyword arguments to pass to the Compiler PassContext.
Available options:
"tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False
"tl.disable_warp_specialized": bool, default: False
"tl.config_index_bitwidth": int, default: None
"tl.disable_dynamic_tail_split": bool, default: False
"tl.dynamic_vectorize_size_bits": int, default: 128
"tl.disable_safe_memory_legalize": bool, default: False
Refer to `tilelang.PassConfigKey` for supported options.
"""
out_idx: Optional[Union[List[int], int]] = None
......
......@@ -59,14 +59,7 @@ def compile(
Whether to enable verbose output (default: False).
pass_configs : dict, optional
Additional keyword arguments to pass to the Compiler PassContext.
Available options:
"tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False
"tl.disable_warp_specialized": bool, default: False
"tl.config_index_bitwidth": int, default: None
"tl.disable_dynamic_tail_split": bool, default: False
"tl.dynamic_vectorize_size_bits": int, default: 128
"tl.disable_safe_memory_legalize": bool, default: False
Refer to `tilelang.transform.PassConfigKey` for supported options.
"""
assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}"
if isinstance(compile_flags, str):
......
......@@ -71,11 +71,7 @@ class JITKernel(object):
Whether to enable verbose output (default: False).
pass_configs : dict, optional
Additional keyword arguments to pass to the Compiler PassContext.
Available options:
"tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False
"tl.disable_dynamic_tail_split": bool, default: False
"tl.dynamic_vectorize_size_bits": int, default: 128
Refer to `tilelang.PassConfigKey` for supported options.
from_database : bool, optional
Whether to create a TorchFunction from a database.
"""
......
......@@ -69,6 +69,46 @@ class PassConfigKey(str, Enum):
TL_FORCE_LET_INLINE = "tl.force_let_inline"
"""Force TileLang to inline let bindings during simplification. Default: False"""
TL_STORAGE_REWRITE_DETECT_INPLACE = "tl.storage_rewrite_detect_inplace"
"""Control StorageRewrite inplace detection.
When False (default) StorageRewrite keeps distinct temporaries for patterns
such as `dst[i] = f(src[i])`, avoiding implicit aliasing:
```
read = T.allocate([1], "int32", "local.var")
write = T.allocate([1], "int32", "local.var")
read_buf = T.Buffer((1,), "int32", data=read, scope="local.var")
write_buf = T.Buffer((1,), "int32", data=write, scope="local.var")
write_buf[0] = read_buf[0] * 2
f(write_buf[0])
```
Setting the flag to True allows StorageRewrite to reuse the `read` buffer
for the write when it can prove the update is safely inplace, producing IR
like:
```
read = T.allocate([1], "int32", "local.var")
read_buf = T.Buffer((1,), "int32", data=read, scope="local.var")
read_buf[0] = read_buf[0] * 2
f(read_buf[0])
```
This reduces local memory usage but introduces aliasing between the buffers.
Usage:
```python
from tilelang.transform import PassContext, PassConfigKey
with PassContext(
config={PassConfigKey.TL_STORAGE_REWRITE_DETECT_INPLACE.value: True}
):
mod = tilelang.transform.StorageRewrite()(mod)
```
"""
# TIR related configs
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment