"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "04b3214a11ffa7d6e3b8b66a284ef6541bb80a65"
Unverified Commit ba410ae3 authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[Feature] Support Reduce operators for bitwise and/or/xor (#1074)

* [Feature] Support Reduce operators for bitwise and/or/xor

* [Lint]
parent 1516f43c
......@@ -443,7 +443,8 @@ class _attention(torch.autograd.Function):
dk = torch.empty(shape_k, dtype=torch.float16, device=q.device)
dv = torch.empty(shape_v, dtype=torch.float16, device=q.device)
kernel(q, k, v, do, lse, delta, dq, dk, dv)
dq = mod_post(dq)
dq, _, _ = mod_post(dq, torch.zeros_like(k, dtype=torch.float32),
torch.zeros_like(v, dtype=torch.float32))
dk, dv = dk.sum(0), dv.sum(0)
return dq, dk, dv, None, None, None
......
......@@ -70,6 +70,19 @@ PrimExpr ReduceOpNode::MakeInitValue() const {
}
} else if (type->isAbsMax()) {
return make_const(dst->dtype, 0);
} else if (type->isBitAnd()) {
if (is_int) {
return make_const(dst->dtype, -1);
} else if (is_uint) {
return make_const(dst->dtype, (1 << bits) - 1);
} else {
// Should not arrive here
return make_const(dst->dtype, -INFINITY);
}
} else if (type->isBitOr()) {
return make_zero(dst->dtype);
} else if (type->isBitXor()) {
return make_zero(dst->dtype);
} else {
LOG(FATAL) << "Unsupported reduce type: " << type->type;
}
......@@ -91,6 +104,12 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs,
return Min(lhs, rhs);
} else if (type->isAbsMax()) {
return Max(Max(lhs, rhs), -Min(lhs, rhs));
} else if (type->isBitAnd()) {
return lhs & rhs;
} else if (type->isBitOr()) {
return lhs | rhs;
} else if (type->isBitXor()) {
return lhs ^ rhs;
} else {
LOG(FATAL) << "Unsupported reduce type: " << type->type;
}
......@@ -107,6 +126,12 @@ std::string ReduceOpNode::MakeCodegenReducer() const {
return "tl::MinOp";
} else if (type->isAbsMax()) {
return "tl::MaxOp";
} else if (type->isBitAnd()) {
return "tl::BitAndOp";
} else if (type->isBitOr()) {
return "tl::BitOrOp";
} else if (type->isBitXor()) {
return "tl::BitXorOp";
} else {
LOG(FATAL) << "Unsupported reduce type: " << type->type;
return "";
......@@ -195,6 +220,12 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
require_init = true;
} else if (this->type->isAbsSum()) {
require_init = true;
} else if (this->type->isBitAnd()) {
require_init = true;
} else if (this->type->isBitOr()) {
require_init = true;
} else if (this->type->isBitXor()) {
require_init = true;
}
Buffer clear_buffer = dst_buffer;
......@@ -203,6 +234,12 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
need_duplicate = true;
} else if (this->type->isAbsSum() && !this->clear) {
need_duplicate = true;
} else if (this->type->isBitAnd()) {
need_duplicate = true;
} else if (this->type->isBitOr() && !this->clear) {
need_duplicate = true;
} else if (this->type->isBitXor() && !this->clear) {
need_duplicate = true;
}
if (need_duplicate) {
......@@ -213,9 +250,10 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}
// make reduce-init stmt
if (require_init)
if (require_init) {
stmts.push_back(
BufferStore(clear_buffer, this->MakeInitValue(), dst_indices));
}
// make thread-local reduce
Array<PrimExpr> src_indice_compressed;
......@@ -298,6 +336,29 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Add(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else if (this->type->isBitAnd()) {
if (!this->clear) {
stmts.push_back(
BufferStore(dst_buffer,
bitwise_and(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else {
stmts.push_back(BufferStore(
dst_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices));
}
} else if (this->type->isBitOr()) {
stmts.push_back(
BufferStore(dst_buffer,
bitwise_or(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else if (this->type->isBitXor()) {
stmts.push_back(
BufferStore(dst_buffer,
bitwise_xor(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else {
ICHECK(false) << "Unsupported reduce type: " << this->type->type;
}
......
......@@ -21,6 +21,9 @@ enum class ReduceTypeEnum : uint8_t {
kMax, ///< Maximum value reduction
kMin, ///< Minimum value reduction
kAbsMax, ///< Maximum absolute value reduction
kBitAnd, ///< Bitwise and reduction
kBitOr, ///< Bitwise or reduction
kBitXor, ///< Bitwise xor reduction
};
/// Node class representing a reduction type
......@@ -50,6 +53,9 @@ public:
bool isMax() const { return type == int(ReduceTypeEnum::kMax); }
bool isMin() const { return type == int(ReduceTypeEnum::kMin); }
bool isAbsMax() const { return type == int(ReduceTypeEnum::kAbsMax); }
bool isBitAnd() const { return type == int(ReduceTypeEnum::kBitAnd); }
bool isBitOr() const { return type == int(ReduceTypeEnum::kBitOr); }
bool isBitXor() const { return type == int(ReduceTypeEnum::kBitXor); }
};
/// Wrapper class for reduction type with string-based construction
......@@ -68,6 +74,12 @@ public:
node->type = int(ReduceTypeEnum::kAbsMax);
} else if (type == "min") {
node->type = int(ReduceTypeEnum::kMin);
} else if (type == "bitand") {
node->type = int(ReduceTypeEnum::kBitAnd);
} else if (type == "bitor") {
node->type = int(ReduceTypeEnum::kBitOr);
} else if (type == "bitxor") {
node->type = int(ReduceTypeEnum::kBitXor);
} else {
LOG(FATAL) << "Invalid reduce type: " << type;
}
......
......@@ -22,6 +22,24 @@ struct MinOp {
}
};
struct BitAndOp {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return x & y;
}
};
struct BitOrOp {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return x | y;
}
};
struct BitXorOp {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return x ^ y;
}
};
template <class Reducer, int threads, int scale, int thread_offset = 0,
int all_threads = threads>
struct AllReduce {
......
import tilelang
import tilelang.language as T
import torch
import tilelang.testing
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
},
)
def bitwise_reduce(
M,
N,
block_M,
block_N,
name,
func,
clear=True,
):
@T.prim_func
def reduce_func(
A: T.Tensor((M, N), "int32"),
B: T.Tensor((M), "int32"),
Output: T.Tensor((M), "int32"),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), "int32")
A_fragment = T.alloc_fragment((block_M, block_N), "int32")
B_shared = T.alloc_shared((block_M,), "int32")
B_fragment = T.alloc_fragment((block_M), "int32")
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(A_shared, A_fragment)
T.copy(B[by * block_M], B_shared)
T.copy(B_shared, B_fragment)
func(A_fragment, B_fragment, clear=clear)
T.copy(B_fragment, Output[by * block_M])
return reduce_func
def run_single_bitwise_reduce(
name,
func,
clear=True,
):
M, N = 32, 32
block_M, block_N = 32, 32
kernel = bitwise_reduce(M, N, block_M, block_N, name, func, clear)
# Generate test data that exercises all bit patterns for robust bitwise reduce testing
a = torch.zeros((M, N), device="cuda", dtype=torch.int32)
# Fill with patterns that will produce meaningful results for bitwise operations:
# - Different bit patterns across rows/columns
# - Mix of 0s and 1s in various positions
# - Some all-1s and all-0s patterns for edge cases
for i in range(M):
for j in range(N):
# Create varied bit patterns:
# Row-based pattern: alternating bits based on row index
row_pattern = (i & 0xF) << (i % 4) # 4-bit patterns shifted by row
# Column-based pattern: different bit positions set based on column
col_pattern = (1 << (j % 31)) # Single bit set at different positions
# Combine patterns with XOR to create diverse bit distributions
# Add some deterministic "noise" based on position
position_factor = (i * N + j) % 256
# Final value combines all patterns
a[i, j] = (row_pattern ^ col_pattern ^ position_factor) & 0xFFFFFFFF
if i % 4 == 0:
a[i, j] &= ~(0x1 << (i // 4))
elif i % 2 == 0:
a[i, j] |= (0x1 << (i // 2))
if name == "reduce_bitand":
expected = torch.full((M,), -1, device="cuda", dtype=torch.int32)
elif name == "reduce_bitor" or name == "reduce_bitxor":
expected = torch.full((M,), 0, device="cuda", dtype=torch.int32)
else:
raise ValueError("Invalid name: {}".format(name))
output = kernel(a, expected)
for i in range(M):
for j in range(N):
if name == "reduce_bitand":
expected[i] = expected[i] & a[i, j]
elif name == "reduce_bitor":
expected[i] = expected[i] | a[i, j]
elif name == "reduce_bitxor":
expected[i] = expected[i] ^ a[i, j]
else:
raise ValueError("Invalid name: {}".format(name))
assert torch.all(output == expected)
print("✓ {} with clear={} test passed".format(name, clear))
@tilelang.testing.requires_cuda
def test_bitwise_reduce_ops():
run_single_bitwise_reduce("reduce_bitand", T.reduce_bitand, clear=True)
run_single_bitwise_reduce("reduce_bitor", T.reduce_bitor, clear=True)
run_single_bitwise_reduce("reduce_bitxor", T.reduce_bitxor, clear=True)
run_single_bitwise_reduce("reduce_bitand", T.reduce_bitand, clear=False)
run_single_bitwise_reduce("reduce_bitor", T.reduce_bitor, clear=False)
run_single_bitwise_reduce("reduce_bitxor", T.reduce_bitxor, clear=False)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -58,6 +58,9 @@ from .reduce import (
reduce_sum, # noqa: F401
reduce_abssum, # noqa: F401
reduce_absmax, # noqa: F401
reduce_bitand, # noqa: F401
reduce_bitor, # noqa: F401
reduce_bitxor, # noqa: F401
cumsum, # noqa: F401
finalize_reducer, # noqa: F401
)
......
......@@ -139,6 +139,51 @@ def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: boo
return reduce(buffer, out, "absmax", dim, clear)
def reduce_bitand(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
"""Perform reduce bitwise-and on input buffer, store the result to output buffer.
Args:
buffer (tir.Buffer): The input buffer
out (tir.Buffer): The output buffer
dim (int): The dimension to perform reduce on
Returns:
tir.Call: Handle to the reduction operation
"""
dim = _legalize_dim(buffer, dim)
return reduce(buffer, out, "bitand", dim, clear)
def reduce_bitor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
"""Perform reduce bitwise-or on input buffer, store the result to output buffer.
Args:
buffer (tir.Buffer): The input buffer
out (tir.Buffer): The output buffer
dim (int): The dimension to perform reduce on
Returns:
tir.Call: Handle to the reduction operation
"""
dim = _legalize_dim(buffer, dim)
return reduce(buffer, out, "bitor", dim, clear)
def reduce_bitxor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True):
"""Perform reduce bitwise-xor on input buffer, store the result to output buffer.
Args:
buffer (tir.Buffer): The input buffer
out (tir.Buffer): The output buffer
dim (int): The dimension to perform reduce on
Returns:
tir.Call: Handle to the reduction operation
"""
dim = _legalize_dim(buffer, dim)
return reduce(buffer, out, "bitxor", dim, clear)
@macro
def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -> tir.PrimExpr:
cumsum_smem = alloc_shared(src.shape, src.dtype, "shared.dyn")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment