"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "7748a2d4232ec3a8d1104a1a9d9114a907d3b765"
Commit 88747fcd authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Language] Support tile operator `T.cumsum` (#423)

* [Feature] Implement CumSum operation in TileLang

* Added CumSumOp class for cumulative sum operations, including argument validation and lowering logic.
* Introduced CumSum2D template for CUDA, supporting both forward and reverse cumulative sums.
* Created tests for CumSum functionality in shared memory and fragment contexts.
* Updated language interface to include cumsum operation, enhancing the reduction capabilities of TileLang.
* Refactored reduce.py to support cumsum functionality with appropriate memory allocation and copying mechanisms.

* lint fix
parent ae1e7399
...@@ -237,5 +237,55 @@ TIR_REGISTER_TL_OP(ReduceOp, reduce) ...@@ -237,5 +237,55 @@ TIR_REGISTER_TL_OP(ReduceOp, reduce)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
/*
CumSum arguments:
src: input buffer
dst: output buffer
dim: dimension to cumsum
reverse: whether to cumsum in reverse order
*/
CHECK_EQ(args.size(), 4);
src = vmap[GetVarFromAccessPtr(args[0])];
dst = vmap[GetVarFromAccessPtr(args[1])];
dim = args[2].as<IntImm>().value()->value;
reverse = args[3].as<Bool>().value();
CHECK_LT(dim, static_cast<int>(src->shape.size()));
}
Stmt CumSumOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (this->src.scope() == "local.fragment" &&
this->dst.scope() == "local.fragment") {
LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue "
"if you need this feature.";
} else if (this->src.scope() == "shared.dyn" ||
this->src.scope() == "shared") {
ICHECK(this->dst.scope() == "shared.dyn" || this->dst.scope() == "shared");
std::stringstream ss;
auto threads = T.thread_bounds->extent - T.thread_bounds->min;
ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
<< (reverse ? "true" : "false") << ">::run";
Array<PrimExpr> args = {StringImm(ss.str()), src.access_ptr(1),
dst.access_ptr(3)};
for (int i = 0; i < src->shape.size(); i++) {
args.push_back(src->shape[i]);
}
return Evaluate(Call(dst->dtype, builtin::call_extern(), args));
} else {
ICHECK(false) << "Cannot lower cumsum for " << this->src.scope() << " and "
<< this->dst.scope();
}
return Stmt();
}
LayoutMap CumSumOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return {};
}
TIR_REGISTER_TL_OP(CumSumOp, cumsum)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -38,6 +38,19 @@ private: ...@@ -38,6 +38,19 @@ private:
std::string MakeCodegenReducer() const; std::string MakeCodegenReducer() const;
}; };
class CumSumOp : public Operator {
public:
CumSumOp(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get();
private:
tir::Buffer src, dst;
int dim;
bool reverse;
};
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
......
...@@ -65,4 +65,83 @@ struct AllReduce { ...@@ -65,4 +65,83 @@ struct AllReduce {
} }
}; };
template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
static_assert(threads == 1024 or threads == 512 or threads == 256 or
threads == 128 or threads == 64 or threads == 32);
template <typename T, int SEG = 32>
static TL_DEVICE T run(const T *__restrict__ src, T *__restrict__ dst, int H,
int W) {
constexpr int TILE_H = threads / SEG;
constexpr unsigned MASK = 0xffffffff;
const int num_blocks = (H + TILE_H - 1) / TILE_H;
const int tid = threadIdx.x;
const int lane = tid % 32;
const int row = tid / 32;
for (int b = 0; b < num_blocks; ++b) {
const int gRow = b * TILE_H + row;
if (gRow >= H)
return;
T carry = (T)0;
if (reverse) {
// Start from the last segment for reverse mode
for (int seg = (W + SEG - 1) / SEG - 1; seg >= 0; --seg) {
const int col = seg * SEG + lane;
const int real_row = Axis == 1 ? gRow : col;
const int real_col = Axis == 1 ? col : gRow;
T val = (col < W) ? src[real_row * W + real_col] : (T)0;
#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_down_sync(MASK, val, off);
if (lane < SEG - off)
val += n;
}
val += carry;
if (real_col < W)
dst[real_row * W + real_col] = val;
T segSum = (T)__shfl_sync(MASK, val, (T)0);
if (lane == 0)
carry = segSum;
carry = (T)__shfl_sync(MASK, carry, (T)0);
}
} else {
for (int seg = 0; seg * SEG < W; ++seg) {
const int col = seg * SEG + lane;
const int real_row = Axis == 1 ? gRow : col;
const int real_col = Axis == 1 ? col : gRow;
T val = (col < W) ? src[real_row * W + real_col] : (T)0;
#pragma unroll
for (int off = 1; off < SEG; off <<= 1) {
T n = (T)__shfl_up_sync(MASK, val, off);
if (lane >= off)
val += n;
}
val += carry;
if (real_col < W)
dst[real_row * W + real_col] = val;
T segSum = (T)__shfl_sync(MASK, val, SEG - 1);
if (lane == SEG - 1)
carry = segSum;
carry = (T)__shfl_sync(MASK, carry, SEG - 1);
}
}
}
}
};
} // namespace tl } // namespace tl
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import torch
def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float16"):
import tilelang.language as T
@T.prim_func
def cumsum(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.copy(A[by * block_M, bx * block_N], A_shared)
T.cumsum(src=A_shared, dim=dim, reverse=reverse)
T.copy(A_shared, B[by * block_M, bx * block_N])
return cumsum
def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float16"):
import tilelang.language as T
@T.prim_func
def cumsum(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
A_fragment = T.alloc_fragment((block_M, block_N), dtype)
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(A_shared, A_fragment)
T.cumsum(src=A_fragment, dim=dim, reverse=reverse)
T.copy(A_fragment, B[by * block_M, bx * block_N])
return cumsum
def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float16", scope="smem"):
if scope == "smem":
program = cumsum_smem_test(M, N, block_M, block_N, dim, reverse, dtype)
elif scope == "fragment":
program = cumsum_fragment_test(M, N, block_M, block_N, dim, reverse, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
profiler = jit_kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.One)
def ref_program(A):
ref_b = torch.empty_like(A)
for i in range(M // block_M):
for j in range(N // block_N):
ref_b[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N] = A[i * block_M:(i + 1) * block_M, j *
block_N:(j + 1) * block_N].cumsum(dim=dim)
if reverse:
ref_b[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) *
block_N] = ref_b[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N].flip(dims=[dim])
return ref_b
profiler.assert_allclose(ref_program)
def test_cumsum_smem():
# Test different sizes
run_cumsum(1024, 1024, 128, 128)
run_cumsum(1024, 1024, 128, 128, dim=1)
run_cumsum(1024, 1024, 128, 128, dim=1, reverse=True)
# Test different dtypes
run_cumsum(256, 256, 128, 128, dtype="float32")
run_cumsum(256, 256, 128, 128, dtype="float16")
def test_cumsum_fragment():
run_cumsum(1024, 1024, 128, 128, scope="fragment")
run_cumsum(1024, 1024, 128, 128, dim=1, scope="fragment")
run_cumsum(1024, 1024, 128, 128, dim=1, reverse=True, scope="fragment")
# Test different dtypes
run_cumsum(256, 256, 128, 128, dtype="float32", scope="fragment")
run_cumsum(256, 256, 128, 128, dtype="float16", scope="fragment")
if __name__ == "__main__":
tilelang.testing.main()
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
tilelang.disable_cache()
def reduce_sum_test(M, N, dtype="float16"):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)
# Copy input to local
T.copy(A, A_local)
# Perform reduce_sum operation
T.reduce_sum(A_local, B_local, dim=0)
# Copy result back
T.copy(B_local, B)
return main
def run_reduce_sum(M, N, dtype="float16"):
program = reduce_sum_test(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
print(jit_kernel.get_kernel_source())
profiler = jit_kernel.get_profiler()
def ref_program(A):
return A.sum(dim=0)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reduce_sum():
# Test different sizes
run_reduce_sum(256, 256)
run_reduce_sum(512, 128)
run_reduce_sum(128, 512)
# Test different dtypes
run_reduce_sum(256, 256, "float32")
run_reduce_sum(256, 256, "float16")
if __name__ == "__main__":
tilelang.testing.main()
...@@ -48,6 +48,7 @@ from .reduce import ( ...@@ -48,6 +48,7 @@ from .reduce import (
reduce_sum, # noqa: F401 reduce_sum, # noqa: F401
reduce_abssum, # noqa: F401 reduce_abssum, # noqa: F401
reduce_absmax, # noqa: F401 reduce_absmax, # noqa: F401
cumsum, # noqa: F401
) )
from .print import print # noqa: F401 from .print import print # noqa: F401
from .customize import ( from .customize import (
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from tvm import tir from tvm import tir
from typing import Optional
from tilelang.language import copy, macro, alloc_shared
def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool): def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool):
...@@ -104,3 +106,33 @@ def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int): ...@@ -104,3 +106,33 @@ def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int):
tir.Call: Handle to the reduction operation tir.Call: Handle to the reduction operation
""" """
return reduce(buffer, out, "absmax", dim, True) return reduce(buffer, out, "absmax", dim, True)
@macro
def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -> tir.PrimExpr:
cumsum_smem = alloc_shared(src.shape, src.dtype, "shared.dyn")
copy(src, cumsum_smem)
tir.call_intrin(
"handle",
tir.op.Op.get("tl.cumsum"),
cumsum_smem.access_ptr("r"),
cumsum_smem.access_ptr("w"),
dim,
reverse,
)
copy(cumsum_smem, dst)
def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reverse: bool = False):
if dst is None:
dst = src
if src.scope() == "local.fragment":
return cumsum_fragment(src, dst, dim, reverse)
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.cumsum"),
src.access_ptr("r"),
dst.access_ptr("w"),
dim,
reverse,
)
...@@ -292,6 +292,8 @@ def torch_assert_close( ...@@ -292,6 +292,8 @@ def torch_assert_close(
f"({max_mismatched_ratio * 100:.2f}% allowed, but get {num_mismatched / total_elements * 100:.2f}%)." f"({max_mismatched_ratio * 100:.2f}% allowed, but get {num_mismatched / total_elements * 100:.2f}%)."
f"{mismatch_info}" f"{mismatch_info}"
f"\nGreatest absolute difference: {diff.max().item()}, " f"\nGreatest absolute difference: {diff.max().item()}, "
f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}.") f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}"
f"\nLHS: {tensor_a}"
f"\nRHS: {tensor_b}")
else: else:
return True return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment