Commit dca2fb48 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Introduce flag to visualize shared memory merge plan (#496)

* Remove debug print statement from block_sparse_attn_triton.py and implement a timeout handler in autotuner for function execution. This enhances the robustness of the autotuner by allowing it to handle timeouts gracefully.

* Enhance the autotuner module by adding a timeout handler for function execution, improving robustness in handling long-running tasks. This change includes the introduction of a custom TimeoutException and updates to the run_with_timeout function for better signal management.

* Add merge shared memory allocations pass and related configurations

- Introduced a new pass for merging shared memory allocations in GPU kernels, allowing for more efficient memory usage.
- Registered configuration options for debugging and controlling the merging behavior.
- Updated relevant files to integrate the new pass into the TileLang engine and transform modules.
- Adjusted import paths and added documentation for the new functionality.

* Reduce num_stages parameter in GEMM functions from 3 to 1 for improved performance in test_tilelang_kernel_gemm.py
parent cde1886f
...@@ -60,8 +60,6 @@ def _fwd_kernel_inner( ...@@ -60,8 +60,6 @@ def _fwd_kernel_inner(
mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n)
# print # print
if k_block_col_idx == 3:
print("mask_val", mask_val)
if mask_val == True: if mask_val == True:
start_n = k_block_col_idx * BLOCK_N start_n = k_block_col_idx * BLOCK_N
# -- compute qk ---- # -- compute qk ----
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
TVM_REGISTER_PASS_CONFIG_OPTION(kDebugMergeSharedMemoryAllocations, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
......
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
static constexpr const char *kDebugMergeSharedMemoryAllocations =
"tl.debug_merge_shared_memory_allocations";
static constexpr const char *kDisableTMALower = "tl.disable_tma_lower"; static constexpr const char *kDisableTMALower = "tl.disable_tma_lower";
static constexpr const char *kDisableSafeMemoryLegalize = static constexpr const char *kDisableSafeMemoryLegalize =
"tl.disable_safe_memory_legalize"; "tl.disable_safe_memory_legalize";
......
This diff is collapsed.
...@@ -354,7 +354,7 @@ def run_gemm_sr( ...@@ -354,7 +354,7 @@ def run_gemm_sr(
block_M, block_M,
block_N, block_N,
block_K, block_K,
num_stages=3, num_stages=1,
num_threads=128, num_threads=128,
): ):
program = matmul_sr( program = matmul_sr(
...@@ -470,7 +470,7 @@ def run_gemm_rs( ...@@ -470,7 +470,7 @@ def run_gemm_rs(
block_M, block_M,
block_N, block_N,
block_K, block_K,
num_stages=3, num_stages=1,
num_threads=128, num_threads=128,
): ):
program = matmul_rs( program = matmul_rs(
......
...@@ -17,6 +17,26 @@ import concurrent.futures ...@@ -17,6 +17,26 @@ import concurrent.futures
import torch import torch
import os import os
import sys import sys
import signal
class TimeoutException(Exception):
pass
def timeout_handler(signum, frame):
raise TimeoutException()
def run_with_timeout(func, timeout, *args, **kwargs):
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout)
try:
result = func(*args, **kwargs)
finally:
signal.alarm(0)
return result
# Configure logging for the autotuner module # Configure logging for the autotuner module
# TODO: Consider creating a common logger in utils # TODO: Consider creating a common logger in utils
...@@ -374,12 +394,8 @@ class AutoTuner: ...@@ -374,12 +394,8 @@ class AutoTuner:
# Cannot ThreadPoolExecutor to enforce timeout on target_fn execution # Cannot ThreadPoolExecutor to enforce timeout on target_fn execution
# Because tma init may behave strangely with one thread # Because tma init may behave strangely with one thread
# latency, ref_latency = target_fn(jit_context) # latency, ref_latency = target_fn(jit_context)
benchmark_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) latency, ref_latency = run_with_timeout(target_fn, timeout, jit_context)
future = benchmark_executor.submit( except TimeoutException:
functools.partial(device_wrapper, target_fn, torch.cuda.current_device()),
jit_context)
latency, ref_latency = future.result(timeout=timeout)
except concurrent.futures.TimeoutError:
logger.info( logger.info(
f"A timeout occurred while testing config {config}, checkout autotuner.log for more details" f"A timeout occurred while testing config {config}, checkout autotuner.log for more details"
) )
......
...@@ -128,7 +128,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -128,7 +128,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.AnnotateDeviceRegions()(mod) mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod) mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod) mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
mod = tilelang.transform.MakePackedAPI()(mod) mod = tilelang.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod) mod = tir.transform.LowerDeviceKernelLaunch()(mod)
......
...@@ -331,3 +331,14 @@ def EliminateStorageSyncForMBarrier(): ...@@ -331,3 +331,14 @@ def EliminateStorageSyncForMBarrier():
"""EliminateStorageSyncForMBarrier """EliminateStorageSyncForMBarrier
""" """
return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore
def MergeSharedMemoryAllocations():
"""MergeSharedMemoryAllocations
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MergeSharedMemoryAllocations() # type: ignore
...@@ -27,6 +27,9 @@ class PassConfigKey(str, Enum): ...@@ -27,6 +27,9 @@ class PassConfigKey(str, Enum):
TL_DISABLE_SAFE_MEMORY_ACCESS = "tl.disable_safe_memory_legalize" TL_DISABLE_SAFE_MEMORY_ACCESS = "tl.disable_safe_memory_legalize"
"""Disable safe memory access optimization. Default: False""" """Disable safe memory access optimization. Default: False"""
TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS = "tl.debug_merge_shared_memory_allocations"
"""Enable debug information for merge shared memory allocations. Default: False"""
# TIR related configs # TIR related configs
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir" TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True""" """Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment