Unverified Commit 747381ae authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[TileOp] Implememt `CumSum1D` (#978)

* support cumsum-1d

* cumsum 1d support
parent 0ae183db
...@@ -420,12 +420,23 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -420,12 +420,23 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared"); ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared");
std::stringstream ss; std::stringstream ss;
auto threads = T.thread_bounds->extent; auto threads = T.thread_bounds->extent;
ss << "tl::CumSum2D<" << threads << ", " << dim << ", " Array<PrimExpr> args;
<< (reverse ? "true" : "false") << ">::run"; int ndim = static_cast<int>(src->shape.size());
Array<PrimExpr> args = {StringImm(ss.str()), src.access_ptr(1), if (ndim == 1) {
dst.access_ptr(3)}; ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim "
for (int i = 0; i < src->shape.size(); i++) { "= 0.";
args.push_back(src->shape[i]); ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false")
<< ">::run";
args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3),
src->shape[0]};
} else if (ndim == 2) {
ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
<< (reverse ? "true" : "false") << ">::run";
args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3),
src->shape[0], src->shape[1]};
} else {
LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got "
<< ndim << "D.";
} }
return Evaluate(Call(dst->dtype, builtin::call_extern(), args)); return Evaluate(Call(dst->dtype, builtin::call_extern(), args));
} else { } else {
...@@ -446,4 +457,4 @@ TIR_REGISTER_TL_OP(CumSumOp, cumsum) ...@@ -446,4 +457,4 @@ TIR_REGISTER_TL_OP(CumSumOp, cumsum)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -68,6 +68,74 @@ struct AllReduce { ...@@ -68,6 +68,74 @@ struct AllReduce {
} }
}; };
template <int threads, bool reverse = false> struct CumSum1D {
static_assert(threads == 1024 or threads == 512 or threads == 256 or
threads == 128 or threads == 64 or threads == 32);
template <typename T, int SEG = 32>
static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst,
int N) {
if (N <= 0)
return;
constexpr unsigned MASK = 0xffffffff;
const int tid = threadIdx.x;
const int lane = tid % SEG;
if (tid >= SEG)
return;
T carry = (T)0;
if (reverse) {
const int num_segments = (N + SEG - 1) / SEG;
for (int seg = num_segments - 1; seg >= 0; --seg) {
const int idx = seg * SEG + lane;
T val = (idx < N) ? src[idx] : (T)0;
#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_down_sync(MASK, val, off);
if (lane < SEG - off)
val += n;
}
val += carry;
if (idx < N)
dst[idx] = val;
T segSum = (T)__shfl_sync(MASK, val, 0);
if (lane == 0)
carry = segSum;
carry = (T)__shfl_sync(MASK, carry, 0);
}
} else {
const int num_segments = (N + SEG - 1) / SEG;
for (int seg = 0; seg < num_segments; ++seg) {
const int idx = seg * SEG + lane;
T val = (idx < N) ? src[idx] : (T)0;
#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_up_sync(MASK, val, off);
if (lane >= off)
val += n;
}
val += carry;
if (idx < N)
dst[idx] = val;
T segSum = (T)__shfl_sync(MASK, val, SEG - 1);
if (lane == SEG - 1)
carry = segSum;
carry = (T)__shfl_sync(MASK, carry, SEG - 1);
}
}
}
};
template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D { template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
static_assert(threads == 1024 or threads == 512 or threads == 256 or static_assert(threads == 1024 or threads == 512 or threads == 256 or
threads == 128 or threads == 64 or threads == 32); threads == 128 or threads == 64 or threads == 32);
......
...@@ -71,6 +71,75 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32", sc ...@@ -71,6 +71,75 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32", sc
torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3) torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3)
def cumsum_smem_test_1d(N, block_N, reverse=False, dtype="float32"):
import tilelang.language as T
@T.prim_func
def cumsum(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared((block_N,), dtype)
T.copy(A[bx * block_N], A_shared)
T.cumsum(src=A_shared, dim=0, reverse=reverse)
T.copy(A_shared, B[bx * block_N])
return cumsum
def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype="float32"):
import tilelang.language as T
@T.prim_func
def cumsum(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared((block_N,), dtype)
A_fragment = T.alloc_fragment((block_N,), dtype)
T.copy(A[bx * block_N], A_shared)
T.copy(A_shared, A_fragment)
T.cumsum(src=A_fragment, dim=0, reverse=reverse)
T.copy(A_fragment, B[bx * block_N])
return cumsum
def run_cumsum_1d(N, block_N, reverse=False, dtype="float32", scope="smem"):
if scope == "smem":
program = cumsum_smem_test_1d(N, block_N, reverse, dtype)
elif scope == "fragment":
program = cumsum_fragment_test_1d(N, block_N, reverse, dtype)
else:
raise ValueError(f"Unknown scope {scope}")
jit_kernel = tl.compile(program, out_idx=-1)
A = torch.randn(N, dtype=getattr(torch, dtype)).cuda()
def ref_program(A):
ref_b = torch.empty_like(A)
num_blocks = (N + block_N - 1) // block_N
for j in range(num_blocks):
start = j * block_N
end = min(start + block_N, N)
chunk = A[start:end]
if reverse:
chunk = torch.flip(chunk, dims=[0])
chunk = chunk.cumsum(dim=0)
if reverse:
chunk = torch.flip(chunk, dims=[0])
ref_b[start:end] = chunk
return ref_b
tilelang_res = jit_kernel(A)
ref_res = ref_program(A)
torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3)
def test_cumsum_smem(): def test_cumsum_smem():
# Test different sizes # Test different sizes
run_cumsum(1024, 1024, 128, 128) run_cumsum(1024, 1024, 128, 128)
...@@ -92,5 +161,15 @@ def test_cumsum_fragment(): ...@@ -92,5 +161,15 @@ def test_cumsum_fragment():
run_cumsum(256, 256, 128, 128, dtype="float32", scope="fragment") run_cumsum(256, 256, 128, 128, dtype="float32", scope="fragment")
def test_cumsum_smem_1d():
run_cumsum_1d(1024, 128)
run_cumsum_1d(1024, 128, reverse=True)
def test_cumsum_fragment_1d():
run_cumsum_1d(1024, 128, scope="fragment")
run_cumsum_1d(1024, 128, reverse=True, scope="fragment")
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -160,6 +160,29 @@ def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reve ...@@ -160,6 +160,29 @@ def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reve
Negative `dim` indices are normalized (Python-style). If `dst` is None, the operation is performed in-place into `src`. Raises ValueError when `dim` is out of bounds for `src.shape`. When `src.scope() == "local.fragment"`, this delegates to `cumsum_fragment`; otherwise it emits the `tl.cumsum` intrinsic. Negative `dim` indices are normalized (Python-style). If `dst` is None, the operation is performed in-place into `src`. Raises ValueError when `dim` is out of bounds for `src.shape`. When `src.scope() == "local.fragment"`, this delegates to `cumsum_fragment`; otherwise it emits the `tl.cumsum` intrinsic.
Examples:
A 1D inclusive scan that writes the result into a separate shared-memory buffer:
>>> import tilelang.language as T
>>> @T.prim_func
... def kernel(A: T.Tensor((128,), "float32"), B: T.Tensor((128,), "float32")):
... with T.Kernel(1, threads=128):
... A_shared = T.alloc_shared((128,), "float32")
... T.copy(A, A_shared)
... T.cumsum(src=A_shared, dst=A_shared, dim=0)
... T.copy(A_shared, B)
A 2D prefix sum along the last dimension with reverse accumulation:
>>> import tilelang.language as T
>>> @T.prim_func
... def kernel2d(A: T.Tensor((64, 64), "float16"), B: T.Tensor((64, 64), "float16")):
... with T.Kernel(1, 1, threads=256):
... tile = T.alloc_shared((64, 64), "float16")
... T.copy(A, tile)
... T.cumsum(src=tile, dim=1, reverse=True)
... T.copy(tile, B)
Returns: Returns:
tir.Call: A handle to the emitted cumulative-sum operation. tir.Call: A handle to the emitted cumulative-sum operation.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment