Commit b5faf25a authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Add new examples for warp specialization and TMA integration (#448)

* [Refactor] Update KernelLaunch to clarify CPU and GPU kernel launch logic

* Added comments to distinguish between CPU and GPU kernel launch sections for better code readability.
* Changed the creation of empty blocks to use a consistent "root" identifier, enhancing clarity in frame management.

* [Refactor] Rename operations for consistency in lower_hopper_intrin and related files

* Updated function names from CamelCase to snake_case for better consistency across the codebase.
* Refactored calls to `CreateTMADescriptorOp`, `CreateListofMBarrierOp`, and similar functions to their new names: `create_tma_descriptor`, `create_list_of_mbarrier`, etc.
* Adjusted corresponding test cases to reflect these changes, ensuring compatibility with the new naming conventions.

* [Refactor] Rename operations to snake_case for consistency

* Updated function names from CamelCase to snake_case across various files, including `CreateTMADescriptorOp` to `create_tma_descriptor`, `GetMBarrierOp` to `get_mbarrier`, and others.
* Adjusted corresponding calls and definitions in the codebase to reflect these naming changes, ensuring uniformity and improved readability.
* Enhanced layout inference and loop partitioning logic to accommodate the new naming conventions.

* [Feature] Introduce Warp Specialization and Eliminate Storage Sync for MBarrier

* Added a new example `gemm_ws.py` demonstrating matrix multiplication with warp specialization using TileLang.
* Implemented `WarpSpecializeFrame` and `WarpSpecialize` functionality to manage warp group indices in TIR frames.
* Introduced `EliminateStorageSyncForMBarrier` transformation to optimize storage synchronization in mbarrier regions.
* Enhanced the TileLang API with new methods for retrieving block and thread extents.
* Updated the `LowerAndLegalize` and `OptimizeForTarget` functions to incorporate the new transformation.
* Improved layout inference and kernel launch logic for better performance and clarity.

* [Refactor] Clean up code formatting and improve readability

* Added blank lines for better separation of code blocks in `gemm_ws.py`, `phase.py`, `kernel.py`, and `warpgroup.py`.
* Reformatted the `tilelang.compile` call in `gemm_ws.py` for improved clarity.
* Updated comments in `warpgroup.py` to clarify the availability of the `WarpSpecialize` function for NVIDIA GPUs.
* Ensured consistent spacing and formatting across multiple files to enhance overall code readability.

* lint fix

* [Refactor] Update mbarrier functions for improved clarity and consistency

* Refactored `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to accept explicit parameters for better readability.
* Updated calls in `gemm_ws.py` to use the new function signatures, enhancing code clarity.
* Adjusted `warpgroup.py` to remove unused thread extent variable, streamlining the code.
* Added detailed docstrings to clarify usage examples for memory barrier functions.

* Added blank lines in `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` for improved code readability and separation of logical sections.

* [Feature] Add examples for warp specialization and TMA barrier integration

* Introduced three new example scripts: `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, and `example_warp_specialize_mla.py` demonstrating matrix multiplication with warp specialization and TMA barriers.
* Implemented kernel functions with shared memory allocation and memory barrier synchronization for improved performance.
* Enhanced the TileLang API with new methods for compiling and testing kernels in Python using PyTorch.
* Updated the `phase.py` to include TMA barrier injection in the optimization process.
* Improved documentation and comments for better clarity on usage and functionality.

* [Feature] Add example for warp specialization in GEMM with TMA barriers

* Introduced a new example script `example_warp_specialize_gemm_stage2.py` demonstrating matrix multiplication using warp specialization and TMA barriers.
* Implemented a kernel function with shared memory allocation and memory barrier synchronization for enhanced performance.
* Included functionality to compile the kernel into a PyTorch-compatible function and validate its correctness against PyTorch's reference implementation.
* Enhanced documentation and comments for clarity on usage and functionality.

* lint fix

* [Feature] Implement WarpSpecializedDetector for TMA and MBarrier Detection

* Added the `WarpSpecializedDetector` class to identify the presence of TMA operations and memory barrier operations within a given TIR statement.
* Enhanced the `WarpSpecialized` pass to utilize the detector, allowing for conditional substitution based on the detection results.
* Improved code organization by including necessary headers and utilizing the `IRVisitorWithAnalyzer` for analysis.
* This addition aims to optimize warp specialization by ensuring that only relevant functions are transformed, enhancing performance and correctness.

* lint fix

* [Feature] Add new examples for warp specialization and TMA integration

* Introduced multiple new example scripts demonstrating warp specialization techniques, including `example_warp_specialize_flashmla.py`, `example_warp_specialize_gemm_barrierpipe_stage2.py`, `example_warp_specialize_gemm_copy_0_gemm_1.py`, `example_warp_specialize_gemm_copy_1_gemm_0.py`, and `example_warp_specialize_gemm_softpipe_stage2.py`.
* Each example showcases matrix multiplication with warp specialization and TMA barriers, implementing kernel functions with shared memory allocation and memory barrier synchronization for enhanced performance.
* Added a test suite in `test_example_warp_specialize.py` to validate the functionality of the new examples.
* Updated the TileLang API to support these examples and improve kernel compilation and testing processes.
* Removed outdated example scripts to streamline the codebase and enhance clarity on available functionalities.

* lint fix

* Remove outdated example scripts for warp specialization and TMA integration to streamline the codebase. This includes `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, `example_warp_specialize_gemm_stage2.py`, and `example_warp_specialize_mla.py`, which are no longer needed following recent updates and improvements in the TileLang API.
parent fce16b00
...@@ -60,9 +60,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -60,9 +60,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.IfStmtBinding()(mod) mod = tilelang.transform.IfStmtBinding()(mod)
mod = tilelang.transform.MultiVersionBuffer()(mod) mod = tilelang.transform.MultiVersionBuffer()(mod)
mod = tilelang.transform.WarpSpecialized()(mod) mod = tilelang.transform.WarpSpecialized()(mod)
mod = tilelang.transform.InjectTmaBarrier()(mod)
# if tma is not enabled, we can also do pipeline planning # if tma is not enabled, we can also do pipeline planning
# to get better performance with async copy # to get better performance with async copy
mod = tilelang.transform.InjectTmaBarrier()(mod)
mod = tilelang.transform.PipelinePlanning()(mod) mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod)
# warp_specialized pass will pack the if stmt into the block # warp_specialized pass will pack the if stmt into the block
...@@ -78,10 +78,10 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -78,10 +78,10 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tilelang.transform.MergeIfStmt()(mod) mod = tilelang.transform.MergeIfStmt()(mod)
if allow_fence_proxy(target=target): if allow_fence_proxy(target=target):
# in hopper device, wgmma is an async proxy # in hopper device, wgmma is an async proxy
# so we need to inject a fence proxy before it # so we need to inject a fence proxy before it
mod = tilelang.transform.InjectFenceProxy()(mod) mod = tilelang.transform.InjectFenceProxy()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod) mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.FlattenBuffer()(mod) mod = tilelang.transform.FlattenBuffer()(mod)
......
...@@ -141,6 +141,7 @@ def compile( ...@@ -141,6 +141,7 @@ def compile(
"tl.config_index_bitwidth": int, default: None "tl.config_index_bitwidth": int, default: None
"tl.disable_dynamic_tail_split": bool, default: False "tl.disable_dynamic_tail_split": bool, default: False
"tl.dynamic_vectorize_size_bits": int, default: 128 "tl.dynamic_vectorize_size_bits": int, default: 128
"tl.disable_safe_memory_legalize": bool, default: False
""" """
return cached( return cached(
func=func, func=func,
......
...@@ -111,28 +111,40 @@ def tma_store_wait(*args): ...@@ -111,28 +111,40 @@ def tma_store_wait(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.tma_store_wait"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.tma_store_wait"), *args)
def set_max_nreg(*args): def set_max_nreg(reg_count: int, is_inc: int):
"""Set the maximum number of registers to use. """Set the maximum number of registers to use.
Detailed Documentation:
https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg
Args: Args:
*args: Variable arguments specifying register allocation limits reg_count: int
The number of registers to allocate
is_inc: int
Whether to increment or decrement the register count
0 if decrement, 1 if increment
Returns: Returns:
tir.Call: A handle to the register setting operation tir.Call: A handle to the register setting operation
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.set_max_nreg"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.set_max_nreg"), reg_count, is_inc)
def no_set_max_nreg(*args): def inc_max_nreg(reg_count: int):
"""Disable the maximum register limit setting. """Increment the maximum number of registers to use.
"""
return set_max_nreg(reg_count, 1)
Args:
*args: Variable arguments for the operation
Returns: def dec_max_nreg(reg_count: int):
tir.Call: A handle to the register limit disable operation """Decrement the maximum number of registers to use.
"""
return set_max_nreg(reg_count, 0)
def no_set_max_nreg():
"""Disable the maximum register limit setting.
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg"), *args) return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg"))
def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union[int, Var]): def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union[int, Var]):
......
...@@ -4,6 +4,7 @@ from tvm.script.ir_builder.tir.frame import TIRFrame ...@@ -4,6 +4,7 @@ from tvm.script.ir_builder.tir.frame import TIRFrame
from tvm._ffi import register_object from tvm._ffi import register_object
from tilelang import _ffi_api from tilelang import _ffi_api
from .kernel import get_thread_bindings, get_thread_extents from .kernel import get_thread_bindings, get_thread_extents
from typing import List
@register_object("tl.WarpSpecializeFrame") @register_object("tl.WarpSpecializeFrame")
...@@ -14,7 +15,7 @@ class WarpSpecializeFrame(TIRFrame): ...@@ -14,7 +15,7 @@ class WarpSpecializeFrame(TIRFrame):
""" """
def WarpSpecialize(warp_group_idx: int,): def WarpSpecialize(*warp_group_idx):
"""Tools to construct a warp group frame. """Tools to construct a warp group frame.
Parameters Parameters
...@@ -28,6 +29,10 @@ def WarpSpecialize(warp_group_idx: int,): ...@@ -28,6 +29,10 @@ def WarpSpecialize(warp_group_idx: int,):
------- -------
res : Tuple[frame.LaunchThreadFrame] res : Tuple[frame.LaunchThreadFrame]
The result LaunchThreadFrame. The result LaunchThreadFrame.
Examples:
>>> T.ws(0) -> if tx < 128
>>> T.ws(1) -> if tx >= 128 and tx < 256
>>> T.ws(0, 1) -> if tx < 128 or (tx >= 128 and tx < 256)
""" """
id_x, id_y, id_z = get_thread_bindings() id_x, id_y, id_z = get_thread_bindings()
ex_x, ex_y, _ = get_thread_extents() ex_x, ex_y, _ = get_thread_extents()
...@@ -35,7 +40,13 @@ def WarpSpecialize(warp_group_idx: int,): ...@@ -35,7 +40,13 @@ def WarpSpecialize(warp_group_idx: int,):
# only available for nvidia gpus. # only available for nvidia gpus.
warp_group_size = 128 warp_group_size = 128
return _ffi_api.WarpSpecialize(warp_group_idx, tid, warp_group_size) warp_group_ids: List[int] = []
for warp_group_id in warp_group_idx:
warp_group_ids.append(warp_group_id)
assert len(warp_group_ids) > 0, "warp_group_idx must be non-empty"
return _ffi_api.WarpSpecialize(warp_group_ids, tid, warp_group_size)
# Alias for WarpSpecialize for more concise usage # Alias for WarpSpecialize for more concise usage
......
...@@ -4,7 +4,9 @@ import pytest ...@@ -4,7 +4,9 @@ import pytest
import random import random
import torch import torch
import numpy as np import numpy as np
from tilelang.contrib import nvcc
from tvm.testing.utils import * from tvm.testing.utils import *
from tvm.testing.utils import _compose
from tilelang.utils.tensor import torch_assert_close as torch_assert_close from tilelang.utils.tensor import torch_assert_close as torch_assert_close
...@@ -21,3 +23,82 @@ def set_random_seed(seed: int = 42) -> None: ...@@ -21,3 +23,82 @@ def set_random_seed(seed: int = 42) -> None:
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
def requires_cuda_compute_version(major_version, minor_version=0, mode="ge"):
"""Mark a test as requiring at least a compute architecture
Unit test marked with this decorator will run only if the CUDA
compute architecture of the GPU is at least `(major_version,
minor_version)`.
This also marks the test as requiring a cuda support.
Parameters
----------
major_version: int
The major version of the (major,minor) version tuple.
minor_version: int
The minor version of the (major,minor) version tuple.
mode: str
The mode of the comparison.
- "ge": greater than or equal to
- "gt": greater than
- "le": less than or equal to
- "lt": less than
"""
min_version = (major_version, minor_version)
try:
arch = nvcc.get_target_compute_version()
compute_version = nvcc.parse_compute_version(arch)
except ValueError:
# No GPU present. This test will be skipped from the
# requires_cuda() marks as well.
compute_version = (0, 0)
min_version_str = ".".join(str(v) for v in min_version)
compute_version_str = ".".join(str(v) for v in compute_version)
def compare(compute_version, min_version, mode) -> bool:
if mode == "ge":
return compute_version >= min_version
elif mode == "gt":
return compute_version > min_version
elif mode == "le":
return compute_version <= min_version
elif mode == "lt":
return compute_version < min_version
elif mode == "eq":
return compute_version == min_version
else:
raise ValueError(f"Invalid mode: {mode}")
requires = [
pytest.mark.skipif(
not compare(compute_version, min_version, mode),
reason=f"Requires CUDA compute {mode} {min_version_str}, but have {compute_version_str}",
),
*requires_cuda.marks(),
]
def inner(func):
return _compose([func], requires)
return inner
def requires_cuda_compute_version_ge(major_version, minor_version=0):
return requires_cuda_compute_version(major_version, minor_version, mode="ge")
def requires_cuda_compute_version_gt(major_version, minor_version=0):
return requires_cuda_compute_version(major_version, minor_version, mode="gt")
def requires_cuda_compute_version_eq(major_version, minor_version=0):
return requires_cuda_compute_version(major_version, minor_version, mode="eq")
...@@ -24,6 +24,9 @@ class PassConfigKey(str, Enum): ...@@ -24,6 +24,9 @@ class PassConfigKey(str, Enum):
TL_DISABLE_TMA_LOWER = "tl.disable_tma_lower" TL_DISABLE_TMA_LOWER = "tl.disable_tma_lower"
"""Disable TMA (Tensor Memory Access) lowering. Default: False""" """Disable TMA (Tensor Memory Access) lowering. Default: False"""
TL_DISABLE_SAFE_MEMORY_ACCESS = "tl.disable_safe_memory_legalize"
"""Disable safe memory access optimization. Default: False"""
# TIR related configs # TIR related configs
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir" TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True""" """Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment