math.cc 1.05 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file tl/op/math.cc
 * \brief Math operations.
 *
 */

7
#include <tvm/ffi/function.h>
8
9
10
11
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

12
13
#include "../support/ffi_aliases.h"

14
15
16
17
namespace tvm {
namespace tl {
using namespace tir;

18
PrimExpr pow_of_int_op(PrimExpr args) {
19
20
21
22
23
24
  const CallNode *call = args.as<CallNode>();
  CHECK(call != nullptr);
  const Array<PrimExpr> &arg = call->args;
  ICHECK_EQ(arg.size(), 2);
  PrimExpr base = arg[0];
  PrimExpr exp = arg[1];
25
26
  String pow_of_int_name =
      "tl::pow_of_int<" + std::to_string(exp.as<IntImmNode>()->value) + ">";
27
  return tir::Call(base.dtype(), tir::builtin::call_extern(),
28
                   {StringImm(pow_of_int_name), base});
29
30
}

31
TVM_REGISTER_OP("tl.pow_of_int")
32
33
34
    .set_num_inputs(2)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure))
35
36
    .set_attr<TScriptPrinterName>("TScriptPrinterName", "pow_of_int")
    .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", pow_of_int_op);
37
38
39

} // namespace tl
} // namespace tvm