Unverified Commit caa6dd3f authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Feat] Support warp reduce (#1316)

* [Feat] Support warp reduce

* lint

* add test

* lint
parent 6c2162a9
......@@ -341,5 +341,30 @@ TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warp_reduce_sum)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warp_reduce_max)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warp_reduce_min)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warp_reduce_bitand)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warp_reduce_bitor)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
......@@ -571,6 +571,31 @@ TVM_DLL const Op &device_assert();
*/
TVM_DLL const Op &device_assert_with_msg();
/*!
* \brief tilelang intrinsic for warp reduction sum.
*/
TVM_DLL const Op &warp_reduce_sum();
/*!
* \brief tilelang intrinsic for warp reduction max.
*/
TVM_DLL const Op &warp_reduce_max();
/*!
* \brief tilelang intrinsic for warp reduction min.
*/
TVM_DLL const Op &warp_reduce_min();
/*!
* \brief tilelang intrinsic for warp reduction bitand.
*/
TVM_DLL const Op &warp_reduce_bitand();
/*!
* \brief tilelang intrinsic for warp reduction bitor.
*/
TVM_DLL const Op &warp_reduce_bitor();
} // namespace tl
} // namespace tvm
......
......@@ -2609,6 +2609,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std::string func_name = math_func(op->dtype, "fdiv", rounding_mode);
os << func_name << "(" << PrintExpr(op->args[0]) << ", "
<< PrintExpr(op->args[1]) << ")";
} else if (op->op.same_as(tl::warp_reduce_sum())) {
os << "tl::warp_reduce_sum(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::warp_reduce_max())) {
os << "tl::warp_reduce_max(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::warp_reduce_min())) {
os << "tl::warp_reduce_min(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::warp_reduce_bitand())) {
os << "tl::warp_reduce_bitand(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::warp_reduce_bitor())) {
os << "tl::warp_reduce_bitor(" << PrintExpr(op->args[0]) << ")";
} else {
CodeGenC::VisitExpr_(op, os);
}
......
......@@ -250,4 +250,35 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
}
};
template <typename T, typename ReduceOp>
TL_DEVICE T warp_reduce(T value, ReduceOp op) {
constexpr uint32_t mask = 0xffffffff;
value = op(value, __shfl_xor_sync(mask, value, 16));
value = op(value, __shfl_xor_sync(mask, value, 8));
value = op(value, __shfl_xor_sync(mask, value, 4));
value = op(value, __shfl_xor_sync(mask, value, 2));
value = op(value, __shfl_xor_sync(mask, value, 1));
return value;
}
template <typename T> TL_DEVICE T warp_reduce_sum(T value) {
return warp_reduce<T>(value, SumOp());
}
template <typename T> TL_DEVICE T warp_reduce_max(T value) {
return warp_reduce<T>(value, MaxOp());
}
template <typename T> TL_DEVICE T warp_reduce_min(T value) {
return warp_reduce<T>(value, MinOp());
}
template <typename T> TL_DEVICE T warp_reduce_bitand(T value) {
return warp_reduce<T>(value, BitAndOp());
}
template <typename T> TL_DEVICE T warp_reduce_bitor(T value) {
return warp_reduce<T>(value, BitOrOp());
}
} // namespace tl
import torch
import tilelang
import tilelang.testing
import tilelang.language as T
@tilelang.jit
def get_kernel(reduce_op: str, dtype: str):
assert reduce_op in ["sum", "max", "min", "bitand", "bitor"]
@T.prim_func
def main(x: T.Tensor((32), dtype)):
with T.Kernel(1, threads=32):
tx = T.get_thread_binding(0)
local_val = T.alloc_local([1], dtype)
local_val[0] = x[tx]
reduced_val = T.alloc_local([1], dtype)
if reduce_op == "sum":
reduced_val[0] = T.warp_reduce_sum(local_val[0])
elif reduce_op == "max":
reduced_val[0] = T.warp_reduce_max(local_val[0])
elif reduce_op == "min":
reduced_val[0] = T.warp_reduce_min(local_val[0])
elif reduce_op == "bitand":
reduced_val[0] = T.warp_reduce_bitand(local_val[0])
elif reduce_op == "bitor":
reduced_val[0] = T.warp_reduce_bitor(local_val[0])
x[tx] = reduced_val[0]
return main
def test_warp_reduce_sum():
a = torch.randn((32,), dtype=torch.float32, device='cuda')
kernel = get_kernel('sum', 'float32')
ref = torch.full_like(a, a.sum())
kernel(a)
torch.testing.assert_close(a, ref)
def test_warp_reduce_max():
a = torch.randn((32,), dtype=torch.float32, device='cuda')
kernel = get_kernel("max", 'float32')
print(kernel.get_kernel_source())
ref = torch.full_like(a, a.max())
kernel(a)
torch.testing.assert_close(a, ref)
def test_warp_reduce_min():
a = torch.randn((32,), dtype=torch.float32, device='cuda')
kernel = get_kernel("min", 'float32')
ref = torch.full_like(a, a.min())
kernel(a)
torch.testing.assert_close(a, ref)
def test_warp_reduce_bitand():
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda')
kernel = get_kernel("bitand", 'int32')
ref_val = a[0]
for i in range(1, a.shape[0]):
ref_val = ref_val & a[i]
ref = torch.full_like(a, ref_val)
kernel(a)
torch.testing.assert_close(a, ref)
def test_warp_reduce_bitor():
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda')
kernel = get_kernel("bitor", 'int32')
ref_val = a[0]
for i in range(1, a.shape[0]):
ref_val = ref_val | a[i]
ref = torch.full_like(a, ref_val)
kernel(a)
torch.testing.assert_close(a, ref)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -65,6 +65,11 @@ from .reduce import (
reduce_bitxor, # noqa: F401
cumsum, # noqa: F401
finalize_reducer, # noqa: F401
warp_reduce_sum, # noqa: F401
warp_reduce_max, # noqa: F401
warp_reduce_min, # noqa: F401
warp_reduce_bitand, # noqa: F401
warp_reduce_bitor, # noqa: F401
)
from .print import print, device_assert # noqa: F401
from .customize import (
......
......@@ -325,3 +325,83 @@ def finalize_reducer(reducer: tir.Buffer):
tir.op.Op.get("tl.finalize_reducer"),
reducer.access_ptr("w"),
)
def warp_reduce_sum(value: tir.PrimExpr):
"""Perform warp reduction sum on a register value.
This function reduces a value across all threads in a warp using shuffle operations.
Each thread provides a register `value`, and after the reduction, all threads
will have the sum of all values across the warp.
Args:
value (tir.PrimExpr): The input register value to reduce
Returns:
tir.PrimExpr: The reduced sum value (same on all threads in the warp)
"""
return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_sum"), value)
def warp_reduce_max(value: tir.PrimExpr):
"""Perform warp reduction max on a register value.
This function reduces a value across all threads in a warp using shuffle operations.
Each thread provides a register `value`, and after the reduction, all threads
will have the max of all values across the warp.
Args:
value (tir.PrimExpr): The input register value to reduce
Returns:
tir.PrimExpr: The reduced max value (same on all threads in the warp)
"""
return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_max"), value)
def warp_reduce_min(value: tir.PrimExpr):
"""Perform warp reduction min on a register value.
This function reduces a value across all threads in a warp using shuffle operations.
Each thread provides a register `value`, and after the reduction, all threads
will have the min of all values across the warp.
Args:
value (tir.PrimExpr): The input register value to reduce
Returns:
tir.PrimExpr: The reduced min value (same on all threads in the warp)
"""
return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_min"), value)
def warp_reduce_bitand(value: tir.PrimExpr):
"""Perform warp reduction bitwise-and on a register value.
This function reduces a value across all threads in a warp using shuffle operations.
Each thread provides a register `value`, and after the reduction, all threads
will have the bitwise-and of all values across the warp.
Args:
value (tir.PrimExpr): The input register value to reduce
Returns:
tir.PrimExpr: The reduced bitwise-and value (same on all threads in the warp)
"""
return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_bitand"), value)
def warp_reduce_bitor(value: tir.PrimExpr):
"""Perform warp reduction bitwise-or on a register value.
This function reduces a value across all threads in a warp using shuffle operations.
Each thread provides a register `value`, and after the reduction, all threads
will have the bitwise-or of all values across the warp.
Args:
value (tir.PrimExpr): The input register value to reduce
Returns:
tir.PrimExpr: The reduced bitwise-or value (same on all threads in the warp)
"""
return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_bitor"), value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment