Unverified Commit b78d8404 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Expose `T.get_warp_idx_sync` and `T.shuffle_elect` for efficient thread election (#989)



* Expose CUDA warp/lane intrinsics in TileLang frontend

* generalize warp indexing intrinsics and add coverage

* [Lint]: [pre-commit.ci] auto fixes [...]

---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 32ddc1ac
...@@ -218,6 +218,26 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait) ...@@ -218,6 +218,26 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(get_lane_idx)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(get_warp_idx_sync)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(get_warp_idx)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(get_warp_group_idx)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(wait_wgmma) TIR_DEFINE_TL_BUILTIN(wait_wgmma)
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
......
...@@ -358,6 +358,38 @@ TVM_DLL const Op &warpgroup_commit_batch(); ...@@ -358,6 +358,38 @@ TVM_DLL const Op &warpgroup_commit_batch();
*/ */
TVM_DLL const Op &warpgroup_wait(); TVM_DLL const Op &warpgroup_wait();
/*!
* \brief Return the canonical lane index for the calling thread.
*
* get_lane_idx([warp_size])
*
*/
TVM_DLL const Op &get_lane_idx();
/*!
* \brief Return the canonical warp index, assuming converged threads.
*
* get_warp_idx_sync([warp_size])
*
*/
TVM_DLL const Op &get_warp_idx_sync();
/*!
* \brief Return the canonical warp index without synchronizing the warp.
*
* get_warp_idx([warp_size])
*
*/
TVM_DLL const Op &get_warp_idx();
/*!
* \brief Return the canonical warp group index for converged threads.
*
* get_warp_group_idx([warp_size, warps_per_group])
*
*/
TVM_DLL const Op &get_warp_group_idx();
/*! /*!
* \brief Wait the previous wgmma to finish * \brief Wait the previous wgmma to finish
* *
......
...@@ -1968,6 +1968,41 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1968,6 +1968,41 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
enable_sparse_gemm_ = true; enable_sparse_gemm_ = true;
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value, this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
op->args, true, os); op->args, true, os);
} else if (op->op.same_as(tl::get_lane_idx())) {
ICHECK_LE(op->args.size(), 1)
<< "tl.get_lane_idx expects at most one argument <warp_size>.";
os << "tl::get_lane_idx(";
if (!op->args.empty()) {
os << PrintExpr(op->args[0]);
}
os << ")";
} else if (op->op.same_as(tl::get_warp_idx_sync())) {
ICHECK_LE(op->args.size(), 1)
<< "tl.get_warp_idx_sync expects at most one argument <warp_size>.";
os << "tl::get_warp_idx_sync(";
if (!op->args.empty()) {
os << PrintExpr(op->args[0]);
}
os << ")";
} else if (op->op.same_as(tl::get_warp_idx())) {
ICHECK_LE(op->args.size(), 1)
<< "tl.get_warp_idx expects at most one argument <warp_size>.";
os << "tl::get_warp_idx(";
if (!op->args.empty()) {
os << PrintExpr(op->args[0]);
}
os << ")";
} else if (op->op.same_as(tl::get_warp_group_idx())) {
ICHECK_LE(op->args.size(), 2)
<< "tl.get_warp_group_idx expects <warp_size, warps_per_group>.";
os << "tl::get_warp_group_idx(";
for (size_t i = 0; i < op->args.size(); ++i) {
if (i != 0) {
os << ", ";
}
os << PrintExpr(op->args[i]);
}
os << ")";
} else if (op->op.same_as(tl::tl_shuffle_elect())) { } else if (op->op.same_as(tl::tl_shuffle_elect())) {
os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()"; os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()";
} else if (op->op.same_as(tl::initialize_descriptor())) { } else if (op->op.same_as(tl::initialize_descriptor())) {
......
#pragma once #pragma once
#include "common.h"
#include "cutlass/cutlass.h"
#if __CUDA_ARCH_LIST__ >= 900 #if __CUDA_ARCH_LIST__ >= 900
#include "cute/arch/cluster_sm90.hpp" #include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/mma_sm90_gmma.hpp" #include "cute/arch/mma_sm90_gmma.hpp"
#include "cutlass/cutlass.h" #endif
namespace tl { namespace tl {
namespace detail {
// Provide architecture-specific defaults so callers may omit arguments.
TL_DEVICE constexpr int default_warp_size() {
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP_DEVICE_COMPILE__)
return 64;
#else
return 32;
#endif
}
TL_DEVICE constexpr int default_warps_per_group() { return 4; }
TL_DEVICE int linear_thread_idx_in_block() {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
#else
return 0;
#endif
}
} // namespace detail
TL_DEVICE int get_lane_idx(int warp_size = detail::default_warp_size()) {
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
return detail::linear_thread_idx_in_block() % warp_size;
}
TL_DEVICE int get_warp_idx_sync(int warp_size = detail::default_warp_size()) {
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
return detail::linear_thread_idx_in_block() / warp_size;
}
TL_DEVICE int get_warp_idx(int warp_size = detail::default_warp_size()) {
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
return detail::linear_thread_idx_in_block() / warp_size;
}
TL_DEVICE int
get_warp_group_idx(int warp_size = detail::default_warp_size(),
int warps_per_group = detail::default_warps_per_group()) {
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
warps_per_group =
warps_per_group > 0 ? warps_per_group : detail::default_warps_per_group();
int threads_per_group = warp_size * warps_per_group;
threads_per_group = threads_per_group > 0 ? threads_per_group : warp_size;
return detail::linear_thread_idx_in_block() / threads_per_group;
}
#if __CUDA_ARCH_LIST__ >= 900
TL_DEVICE void warpgroup_arrive() { cute::warpgroup_arrive(); } TL_DEVICE void warpgroup_arrive() { cute::warpgroup_arrive(); }
TL_DEVICE void warpgroup_commit_batch() { cute::warpgroup_commit_batch(); } TL_DEVICE void warpgroup_commit_batch() { cute::warpgroup_commit_batch(); }
...@@ -61,5 +114,6 @@ template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() { ...@@ -61,5 +114,6 @@ template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() {
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() { template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount)); asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
} }
} // namespace tl
#endif #endif
} // namespace tl
from typing import Optional
import tilelang.language as T
import tilelang.testing
import torch
from tilelang.utils.target import check_hip_availability
_IS_HIP_AVAILABLE = check_hip_availability()
_DEFAULT_WARPS_PER_GROUP = 4
def _resolve_warp_size(warp_size: Optional[int]) -> int:
if warp_size is not None:
return int(warp_size)
return 64 if _IS_HIP_AVAILABLE else 32
def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int:
if warps_per_group is not None:
return int(warps_per_group)
return _DEFAULT_WARPS_PER_GROUP
@tilelang.jit(out_idx=[-1])
def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@T.prim_func
def laneid_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding()
A[tx] = T.get_lane_idx(warp_size)
return laneid_kernel
@tilelang.jit(out_idx=[-1])
def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@T.prim_func
def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding()
A[tx] = T.get_warp_idx_sync(warp_size)
return warp_idx_sync_kernel
@tilelang.jit(out_idx=[-1])
def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@T.prim_func
def warp_idx_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding()
A[tx] = T.get_warp_idx(warp_size)
return warp_idx_kernel
@tilelang.jit(out_idx=[-1])
def _get_warp_group_idx_kernel(
num_threads: int = 128,
warp_size: Optional[int] = None,
warps_per_group: Optional[int] = None,
):
@T.prim_func
def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding()
A[tx] = T.get_warp_group_idx(warp_size, warps_per_group)
return warp_group_idx_kernel
@tilelang.jit(out_idx=[-1])
def _shuffle_elect_kernel(num_threads: int = 128, thread_extent: int = 64):
@T.prim_func
def shuffle_elect_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding()
elected = T.shuffle_elect(thread_extent)
A[tx] = elected
return shuffle_elect_kernel
def run_get_lane_id(num_threads: int = 128, warp_size: Optional[int] = None):
kernel = _get_laneid_kernel(num_threads, warp_size)
A = kernel()
print(kernel.get_kernel_source())
print(A)
expected_warp_size = _resolve_warp_size(warp_size)
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) % expected_warp_size
torch.testing.assert_close(A.cpu(), ref.cpu())
return A
def run_get_warp_idx_sync(num_threads: int = 128, warp_size: Optional[int] = None):
kernel = _get_warp_idx_sync_kernel(num_threads, warp_size)
A = kernel()
print(kernel.get_kernel_source())
print(A)
expected_warp_size = _resolve_warp_size(warp_size)
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // expected_warp_size
torch.testing.assert_close(A.cpu(), ref.cpu())
return A
def run_get_warp_idx(num_threads: int = 128, warp_size: Optional[int] = None):
kernel = _get_warp_idx_kernel(num_threads, warp_size)
A = kernel()
print(kernel.get_kernel_source())
print(A)
expected_warp_size = _resolve_warp_size(warp_size)
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // expected_warp_size
torch.testing.assert_close(A.cpu(), ref.cpu())
return A
def run_get_warp_group_idx(
num_threads: int = 128,
warp_size: Optional[int] = None,
warps_per_group: Optional[int] = None,
):
kernel = _get_warp_group_idx_kernel(num_threads, warp_size, warps_per_group)
A = kernel()
print(kernel.get_kernel_source())
print(A)
expected_warp_size = _resolve_warp_size(warp_size)
expected_warps_per_group = _resolve_warps_per_group(warps_per_group)
threads_per_group = expected_warp_size * expected_warps_per_group
if threads_per_group <= 0:
raise ValueError("threads_per_group must be positive.")
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // threads_per_group
torch.testing.assert_close(A.cpu(), ref.cpu())
return A
def run_shuffle_elect(num_threads: int = 128, thread_extent: int = 64):
if thread_extent < 0:
raise ValueError("thread_extent must be non-negative.")
kernel = _shuffle_elect_kernel(num_threads, thread_extent)
A = kernel()
print(kernel.get_kernel_source())
print(A)
indices = torch.arange(num_threads, device=A.device, dtype=torch.int64)
if thread_extent == 0:
mask = indices == 0
elif thread_extent > 0:
mask = (indices % thread_extent) == 0
else:
mask = torch.zeros_like(indices, dtype=torch.bool)
ref = mask.to(dtype=A.dtype, device=A.device)
torch.testing.assert_close(A.cpu(), ref.cpu())
return A
@tilelang.testing.requires_cuda
def test_get_lane_idx_default():
run_get_lane_id()
@tilelang.testing.requires_cuda
def test_get_lane_idx_custom():
run_get_lane_id(num_threads=256, warp_size=64)
@tilelang.testing.requires_cuda
def test_get_warp_idx_sync_default():
run_get_warp_idx_sync()
@tilelang.testing.requires_cuda
def test_get_warp_idx_sync_custom():
run_get_warp_idx_sync(num_threads=256, warp_size=16)
@tilelang.testing.requires_cuda
def test_get_warp_idx_default():
run_get_warp_idx()
@tilelang.testing.requires_cuda
def test_get_warp_idx_custom():
run_get_warp_idx(num_threads=320, warp_size=20)
@tilelang.testing.requires_cuda
def test_get_warp_group_idx_default():
run_get_warp_group_idx()
@tilelang.testing.requires_cuda
def test_get_warp_group_idx_custom():
run_get_warp_group_idx(num_threads=512, warp_size=32, warps_per_group=5)
@tilelang.testing.requires_cuda
def test_shuffle_elect_default():
run_shuffle_elect(num_threads=256, thread_extent=64)
@tilelang.testing.requires_cuda
def test_shuffle_elect_block_leader():
run_shuffle_elect(num_threads=128, thread_extent=0)
if __name__ == "__main__":
tilelang.testing.main()
# run_get_lane_id()
...@@ -5,12 +5,26 @@ from tilelang.language import ptx_arrive_barrier, evaluate ...@@ -5,12 +5,26 @@ from tilelang.language import ptx_arrive_barrier, evaluate
from tilelang.language.kernel import get_thread_bindings, get_block_extents from tilelang.language.kernel import get_thread_bindings, get_block_extents
from tilelang.utils.target import check_hip_availability from tilelang.utils.target import check_hip_availability
from tvm import tir from tvm import tir
from typing import Union, Any from typing import Union, Any, Optional
from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad
_IS_HIP_AVAILABLE = check_hip_availability() _IS_HIP_AVAILABLE = check_hip_availability()
def _normalize_index_arg(value: Optional[Union[int, PrimExpr]]) -> Optional[PrimExpr]:
"""
Normalize warp sizing arguments so both Python ints and PrimExpr values
are accepted uniformly.
"""
if value is None:
return None
if isinstance(value, PrimExpr):
return value
if isinstance(value, int):
return tir.IntImm("int32", value)
raise TypeError(f"Expect warp sizing argument to be int or PrimExpr, but got {type(value)}.")
def create_list_of_mbarrier(*args: Any) -> Call: def create_list_of_mbarrier(*args: Any) -> Call:
""" """
Create a list of memory barrier handles. Create a list of memory barrier handles.
...@@ -280,6 +294,140 @@ def warpgroup_wait(num_mma: int): ...@@ -280,6 +294,140 @@ def warpgroup_wait(num_mma: int):
return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma) return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma)
def get_lane_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
"""Return the logical lane index of the calling thread within a warp.
Parameters
----------
warp_size : Optional[int, PrimExpr]
Logical warp (or wavefront) size. Defaults to 32 on NVIDIA and 64 on AMD.
Example
-------
>>> lane = T.get_lane_idx()
>>> custom_lane = T.get_lane_idx(64) # override warp size explicitly
Implementation Notes
--------------------
Lowers to the CUDA helper `tl::get_lane_idx(warp_size)` defined in
`src/tl_templates/cuda/intrin.h`, which computes the lane index from the
linear thread id using the provided `warp_size`.
"""
warp_size_expr = _normalize_index_arg(warp_size)
if warp_size_expr is None:
return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"))
return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"), warp_size_expr)
def get_warp_idx_sync(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
"""Return the canonical warp index, assuming the warp's threads are converged.
Parameters
----------
warp_size : Optional[int, PrimExpr]
Logical warp size used for the index calculation.
Example
-------
>>> warp = T.get_warp_idx_sync()
>>> custom_warp = T.get_warp_idx_sync(64)
Implementation Notes
--------------------
Emits `tl::get_warp_idx_sync(warp_size)` which divides the block-linear
thread id by `warp_size`, matching the semantics of CUTLASS' canonical helpers.
"""
warp_size_expr = _normalize_index_arg(warp_size)
if warp_size_expr is None:
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"))
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"), warp_size_expr)
def get_warp_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
"""Return the canonical warp index without synchronizing the warp.
Parameters
----------
warp_size : Optional[int, PrimExpr]
Logical warp size used for the index calculation.
Example
-------
>>> warp = T.get_warp_idx()
>>> custom_warp = T.get_warp_idx(64)
Implementation Notes
--------------------
Lowers to `tl::get_warp_idx(warp_size)` which divides the block-linear
thread id by the provided `warp_size` without requiring warp convergence.
"""
warp_size_expr = _normalize_index_arg(warp_size)
if warp_size_expr is None:
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx"))
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx"), warp_size_expr)
def get_warp_group_idx(
warp_size: Optional[Union[int, PrimExpr]] = None,
warps_per_group: Optional[Union[int, PrimExpr]] = None,
) -> PrimExpr:
"""Return the canonical warp group index for the calling thread.
Parameters
----------
warp_size : Optional[int, PrimExpr]
Logical warp size to use (defaults to 32 on NVIDIA / 64 on AMD).
warps_per_group : Optional[int, PrimExpr]
Number of warps per warp-group. Defaults to 4 on NVIDIA architectures.
Example
-------
>>> group = T.get_warp_group_idx()
>>> custom_group = T.get_warp_group_idx(32, 6) # treat 6 warps as a group
Implementation Notes
--------------------
Generates `tl::get_warp_group_idx(warp_size, warps_per_group)` which
divides the block-linear thread id by `warp_size * warps_per_group`,
matching the canonical ordering while allowing architecture-specific overrides.
"""
warp_size_expr = _normalize_index_arg(warp_size)
warps_per_group_expr = _normalize_index_arg(warps_per_group)
args = []
if warp_size_expr is not None:
args.append(warp_size_expr)
if warps_per_group_expr is not None:
if warp_size_expr is None:
raise ValueError("get_warp_group_idx expects `warp_size` when specifying "
"`warps_per_group`.")
args.append(warps_per_group_expr)
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_group_idx"), *args)
def shuffle_elect(thread_extent: int) -> PrimExpr:
"""Elect exactly one lane within a logical thread group.
Parameters
----------
thread_extent : int
Size (in threads) of the group in which a single lane should be elected.
Passing 0 elects a single lane in the entire thread block.
Example
-------
>>> is_leader = T.shuffle_elect(64)
>>> T.if_then_else(is_leader, do_leader_work(), T.evaluate(0))
Implementation Notes
--------------------
Lowered to the CUDA helper `tl::tl_shuffle_elect<thread_extent>()` defined in
`src/tl_templates/cuda/intrin.h`, which relies on
`cutlass::canonical_warp_idx_sync()` and `cute::elect_one_sync()` (or
`__shfl_sync`) to pick one lane per group.
"""
return tir.call_intrin("bool", tir.op.Op.get("tl.tl_shuffle_elect"), thread_extent)
def wait_wgmma(id: int): def wait_wgmma(id: int):
"""Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete. """Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment