"git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "8f5c671e4c4f3874d6ee4d73155134b62374b9e2"
Commit dca2fb48 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Introduce flag to visualize shared memory merge plan (#496)

* Remove debug print statement from block_sparse_attn_triton.py and implement a timeout handler in autotuner for function execution. This enhances the robustness of the autotuner by allowing it to handle timeouts gracefully.

* Enhance the autotuner module by adding a timeout handler for function execution, improving robustness in handling long-running tasks. This change includes the introduction of a custom TimeoutException and updates to the run_with_timeout function for better signal management.

* Add merge shared memory allocations pass and related configurations

- Introduced a new pass for merging shared memory allocations in GPU kernels, allowing for more efficient memory usage.
- Registered configuration options for debugging and controlling the merging behavior.
- Updated relevant files to integrate the new pass into the TileLang engine and transform modules.
- Adjusted import paths and added documentation for the new functionality.

* Reduce num_stages parameter in GEMM functions from 3 to 1 for improved performance in test_tilelang_kernel_gemm.py
parent cde1886f
......@@ -60,8 +60,6 @@ def _fwd_kernel_inner(
mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n)
# print
if k_block_col_idx == 3:
print("mask_val", mask_val)
if mask_val == True:
start_n = k_block_col_idx * BLOCK_N
# -- compute qk ----
......
......@@ -16,6 +16,7 @@
namespace tvm {
namespace tl {
TVM_REGISTER_PASS_CONFIG_OPTION(kDebugMergeSharedMemoryAllocations, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
......
......@@ -12,7 +12,8 @@
namespace tvm {
namespace tl {
static constexpr const char *kDebugMergeSharedMemoryAllocations =
"tl.debug_merge_shared_memory_allocations";
static constexpr const char *kDisableTMALower = "tl.disable_tma_lower";
static constexpr const char *kDisableSafeMemoryLegalize =
"tl.disable_safe_memory_legalize";
......
This diff is collapsed.
......@@ -354,7 +354,7 @@ def run_gemm_sr(
block_M,
block_N,
block_K,
num_stages=3,
num_stages=1,
num_threads=128,
):
program = matmul_sr(
......@@ -470,7 +470,7 @@ def run_gemm_rs(
block_M,
block_N,
block_K,
num_stages=3,
num_stages=1,
num_threads=128,
):
program = matmul_rs(
......
......@@ -17,6 +17,26 @@ import concurrent.futures
import torch
import os
import sys
import signal
class TimeoutException(Exception):
pass
def timeout_handler(signum, frame):
raise TimeoutException()
def run_with_timeout(func, timeout, *args, **kwargs):
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout)
try:
result = func(*args, **kwargs)
finally:
signal.alarm(0)
return result
# Configure logging for the autotuner module
# TODO: Consider creating a common logger in utils
......@@ -374,12 +394,8 @@ class AutoTuner:
# Cannot ThreadPoolExecutor to enforce timeout on target_fn execution
# Because tma init may behave strangely with one thread
# latency, ref_latency = target_fn(jit_context)
benchmark_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
future = benchmark_executor.submit(
functools.partial(device_wrapper, target_fn, torch.cuda.current_device()),
jit_context)
latency, ref_latency = future.result(timeout=timeout)
except concurrent.futures.TimeoutError:
latency, ref_latency = run_with_timeout(target_fn, timeout, jit_context)
except TimeoutException:
logger.info(
f"A timeout occurred while testing config {config}, checkout autotuner.log for more details"
)
......
......@@ -128,7 +128,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
mod = tilelang.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
......
......@@ -331,3 +331,14 @@ def EliminateStorageSyncForMBarrier():
"""EliminateStorageSyncForMBarrier
"""
return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore
def MergeSharedMemoryAllocations():
"""MergeSharedMemoryAllocations
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MergeSharedMemoryAllocations() # type: ignore
......@@ -27,6 +27,9 @@ class PassConfigKey(str, Enum):
TL_DISABLE_SAFE_MEMORY_ACCESS = "tl.disable_safe_memory_legalize"
"""Disable safe memory access optimization. Default: False"""
TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS = "tl.debug_merge_shared_memory_allocations"
"""Enable debug information for merge shared memory allocations. Default: False"""
# TIR related configs
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment