Commit 7c266adf authored by Cunxiao Ni's avatar Cunxiao Ni Committed by LeiWang1999
Browse files

[Enhancement] Move T.any_of and T.all_of op registration from python into cpp (#398)

* [Enhancement] Move T.any_of and T.all_of op registration from python into cpp

* format

* add license
parent cffcf1c2
/*!
* \file tl/op/logical.cc
* \brief Logical operations.
*
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
namespace tvm {
namespace tl {
using namespace tir;
PrimExpr any_of_op(PrimExpr args) {
const CallNode *call = args.as<CallNode>();
CHECK(call != nullptr);
const Array<PrimExpr> &arg = call->args;
ICHECK_EQ(arg.size(), 2);
PrimExpr buffer_address = arg[0];
PrimExpr elems = arg[1];
return tir::Call(DataType::Bool(), tir::builtin::call_extern(),
{StringImm("tl::Any"), buffer_address, elems});
}
PrimExpr all_of_op(PrimExpr args) {
const CallNode *call = args.as<CallNode>();
CHECK(call != nullptr);
const Array<PrimExpr> &arg = call->args;
ICHECK_EQ(arg.size(), 2);
PrimExpr buffer_address = arg[0];
PrimExpr elems = arg[1];
return tir::Call(DataType::Bool(), tir::builtin::call_extern(),
{StringImm("tl::All"), buffer_address, elems});
}
TVM_REGISTER_OP("tl.any_of")
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "any_of")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", any_of_op);
TVM_REGISTER_OP("tl.all_of")
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "all_of")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", all_of_op);
} // namespace tl
} // namespace tvm
\ No newline at end of file
"""The language interface for tl programs."""
from tilelang import language as T
import tvm
from tvm.tir import Buffer, BufferRegion
from tvm.ir import Range
from tvm.ir import register_op_attr, register_intrin_lowering
from tvm import tir
from typing import Union
from tilelang.utils.language import get_buffer_elems
# TODO: move this part into src to reduce runtime overhead
def any_of_op(op):
args = op.args
assert len(args) == 2
buffer_address, elems = args
return T.call_extern("bool", "tl::Any", buffer_address, elems)
register_op_attr("tl.any_of", "TCallEffectKind", tvm.tir.CallEffectKind.Pure)
register_op_attr("tl.any_of", "TScriptPrinterName", "any_of")
register_intrin_lowering("tl.any_of", target="cuda", f=any_of_op)
def any_of(buffer: Union[T.Tensor, BufferRegion]):
"""Check if any element in the buffer is true.
......@@ -57,18 +42,6 @@ def any_of(buffer: Union[T.Tensor, BufferRegion]):
raise ValueError(f"Invalid buffer type: {type(buffer)}")
def all_of_op(op):
args = op.args
assert len(args) == 2
buffer_address, elems = args
return T.call_extern("bool", "tl::All", buffer_address, elems)
register_op_attr("tl.all_of", "TCallEffectKind", tvm.tir.CallEffectKind.Pure)
register_op_attr("tl.all_of", "TScriptPrinterName", "all_of")
register_intrin_lowering("tl.all_of", target="cuda", f=all_of_op)
def all_of(buffer: Union[T.Tensor, BufferRegion]):
"""Check if all elements in the buffer are true.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment