Unverified Commit b78d8404 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Expose `T.get_warp_idx_sync` and `T.shuffle_elect` for efficient thread election (#989)



* Expose CUDA warp/lane intrinsics in TileLang frontend

* generalize warp indexing intrinsics and add coverage

* [Lint]: [pre-commit.ci] auto fixes [...]

---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 32ddc1ac
......@@ -218,6 +218,26 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(get_lane_idx)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(get_warp_idx_sync)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(get_warp_idx)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(get_warp_group_idx)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(wait_wgmma)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
......
......@@ -358,6 +358,38 @@ TVM_DLL const Op &warpgroup_commit_batch();
*/
TVM_DLL const Op &warpgroup_wait();
/*!
* \brief Return the canonical lane index for the calling thread.
*
* get_lane_idx([warp_size])
*
*/
TVM_DLL const Op &get_lane_idx();
/*!
* \brief Return the canonical warp index, assuming converged threads.
*
* get_warp_idx_sync([warp_size])
*
*/
TVM_DLL const Op &get_warp_idx_sync();
/*!
* \brief Return the canonical warp index without synchronizing the warp.
*
* get_warp_idx([warp_size])
*
*/
TVM_DLL const Op &get_warp_idx();
/*!
* \brief Return the canonical warp group index for converged threads.
*
* get_warp_group_idx([warp_size, warps_per_group])
*
*/
TVM_DLL const Op &get_warp_group_idx();
/*!
* \brief Wait the previous wgmma to finish
*
......
......@@ -1968,6 +1968,41 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
enable_sparse_gemm_ = true;
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
op->args, true, os);
} else if (op->op.same_as(tl::get_lane_idx())) {
ICHECK_LE(op->args.size(), 1)
<< "tl.get_lane_idx expects at most one argument <warp_size>.";
os << "tl::get_lane_idx(";
if (!op->args.empty()) {
os << PrintExpr(op->args[0]);
}
os << ")";
} else if (op->op.same_as(tl::get_warp_idx_sync())) {
ICHECK_LE(op->args.size(), 1)
<< "tl.get_warp_idx_sync expects at most one argument <warp_size>.";
os << "tl::get_warp_idx_sync(";
if (!op->args.empty()) {
os << PrintExpr(op->args[0]);
}
os << ")";
} else if (op->op.same_as(tl::get_warp_idx())) {
ICHECK_LE(op->args.size(), 1)
<< "tl.get_warp_idx expects at most one argument <warp_size>.";
os << "tl::get_warp_idx(";
if (!op->args.empty()) {
os << PrintExpr(op->args[0]);
}
os << ")";
} else if (op->op.same_as(tl::get_warp_group_idx())) {
ICHECK_LE(op->args.size(), 2)
<< "tl.get_warp_group_idx expects <warp_size, warps_per_group>.";
os << "tl::get_warp_group_idx(";
for (size_t i = 0; i < op->args.size(); ++i) {
if (i != 0) {
os << ", ";
}
os << PrintExpr(op->args[i]);
}
os << ")";
} else if (op->op.same_as(tl::tl_shuffle_elect())) {
os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()";
} else if (op->op.same_as(tl::initialize_descriptor())) {
......
#pragma once
#include "common.h"
#include "cutlass/cutlass.h"
#if __CUDA_ARCH_LIST__ >= 900
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/mma_sm90_gmma.hpp"
#include "cutlass/cutlass.h"
#endif
namespace tl {
namespace detail {
// Provide architecture-specific defaults so callers may omit arguments.
TL_DEVICE constexpr int default_warp_size() {
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP_DEVICE_COMPILE__)
return 64;
#else
return 32;
#endif
}
TL_DEVICE constexpr int default_warps_per_group() { return 4; }
TL_DEVICE int linear_thread_idx_in_block() {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
#else
return 0;
#endif
}
} // namespace detail
TL_DEVICE int get_lane_idx(int warp_size = detail::default_warp_size()) {
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
return detail::linear_thread_idx_in_block() % warp_size;
}
TL_DEVICE int get_warp_idx_sync(int warp_size = detail::default_warp_size()) {
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
return detail::linear_thread_idx_in_block() / warp_size;
}
TL_DEVICE int get_warp_idx(int warp_size = detail::default_warp_size()) {
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
return detail::linear_thread_idx_in_block() / warp_size;
}
TL_DEVICE int
get_warp_group_idx(int warp_size = detail::default_warp_size(),
int warps_per_group = detail::default_warps_per_group()) {
warp_size = warp_size > 0 ? warp_size : detail::default_warp_size();
warps_per_group =
warps_per_group > 0 ? warps_per_group : detail::default_warps_per_group();
int threads_per_group = warp_size * warps_per_group;
threads_per_group = threads_per_group > 0 ? threads_per_group : warp_size;
return detail::linear_thread_idx_in_block() / threads_per_group;
}
#if __CUDA_ARCH_LIST__ >= 900
TL_DEVICE void warpgroup_arrive() { cute::warpgroup_arrive(); }
TL_DEVICE void warpgroup_commit_batch() { cute::warpgroup_commit_batch(); }
......@@ -61,5 +114,6 @@ template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() {
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
} // namespace tl
#endif
} // namespace tl
from typing import Optional
import tilelang.language as T
import tilelang.testing
import torch
from tilelang.utils.target import check_hip_availability
_IS_HIP_AVAILABLE = check_hip_availability()
_DEFAULT_WARPS_PER_GROUP = 4
def _resolve_warp_size(warp_size: Optional[int]) -> int:
if warp_size is not None:
return int(warp_size)
return 64 if _IS_HIP_AVAILABLE else 32
def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int:
if warps_per_group is not None:
return int(warps_per_group)
return _DEFAULT_WARPS_PER_GROUP
@tilelang.jit(out_idx=[-1])
def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@T.prim_func
def laneid_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding()
A[tx] = T.get_lane_idx(warp_size)
return laneid_kernel
@tilelang.jit(out_idx=[-1])
def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@T.prim_func
def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding()
A[tx] = T.get_warp_idx_sync(warp_size)
return warp_idx_sync_kernel
@tilelang.jit(out_idx=[-1])
def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@T.prim_func
def warp_idx_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding()
A[tx] = T.get_warp_idx(warp_size)
return warp_idx_kernel
@tilelang.jit(out_idx=[-1])
def _get_warp_group_idx_kernel(
num_threads: int = 128,
warp_size: Optional[int] = None,
warps_per_group: Optional[int] = None,
):
@T.prim_func
def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding()
A[tx] = T.get_warp_group_idx(warp_size, warps_per_group)
return warp_group_idx_kernel
@tilelang.jit(out_idx=[-1])
def _shuffle_elect_kernel(num_threads: int = 128, thread_extent: int = 64):
@T.prim_func
def shuffle_elect_kernel(A: T.Tensor((num_threads,), "int32")):
with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding()
elected = T.shuffle_elect(thread_extent)
A[tx] = elected
return shuffle_elect_kernel
def run_get_lane_id(num_threads: int = 128, warp_size: Optional[int] = None):
kernel = _get_laneid_kernel(num_threads, warp_size)
A = kernel()
print(kernel.get_kernel_source())
print(A)
expected_warp_size = _resolve_warp_size(warp_size)
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) % expected_warp_size
torch.testing.assert_close(A.cpu(), ref.cpu())
return A
def run_get_warp_idx_sync(num_threads: int = 128, warp_size: Optional[int] = None):
kernel = _get_warp_idx_sync_kernel(num_threads, warp_size)
A = kernel()
print(kernel.get_kernel_source())
print(A)
expected_warp_size = _resolve_warp_size(warp_size)
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // expected_warp_size
torch.testing.assert_close(A.cpu(), ref.cpu())
return A
def run_get_warp_idx(num_threads: int = 128, warp_size: Optional[int] = None):
kernel = _get_warp_idx_kernel(num_threads, warp_size)
A = kernel()
print(kernel.get_kernel_source())
print(A)
expected_warp_size = _resolve_warp_size(warp_size)
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // expected_warp_size
torch.testing.assert_close(A.cpu(), ref.cpu())
return A
def run_get_warp_group_idx(
num_threads: int = 128,
warp_size: Optional[int] = None,
warps_per_group: Optional[int] = None,
):
kernel = _get_warp_group_idx_kernel(num_threads, warp_size, warps_per_group)
A = kernel()
print(kernel.get_kernel_source())
print(A)
expected_warp_size = _resolve_warp_size(warp_size)
expected_warps_per_group = _resolve_warps_per_group(warps_per_group)
threads_per_group = expected_warp_size * expected_warps_per_group
if threads_per_group <= 0:
raise ValueError("threads_per_group must be positive.")
ref = torch.arange(num_threads, dtype=A.dtype, device=A.device) // threads_per_group
torch.testing.assert_close(A.cpu(), ref.cpu())
return A
def run_shuffle_elect(num_threads: int = 128, thread_extent: int = 64):
if thread_extent < 0:
raise ValueError("thread_extent must be non-negative.")
kernel = _shuffle_elect_kernel(num_threads, thread_extent)
A = kernel()
print(kernel.get_kernel_source())
print(A)
indices = torch.arange(num_threads, device=A.device, dtype=torch.int64)
if thread_extent == 0:
mask = indices == 0
elif thread_extent > 0:
mask = (indices % thread_extent) == 0
else:
mask = torch.zeros_like(indices, dtype=torch.bool)
ref = mask.to(dtype=A.dtype, device=A.device)
torch.testing.assert_close(A.cpu(), ref.cpu())
return A
@tilelang.testing.requires_cuda
def test_get_lane_idx_default():
run_get_lane_id()
@tilelang.testing.requires_cuda
def test_get_lane_idx_custom():
run_get_lane_id(num_threads=256, warp_size=64)
@tilelang.testing.requires_cuda
def test_get_warp_idx_sync_default():
run_get_warp_idx_sync()
@tilelang.testing.requires_cuda
def test_get_warp_idx_sync_custom():
run_get_warp_idx_sync(num_threads=256, warp_size=16)
@tilelang.testing.requires_cuda
def test_get_warp_idx_default():
run_get_warp_idx()
@tilelang.testing.requires_cuda
def test_get_warp_idx_custom():
run_get_warp_idx(num_threads=320, warp_size=20)
@tilelang.testing.requires_cuda
def test_get_warp_group_idx_default():
run_get_warp_group_idx()
@tilelang.testing.requires_cuda
def test_get_warp_group_idx_custom():
run_get_warp_group_idx(num_threads=512, warp_size=32, warps_per_group=5)
@tilelang.testing.requires_cuda
def test_shuffle_elect_default():
run_shuffle_elect(num_threads=256, thread_extent=64)
@tilelang.testing.requires_cuda
def test_shuffle_elect_block_leader():
run_shuffle_elect(num_threads=128, thread_extent=0)
if __name__ == "__main__":
tilelang.testing.main()
# run_get_lane_id()
......@@ -5,12 +5,26 @@ from tilelang.language import ptx_arrive_barrier, evaluate
from tilelang.language.kernel import get_thread_bindings, get_block_extents
from tilelang.utils.target import check_hip_availability
from tvm import tir
from typing import Union, Any
from typing import Union, Any, Optional
from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad
_IS_HIP_AVAILABLE = check_hip_availability()
def _normalize_index_arg(value: Optional[Union[int, PrimExpr]]) -> Optional[PrimExpr]:
"""
Normalize warp sizing arguments so both Python ints and PrimExpr values
are accepted uniformly.
"""
if value is None:
return None
if isinstance(value, PrimExpr):
return value
if isinstance(value, int):
return tir.IntImm("int32", value)
raise TypeError(f"Expect warp sizing argument to be int or PrimExpr, but got {type(value)}.")
def create_list_of_mbarrier(*args: Any) -> Call:
"""
Create a list of memory barrier handles.
......@@ -280,6 +294,140 @@ def warpgroup_wait(num_mma: int):
return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma)
def get_lane_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
"""Return the logical lane index of the calling thread within a warp.
Parameters
----------
warp_size : Optional[int, PrimExpr]
Logical warp (or wavefront) size. Defaults to 32 on NVIDIA and 64 on AMD.
Example
-------
>>> lane = T.get_lane_idx()
>>> custom_lane = T.get_lane_idx(64) # override warp size explicitly
Implementation Notes
--------------------
Lowers to the CUDA helper `tl::get_lane_idx(warp_size)` defined in
`src/tl_templates/cuda/intrin.h`, which computes the lane index from the
linear thread id using the provided `warp_size`.
"""
warp_size_expr = _normalize_index_arg(warp_size)
if warp_size_expr is None:
return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"))
return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"), warp_size_expr)
def get_warp_idx_sync(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
"""Return the canonical warp index, assuming the warp's threads are converged.
Parameters
----------
warp_size : Optional[int, PrimExpr]
Logical warp size used for the index calculation.
Example
-------
>>> warp = T.get_warp_idx_sync()
>>> custom_warp = T.get_warp_idx_sync(64)
Implementation Notes
--------------------
Emits `tl::get_warp_idx_sync(warp_size)` which divides the block-linear
thread id by `warp_size`, matching the semantics of CUTLASS' canonical helpers.
"""
warp_size_expr = _normalize_index_arg(warp_size)
if warp_size_expr is None:
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"))
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"), warp_size_expr)
def get_warp_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
"""Return the canonical warp index without synchronizing the warp.
Parameters
----------
warp_size : Optional[int, PrimExpr]
Logical warp size used for the index calculation.
Example
-------
>>> warp = T.get_warp_idx()
>>> custom_warp = T.get_warp_idx(64)
Implementation Notes
--------------------
Lowers to `tl::get_warp_idx(warp_size)` which divides the block-linear
thread id by the provided `warp_size` without requiring warp convergence.
"""
warp_size_expr = _normalize_index_arg(warp_size)
if warp_size_expr is None:
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx"))
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx"), warp_size_expr)
def get_warp_group_idx(
warp_size: Optional[Union[int, PrimExpr]] = None,
warps_per_group: Optional[Union[int, PrimExpr]] = None,
) -> PrimExpr:
"""Return the canonical warp group index for the calling thread.
Parameters
----------
warp_size : Optional[int, PrimExpr]
Logical warp size to use (defaults to 32 on NVIDIA / 64 on AMD).
warps_per_group : Optional[int, PrimExpr]
Number of warps per warp-group. Defaults to 4 on NVIDIA architectures.
Example
-------
>>> group = T.get_warp_group_idx()
>>> custom_group = T.get_warp_group_idx(32, 6) # treat 6 warps as a group
Implementation Notes
--------------------
Generates `tl::get_warp_group_idx(warp_size, warps_per_group)` which
divides the block-linear thread id by `warp_size * warps_per_group`,
matching the canonical ordering while allowing architecture-specific overrides.
"""
warp_size_expr = _normalize_index_arg(warp_size)
warps_per_group_expr = _normalize_index_arg(warps_per_group)
args = []
if warp_size_expr is not None:
args.append(warp_size_expr)
if warps_per_group_expr is not None:
if warp_size_expr is None:
raise ValueError("get_warp_group_idx expects `warp_size` when specifying "
"`warps_per_group`.")
args.append(warps_per_group_expr)
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_group_idx"), *args)
def shuffle_elect(thread_extent: int) -> PrimExpr:
"""Elect exactly one lane within a logical thread group.
Parameters
----------
thread_extent : int
Size (in threads) of the group in which a single lane should be elected.
Passing 0 elects a single lane in the entire thread block.
Example
-------
>>> is_leader = T.shuffle_elect(64)
>>> T.if_then_else(is_leader, do_leader_work(), T.evaluate(0))
Implementation Notes
--------------------
Lowered to the CUDA helper `tl::tl_shuffle_elect<thread_extent>()` defined in
`src/tl_templates/cuda/intrin.h`, which relies on
`cutlass::canonical_warp_idx_sync()` and `cute::elect_one_sync()` (or
`__shfl_sync`) to pick one lane per group.
"""
return tir.call_intrin("bool", tir.op.Op.get("tl.tl_shuffle_elect"), thread_extent)
def wait_wgmma(id: int):
"""Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment