Unverified Commit bc37ea69 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Efficient `T.reduce_` with shared memory input/output (#1080)

* Support reduce ss

* lint fix

* test fix

* lint fix
parent a7730272
...@@ -175,11 +175,18 @@ std::string ReduceOpNode::MakeCodegenReducer() const { ...@@ -175,11 +175,18 @@ std::string ReduceOpNode::MakeCodegenReducer() const {
* @return Stmt Lowered TIR statement implementing the reduction. * @return Stmt Lowered TIR statement implementing the reduction.
*/ */
Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(this->src.scope() == "local.fragment" && auto get_buffer = [&](const Buffer &buf) {
this->dst.scope() == "local.fragment") if (T.buffer_remap.count(buf))
<< "Reduce for shared memory not implemented."; return T.buffer_remap[buf];
auto src_buffer = T.buffer_remap[this->src]; return buf;
auto dst_buffer = T.buffer_remap[this->dst]; };
auto src_scope = this->src.scope();
auto dst_scope = this->dst.scope();
if (src_scope == "local.fragment" && dst_scope == "local.fragment") {
Buffer src_buffer = get_buffer(this->src);
Buffer dst_buffer = get_buffer(this->dst);
Fragment src_layout = T.layout_map[this->src].as<Fragment>().value(); Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
Fragment dst_layout = T.layout_map[this->dst].as<Fragment>().value(); Fragment dst_layout = T.layout_map[this->dst].as<Fragment>().value();
size_t src_dim = src_layout->InputDim(); size_t src_dim = src_layout->InputDim();
...@@ -191,22 +198,24 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -191,22 +198,24 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(is_one(dst_layout->OutputShape().back())) ICHECK(is_one(dst_layout->OutputShape().back()))
<< "Reduce for scalar not implemented."; << "Reduce for scalar not implemented.";
} else { } else {
ICHECK(src_dim == dst_dim + 1) << "Reduce dimension mismatch."; ICHECK_EQ(src_dim, dst_dim + 1) << "Reduce dimension mismatch.";
} }
Array<IterVar> dst_vars; Array<IterVar> dst_vars;
for (size_t i = 0; i < dst_dim; i++) { for (size_t i = 0; i < dst_dim; ++i) {
Var var = Var(std::string{char('i' + i)}); Var var = Var(std::string{char('i' + i)});
dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var, dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
IterVarType::kDataPar)); IterVarType::kDataPar));
} }
Array<IterVar> src_vars; Array<IterVar> src_vars;
if (!is_1d_reduce) { if (!is_1d_reduce) {
src_vars = dst_vars; src_vars = dst_vars;
} }
src_vars.insert(src_vars.begin() + this->dim, Range reduce_dom(0, src_layout->InputShape()[this->dim]);
{Range(0, src_layout->InputShape()[this->dim]), Var("rv"), IterVar reduce_iv(reduce_dom, Var("rv"), IterVarType::kDataPar);
IterVarType::kDataPar}); src_vars.insert(src_vars.begin() + this->dim, reduce_iv);
Array<PrimExpr> src_indices = src_layout->Forward( Array<PrimExpr> src_indices = src_layout->Forward(
src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); })); src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
Array<PrimExpr> dst_indices = dst_layout->Forward( Array<PrimExpr> dst_indices = dst_layout->Forward(
...@@ -215,30 +224,20 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -215,30 +224,20 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Array<Stmt> stmts; Array<Stmt> stmts;
bool require_init = this->clear; bool require_init = this->clear;
// sum op must be cleared if (this->type->isSum() || this->type->isAbsSum() ||
if (this->type->isSum()) { this->type->isBitAnd() || this->type->isBitOr() ||
require_init = true; this->type->isBitXor()) {
} else if (this->type->isAbsSum()) {
require_init = true;
} else if (this->type->isBitAnd()) {
require_init = true;
} else if (this->type->isBitOr()) {
require_init = true;
} else if (this->type->isBitXor()) {
require_init = true; require_init = true;
} }
Buffer clear_buffer = dst_buffer; Buffer clear_buffer = dst_buffer;
bool need_duplicate = false; bool need_duplicate = false;
if (this->type->isSum() && !this->clear) { if ((this->type->isSum() || this->type->isAbsSum()) && !this->clear) {
need_duplicate = true;
} else if (this->type->isAbsSum() && !this->clear) {
need_duplicate = true;
} else if (this->type->isBitAnd()) {
need_duplicate = true; need_duplicate = true;
} else if (this->type->isBitOr() && !this->clear) { } else if (this->type->isBitAnd() && !this->clear) {
need_duplicate = true; need_duplicate = true;
} else if (this->type->isBitXor() && !this->clear) { } else if ((this->type->isBitOr() || this->type->isBitXor()) &&
!this->clear) {
need_duplicate = true; need_duplicate = true;
} }
...@@ -248,7 +247,6 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -248,7 +247,6 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
dst_buffer->name + "_clear", dst_buffer->name + "_clear",
GetPtrStorageScope(dst_buffer->data)); GetPtrStorageScope(dst_buffer->data));
} }
// make reduce-init stmt // make reduce-init stmt
if (require_init) { if (require_init) {
stmts.push_back( stmts.push_back(
...@@ -258,20 +256,22 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -258,20 +256,22 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
// make thread-local reduce // make thread-local reduce
Array<PrimExpr> src_indice_compressed; Array<PrimExpr> src_indice_compressed;
Array<IterVar> src_var_compressed; Array<IterVar> src_var_compressed;
for (size_t i = 0; i < src_layout->OutputDim(); i++) { for (size_t i = 0; i < src_layout->OutputDim(); ++i) {
PrimExpr expr; PrimExpr expr;
IterVar var; IterVar var;
std::tie(expr, var) = CompressIterator(src_indices[i], src_vars, std::tie(expr, var) = CompressIterator(
src_vars[this->dim]->var, analyzer); src_indices[i], src_vars, src_vars[this->dim]->var, analyzer);
src_indice_compressed.push_back(expr); src_indice_compressed.push_back(expr);
src_var_compressed.push_back(var); src_var_compressed.push_back(var);
} }
Stmt reduce_local = BufferStore( Stmt reduce_local = BufferStore(
clear_buffer, clear_buffer,
this->MakeReduce(BufferLoad(clear_buffer, dst_indices), this->MakeReduce(BufferLoad(clear_buffer, dst_indices),
BufferLoad(src_buffer, src_indice_compressed)), BufferLoad(src_buffer, src_indice_compressed)),
dst_indices); dst_indices);
for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
for (int i = static_cast<int>(src_layout->OutputDim()) - 1; i >= 0; --i) {
reduce_local = reduce_local =
For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent, For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
ForKind::kUnrolled, reduce_local, std::nullopt, ForKind::kUnrolled, reduce_local, std::nullopt,
...@@ -279,7 +279,6 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -279,7 +279,6 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
} }
stmts.push_back(reduce_local); stmts.push_back(reduce_local);
// make inter-thread reduce
PrimExpr src_thread = src_layout->ForwardThread( PrimExpr src_thread = src_layout->ForwardThread(
src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {}); src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
auto iter_sum = auto iter_sum =
...@@ -315,67 +314,118 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -315,67 +314,118 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
*as_const_int(T.thread_bounds->extent), clear_buffer->dtype); *as_const_int(T.thread_bounds->extent), clear_buffer->dtype);
thread_reduce_args.push_back(workspace); thread_reduce_args.push_back(workspace);
} }
auto call = auto call = Call(clear_buffer->dtype, builtin::call_extern(),
Call(clear_buffer->dtype, builtin::call_extern(), thread_reduce_args); thread_reduce_args);
stmts.push_back(BufferStore(clear_buffer, call, dst_indices)); stmts.push_back(BufferStore(clear_buffer, call, dst_indices));
} }
} }
Stmt reduce_interthread = BufferStore(
clear_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices);
// copy clear_buffer to dst_buffer
if (need_duplicate) { if (need_duplicate) {
// if is reduce sum, we should add a copy from clear_buffer to dst_buffer PrimExpr src_val = BufferLoad(clear_buffer, dst_indices);
if (this->type->isSum()) { PrimExpr dst_val = BufferLoad(dst_buffer, dst_indices);
stmts.push_back(BufferStore(dst_buffer, PrimExpr update;
Add(BufferLoad(dst_buffer, dst_indices), if (this->type->isSum() || this->type->isAbsSum()) {
BufferLoad(clear_buffer, dst_indices)), update = dst_val + src_val;
dst_indices));
} else if (this->type->isAbsSum()) {
stmts.push_back(BufferStore(dst_buffer,
Add(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else if (this->type->isBitAnd()) { } else if (this->type->isBitAnd()) {
if (!this->clear) { update = this->clear ? src_val : bitwise_and(dst_val, src_val);
stmts.push_back(
BufferStore(dst_buffer,
bitwise_and(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else {
stmts.push_back(BufferStore(
dst_buffer, BufferLoad(clear_buffer, dst_indices), dst_indices));
}
} else if (this->type->isBitOr()) { } else if (this->type->isBitOr()) {
stmts.push_back( update = bitwise_or(dst_val, src_val);
BufferStore(dst_buffer,
bitwise_or(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else if (this->type->isBitXor()) { } else if (this->type->isBitXor()) {
stmts.push_back( update = bitwise_xor(dst_val, src_val);
BufferStore(dst_buffer,
bitwise_xor(BufferLoad(dst_buffer, dst_indices),
BufferLoad(clear_buffer, dst_indices)),
dst_indices));
} else { } else {
ICHECK(false) << "Unsupported reduce type: " << this->type->type; LOG(FATAL) << "Unsupported reduce type: " << this->type->type;
} }
stmts.push_back(BufferStore(dst_buffer, update, dst_indices));
} }
// make the outer spatial loop
Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]; Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
for (int i = dst_layout->InputDim() - 1; i >= 0; i--) { for (int i = static_cast<int>(dst_layout->InputDim()) - 1; i >= 0; --i) {
body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent, body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
ForKind::kParallel, body); ForKind::kParallel, body);
} }
body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout); if (dst_layout->InputDim() > 0) {
body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer,
dst_layout);
} else {
PrimExpr guard = (T.thread_var == T.thread_bounds->min);
body = IfThenElse(guard, body);
}
if (need_duplicate) { if (need_duplicate) {
body = Allocate(clear_buffer->data, clear_buffer->dtype, body = Allocate(clear_buffer->data, clear_buffer->dtype,
clear_buffer->shape, const_true(), body); clear_buffer->shape, const_true(), body);
} }
return body; return body;
}
auto is_shared_scope = [](const std::string &scope) {
return scope == "shared" || scope == "shared.dyn";
};
if (is_shared_scope(src_scope) && is_shared_scope(dst_scope)) {
Buffer src_buffer = get_buffer(this->src);
Buffer dst_buffer = get_buffer(this->dst);
size_t src_dim = src_buffer->shape.size();
size_t dst_dim = dst_buffer->shape.size();
bool is_1d_reduce = (src_dim == dst_dim && dst_dim == 1);
if (!is_1d_reduce) {
ICHECK_EQ(src_dim, dst_dim + 1) << "Reduce dimension mismatch.";
} else {
ICHECK_EQ(dst_dim, 1U) << "Expect scalar layout for 1D reduce.";
}
auto thread_extent = as_const_int(T.thread_bounds->extent);
ICHECK(thread_extent)
<< "Shared-memory reduce requires static thread extent.";
int threads = *thread_extent;
if (TargetIsCuda(T.target)) {
ICHECK_EQ(threads % 32, 0)
<< "Shared reduce expects blockDim.x to be a multiple of 32 on CUDA.";
} else if (TargetIsRocm(T.target)) {
ICHECK_EQ(threads % 64, 0)
<< "Shared reduce expects blockDim.x to be a multiple of 64 on HIP.";
}
bool use_abs = this->type->isAbsSum() || this->type->isAbsMax();
bool need_accumulate =
(!this->clear) && (this->type->isSum() || this->type->isAbsSum() ||
this->type->isBitAnd() || this->type->isBitOr() ||
this->type->isBitXor());
PrimExpr reduce_extent = src_buffer->shape[this->dim];
PrimExpr tail_extent = make_const(DataType::Int(32), 1);
for (size_t i = this->dim + 1; i < src_dim; ++i) {
tail_extent = analyzer->Simplify(tail_extent * src_buffer->shape[i]);
}
PrimExpr total_dest = make_const(DataType::Int(32), 1);
for (size_t i = 0; i < dst_dim; ++i) {
total_dest = analyzer->Simplify(total_dest * dst_buffer->shape[i]);
}
std::stringstream ss;
std::string reducer = this->MakeCodegenReducer();
ss << "tl::SharedReduceWarp<" << reducer << ", " << threads << ", "
<< (use_abs ? "true" : "false") << ", "
<< (need_accumulate ? "true" : "false") << ">::run";
Array<PrimExpr> call_args = {StringImm(ss.str()),
src_buffer.access_ptr(1),
dst_buffer.access_ptr(3),
cast(DataType::Int(32), total_dest),
cast(DataType::Int(32), reduce_extent),
cast(DataType::Int(32), tail_extent),
this->MakeInitValue()};
return Evaluate(Call(dst_buffer->dtype, builtin::call_extern(), call_args));
}
LOG(FATAL) << "Reduce for buffers in scope (" << src_scope << ", "
<< dst_scope << ") is not implemented.";
return Stmt();
} }
LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
......
...@@ -40,6 +40,53 @@ struct BitXorOp { ...@@ -40,6 +40,53 @@ struct BitXorOp {
} }
}; };
template <class Reducer, int Threads, bool UseAbs, bool NeedAccumulate>
struct SharedReduceWarp {
template <typename T>
static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst,
int total_dest, int reduce_extent, int tail,
T init_value) {
if (total_dest <= 0 || reduce_extent <= 0)
return;
constexpr int kWarpSize = 32;
static_assert(Threads % kWarpSize == 0,
"SharedReduceWarp expects blockDim.x to be a multiple of "
"warp size on CUDA.");
const int tid = threadIdx.x;
const int warp_id = tid / kWarpSize;
const int lane = tid % kWarpSize;
const int num_warps = Threads / kWarpSize;
for (int dest_idx = warp_id; dest_idx < total_dest; dest_idx += num_warps) {
const int prefix = tail == 1 ? dest_idx : dest_idx / tail;
const int suffix = tail == 1 ? 0 : dest_idx % tail;
const int src_base = (prefix * reduce_extent) * tail + suffix;
const int dst_index = prefix * tail + suffix;
T partial = init_value;
for (int rv = lane; rv < reduce_extent; rv += kWarpSize) {
T val = src[src_base + rv * tail];
if constexpr (UseAbs) {
val = val < T(0) ? -val : val;
}
partial = Reducer()(partial, val);
}
unsigned mask = __activemask();
for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) {
T other = __shfl_down_sync(mask, partial, offset);
partial = Reducer()(partial, other);
}
if (lane == 0) {
if constexpr (NeedAccumulate) {
partial = Reducer()(dst[dst_index], partial);
}
dst[dst_index] = partial;
}
}
}
};
template <class Reducer, int threads, int scale, int thread_offset = 0, template <class Reducer, int threads, int scale, int thread_offset = 0,
int all_threads = threads> int all_threads = threads>
struct AllReduce { struct AllReduce {
......
...@@ -22,6 +22,71 @@ struct MinOp { ...@@ -22,6 +22,71 @@ struct MinOp {
} }
}; };
struct BitAndOp {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return x & y;
}
};
struct BitOrOp {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return x | y;
}
};
struct BitXorOp {
template <typename T> TL_DEVICE T operator()(T const &x, T const &y) {
return x ^ y;
}
};
template <class Reducer, int Threads, bool UseAbs, bool NeedAccumulate>
struct SharedReduceWarp {
template <typename T>
static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst,
int total_dest, int reduce_extent, int tail,
T init_value) {
if (total_dest <= 0 || reduce_extent <= 0)
return;
constexpr int kWarpSize = 64;
static_assert(Threads % kWarpSize == 0,
"SharedReduceWarp expects blockDim.x to be a multiple of "
"wave size on HIP.");
const int tid = threadIdx.x;
const int warp_id = tid / kWarpSize;
const int lane = tid % kWarpSize;
const int num_warps = Threads / kWarpSize;
for (int dest_idx = warp_id; dest_idx < total_dest; dest_idx += num_warps) {
const int prefix = tail == 1 ? dest_idx : dest_idx / tail;
const int suffix = tail == 1 ? 0 : dest_idx % tail;
const int src_base = (prefix * reduce_extent) * tail + suffix;
const int dst_index = prefix * tail + suffix;
T partial = init_value;
for (int rv = lane; rv < reduce_extent; rv += kWarpSize) {
T val = src[src_base + rv * tail];
if constexpr (UseAbs) {
val = val < T(0) ? -val : val;
}
partial = Reducer()(partial, val);
}
for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) {
T other = __shfl_down(partial, offset, kWarpSize);
partial = Reducer()(partial, other);
}
if (lane == 0) {
if constexpr (NeedAccumulate) {
partial = Reducer()(dst[dst_index], partial);
}
dst[dst_index] = partial;
}
}
}
};
template <class Reducer, int threads, int scale, int thread_offset = 0> template <class Reducer, int threads, int scale, int thread_offset = 0>
struct AllReduce { struct AllReduce {
static_assert(threads == 1024 || threads == 512 || threads == 256 || static_assert(threads == 1024 || threads == 512 || threads == 256 ||
......
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
tilelang.testing.set_random_seed()
def _make_shared_reduce(M, N, dtype, reduce_cb):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1) as _:
A_shared = T.alloc_shared((M, N), dtype)
B_shared = T.alloc_shared((M,), dtype)
T.copy(A, A_shared)
reduce_cb(T, A_shared, B_shared)
T.copy(B_shared, B)
return main
def _run_program(program, ref_program, atol=1e-2, rtol=1e-2):
jit_kernel = tl.compile(program, out_idx=-1)
profiler = jit_kernel.get_profiler()
profiler.assert_allclose(ref_program, atol=atol, rtol=rtol)
def reduce_max_test(M, N, dtype="float16"):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)
T.copy(A, A_local)
T.reduce_max(A_local, B_local, dim=1)
T.copy(B_local, B)
return main
def reduce_sum_test(M, N, dtype="float32"):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)
T.copy(A, A_local)
T.reduce_sum(A_local, B_local, dim=1)
T.copy(B_local, B)
return main
def reduce_sum_ss(M, N, dtype="float32"):
return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_sum(src, dst, dim=1))
def reduce_max_ss(M, N, dtype="float32"):
return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_max(src, dst, dim=1))
def reduce_min_ss(M, N, dtype="float32"):
return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_min(src, dst, dim=1))
def reduce_abssum_ss(M, N, dtype="float32"):
return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_abssum(src, dst, dim=1))
def reduce_absmax_ss(M, N, dtype="float32"):
return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_absmax(src, dst, dim=1))
def run_reduce_sum(M, N, dtype="float32", mode="rr"):
if mode == "rr":
program = reduce_sum_test(M, N, dtype)
elif mode == "ss":
program = reduce_sum_ss(M, N, dtype)
else:
raise NotImplementedError("run_reduce_sum only supports rr and ss")
_run_program(program, lambda A: A.sum(dim=1))
def run_shared_reduce(program_builder, ref_program, M, N, dtype="float32"):
program = program_builder(M, N, dtype)
_run_program(program, ref_program)
def run_reduce_max(M, N, dtype="float16"):
program = reduce_max_test(M, N, dtype)
_run_program(program, lambda A: A.max(dim=1).values, atol=1e-2, rtol=1e-2)
def test_reduce_sum():
run_reduce_sum(256, 256)
run_reduce_sum(512, 128)
run_reduce_sum(128, 512)
def test_reduce_sum_shared():
run_reduce_sum(64, 64, mode="ss")
run_reduce_sum(32, 96, mode="ss")
def test_reduce_max():
run_reduce_max(256, 256, "float16")
run_reduce_max(512, 128, "float16")
run_reduce_max(256, 256, "float32")
def test_reduce_max_shared():
run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 64, 64, "float32")
run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 96, 48, "float32")
def test_reduce_min_shared():
run_shared_reduce(reduce_min_ss, lambda A: A.min(dim=1).values, 64, 64, "float32")
def test_reduce_abssum_shared():
run_shared_reduce(reduce_abssum_ss, lambda A: A.abs().sum(dim=1), 64, 64, "float32")
def test_reduce_absmax_shared():
run_shared_reduce(reduce_absmax_ss, lambda A: A.abs().max(dim=1).values, 64, 64, "float32")
def reduce_sum_test_clear(M, N, dtype="float32"):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)
T.copy(A, A_local)
T.fill(B_local, 1)
T.reduce_sum(A_local, B_local, dim=1, clear=False)
T.copy(B_local, B)
return main
def run_reduce_sum_clear(M, N, dtype="float32"):
program = reduce_sum_test_clear(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
def ref_program(A):
return A.sum(dim=1) + 1
import torch
dummy_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda()
ref_out = ref_program(dummy_A)
tl_out = jit_kernel(dummy_A)
torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2)
def test_reduce_sum_clear():
run_reduce_sum_clear(256, 256, "float32")
run_reduce_sum_clear(512, 128, "float32")
run_reduce_sum_clear(128, 512, "float32")
def reduce_max_test_clear(M, N, dtype="float16"):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)
T.copy(A, A_local)
T.fill(B_local, -T.infinity(dtype))
T.reduce_max(A_local, B_local, dim=1, clear=False)
T.copy(B_local, B)
return main
def run_reduce_max_clear(M, N, dtype="float16"):
program = reduce_max_test_clear(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
def ref_program(A):
return A.max(dim=1).values
import torch
dummy_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda()
ref_out = ref_program(dummy_A)
tl_out = jit_kernel(dummy_A)
torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2)
def test_reduce_max_clear():
run_reduce_max_clear(256, 256, "float16")
if __name__ == "__main__":
tilelang.testing.main()
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
def reduce_max_test(M, N, dtype="float16"):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)
# Copy input to local
T.copy(A, A_local)
# Perform reduce_max operation
T.reduce_max(A_local, B_local, dim=1)
# Copy result back
T.copy(B_local, B)
return main
def run_reduce_max(M, N, dtype="float16"):
program = reduce_max_test(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
profiler = jit_kernel.get_profiler()
def ref_program(A):
return A.max(dim=1).values
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reduce_max():
# Test different sizes
run_reduce_max(256, 256)
run_reduce_max(512, 128)
run_reduce_max(128, 512)
# Test different dtypes
run_reduce_max(256, 256, "float32")
run_reduce_max(256, 256, "float16")
def reduce_max_test_clear(M, N, dtype="float16"):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)
T.copy(A, A_local)
T.fill(B_local, -T.infinity(dtype))
T.reduce_max(A_local, B_local, dim=1, clear=False)
T.copy(B_local, B)
return main
def run_reduce_max_clear(M, N, dtype="float16"):
program = reduce_max_test_clear(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
print(jit_kernel.get_kernel_source())
def ref_program(A):
return A.max(dim=1).values
import torch
dummp_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda()
ref_out = ref_program(dummp_A)
tl_out = jit_kernel(dummp_A)
print(tl_out)
print(ref_out)
torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2)
def test_reduce_max_clear():
run_reduce_max_clear(256, 256, "float16")
if __name__ == "__main__":
tilelang.testing.main()
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
tilelang.testing.set_random_seed()
def reduce_sum_test(M, N, dtype="float32"):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)
# Copy input to local
T.copy(A, A_local)
# Perform reduce_sum operation
T.reduce_sum(A_local, B_local, dim=1)
# Copy result back
T.copy(B_local, B)
return main
def run_reduce_sum(M, N, dtype="float32"):
program = reduce_sum_test(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
profiler = jit_kernel.get_profiler()
def ref_program(A):
return A.sum(dim=1)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reduce_sum():
# Test different sizes
run_reduce_sum(256, 256)
run_reduce_sum(512, 128)
run_reduce_sum(128, 512)
def reduce_sum_test_clear(M, N, dtype="float32"):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_local = T.alloc_fragment((M,), dtype)
T.copy(A, A_local)
T.fill(B_local, 1)
T.reduce_sum(A_local, B_local, dim=1, clear=False)
T.copy(B_local, B)
return main
def run_reduce_sum_clear(M, N, dtype="float32"):
program = reduce_sum_test_clear(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
def ref_program(A):
return A.sum(dim=1) + 1
import torch
dummp_A = torch.randn((M, N), dtype=getattr(torch, dtype)).cuda()
ref_out = ref_program(dummp_A)
tl_out = jit_kernel(dummp_A)
torch.testing.assert_close(tl_out, ref_out, atol=1e-2, rtol=1e-2)
def test_reduce_sum_clear():
run_reduce_sum_clear(256, 256, "float32")
run_reduce_sum_clear(512, 128, "float32")
run_reduce_sum_clear(128, 512, "float32")
if __name__ == "__main__":
tilelang.testing.main()
...@@ -12,7 +12,7 @@ from tvm.tir.stmt_functor import post_order_visit ...@@ -12,7 +12,7 @@ from tvm.tir.stmt_functor import post_order_visit
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY = """ PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY = """
cudaError_t result_{0} = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1}); cudaError_t result_{0} = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1});
if (result_{0} != CUDA_SUCCESS) {{ if (result_{0} != cudaSuccess) {{
snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", {1}, cudaGetErrorString(result_{0})); snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", {1}, cudaGetErrorString(result_{0}));
return -1; return -1;
}} }}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment