"src/vscode:/vscode.git/clone" did not exist on "5b11099a0bbcd641d1df3b3d46010aebc791a149"
Unverified Commit cb37bfef authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Refactor barrier management (#744)

* Introduce Barrier

* Enhance CUDA kernel with new barrier management and post-processing support

- Added a new CUDA kernel implementation in `example_mla_decode.py` for improved performance with shared memory barriers.
- Refactored barrier handling in `codegen_cuda.cc` and `codegen_hip.cc` to utilize a more flexible mbarrier structure.
- Updated intrinsic definitions from `ptx_stmatirx` to `ptx_stmatrix` across multiple files for consistency.
- Introduced additional print statements for debugging in the lowering phase of the TileLang engine.
- Enhanced the overall structure and readability of the codebase.

* Remove unused barrier handling code in CUDA and HIP code generators to streamline the implementation. This change enhances code clarity and reduces complexity in the barrier management logic.

* Enhance barrier management in TileLang

- Introduced a new intrinsic `allocate_barrier` for dynamic barrier allocation in the TileLang framework.
- Updated CUDA code generation to support the new barrier structure, allowing for improved synchronization in shared memory.
- Refactored existing barrier handling logic to accommodate the new intrinsic and streamline code.
- Added print statements for debugging purposes in various examples and the lowering phase of the TileLang engine.
- Removed deprecated memory scope handling code to enhance clarity and maintainability.

* lint fix

* lint fix

* Remove `allocate_barrier` intrinsic and related code from TileLang to streamline barrier management. This includes updates to CUDA code generation and the removal of associated Python wrappers, enhancing code clarity and maintainability.

* Refactor logging in JITKernel to improve kernel compilation tracking

- Removed unused import of `torch.backends` in the example file.
- Introduced logging for kernel compilation in `JITKernel`, replacing print statements with structured logging for better traceability and debugging.
- Added an assertion to ensure the presence of the `global_symbol` attribute in the kernel function.

* Refactor dequantization tests and update barrier function

- Removed the test for `example_dequant_gemm_bf16_fp4_hopper_serial` to streamline the testing suite.
- Updated the `mbarrier_cp_async_arrive` function to support both pointer and non-pointer types, enhancing flexibility in barrier management.

* Update CI configuration to increase pytest parallelism from 4 to 8 threads for improved test execution speed.

* Fix typos in rasterization parameters and update import path for cached module

- Corrected the spelling of `enable_rasteration` to `enable_rasterization` in the matmul function and its usage.
- Updated the import statement for the `cached` module to reflect the new path in the cache submodule.
- Added `StridedTensor` import in the language module for enhanced tensor functionality.

* Update ci.yml
parent eccdfe17
......@@ -117,7 +117,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tilelang.transform.MergeIfStmt()(mod)
if allow_fence_proxy(target=target):
# in hopper device, wgmma is an async proxy
# so we need to inject a fence proxy before it
......@@ -129,7 +128,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# as it will flatten index computing
mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
mod = tilelang.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
......@@ -155,7 +153,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LowerThreadAllreduce()(mod)
mod = tilelang.transform.LowerHopperIntrin()(mod)
# Global Barrier Synchronization must be applied before
# SplitHostDevice pass, as the global barrier
if allow_global_thread_synchronization():
......
......@@ -11,6 +11,9 @@ from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, Cython
NVRTCKernelAdapter, TorchDLPackKernelAdapter)
from tilelang.profiler import Profiler, TensorSupplyType
from tilelang.utils.target import AVALIABLE_TARGETS, determine_target
import logging
logger = logging.getLogger(__name__)
class JITKernel(object):
......@@ -115,7 +118,10 @@ class JITKernel(object):
# NOTE(Chenggang): printing could let the training/inference framework easier to know
# whether the communication timeout is from compilation
if env.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on"):
print(f"TileLang begins to compile kernel `{func.__name__}` with `{out_idx=}`")
# assert func must have "global_symbol"
func_name = func.attrs.get("global_symbol")
assert func_name is not None, "func must have global_symbol"
logger.info(f"TileLang begins to compile kernel `{func_name}` with `{out_idx=}`")
# Compile the TileLang function and create a kernel adapter for execution.
adapter = self._compile_and_create_adapter(func, out_idx)
......
......@@ -17,6 +17,7 @@ from .proxy import (
make_tensor, # noqa: F401
Buffer, # noqa: F401
Tensor, # noqa: F401
StridedTensor, # noqa: F401
FragmentBuffer, # noqa: F401
SharedBuffer, # noqa: F401
LocalBuffer, # noqa: F401
......@@ -67,7 +68,6 @@ from .customize import (
from .logical import any_of, all_of # noqa: F401
from .builtin import * # noqa: F401
from .memscope import * # noqa: F401
from .utils import index_to_coordinates # noqa: F401
......
from tvm.ffi.registry import register_func
from tvm.ir import make_node
@register_func("tvm.info.mem.local.var")
def mem_info_local_var():
"""Get memory information for local variable memory.
Returns:
tvm.ir.make_node: A node containing memory information
"""
return make_node(
"target.MemoryInfo",
unit_bits=8,
max_num_bits=64,
max_simd_bits=128,
head_address=None,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment