math.cc 1.02 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file tl/op/math.cc
 * \brief Math operations.
 *
 */

7
#include <tvm/ffi/function.h>
8
9
10
11
12
13
14
15
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

namespace tvm {
namespace tl {
using namespace tir;

16
PrimExpr pow_of_int_op(PrimExpr args) {
17
18
19
20
21
22
  const CallNode *call = args.as<CallNode>();
  CHECK(call != nullptr);
  const Array<PrimExpr> &arg = call->args;
  ICHECK_EQ(arg.size(), 2);
  PrimExpr base = arg[0];
  PrimExpr exp = arg[1];
23
24
  String pow_of_int_name =
      "tl::pow_of_int<" + std::to_string(exp.as<IntImmNode>()->value) + ">";
25
  return tir::Call(base.dtype(), tir::builtin::call_extern(),
26
                   {StringImm(pow_of_int_name), base});
27
28
}

29
TVM_REGISTER_OP("tl.pow_of_int")
30
31
32
    .set_num_inputs(2)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure))
33
34
    .set_attr<TScriptPrinterName>("TScriptPrinterName", "pow_of_int")
    .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", pow_of_int_op);
35
36
37

} // namespace tl
} // namespace tvm