logical.cc 1.59 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
/*!
 * \file tl/op/logical.cc
 * \brief Logical operations.
 *
 */

#include <tvm/ffi/function.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

namespace tvm {
namespace tl {
using namespace tir;

PrimExpr any_of_op(PrimExpr args) {
  const CallNode *call = args.as<CallNode>();
  CHECK(call != nullptr);
  const Array<PrimExpr> &arg = call->args;
  ICHECK_EQ(arg.size(), 2);
  PrimExpr buffer_address = arg[0];
  PrimExpr elems = arg[1];
  return tir::Call(DataType::Bool(), tir::builtin::call_extern(),
                   {StringImm("tl::Any"), buffer_address, elems});
}

PrimExpr all_of_op(PrimExpr args) {
  const CallNode *call = args.as<CallNode>();
  CHECK(call != nullptr);
  const Array<PrimExpr> &arg = call->args;
  ICHECK_EQ(arg.size(), 2);
  PrimExpr buffer_address = arg[0];
  PrimExpr elems = arg[1];
  return tir::Call(DataType::Bool(), tir::builtin::call_extern(),
                   {StringImm("tl::All"), buffer_address, elems});
}

TVM_REGISTER_OP("tl.any_of")
    .set_num_inputs(1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure))
    .set_attr<TScriptPrinterName>("TScriptPrinterName", "any_of")
    .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", any_of_op);

TVM_REGISTER_OP("tl.all_of")
    .set_num_inputs(1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure))
    .set_attr<TScriptPrinterName>("TScriptPrinterName", "all_of")
    .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", all_of_op);

} // namespace tl
} // namespace tvm