math.cc 2.16 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file tl/op/math.cc
 * \brief Math operations.
 *
 */

7
#include <tvm/ffi/function.h>
8
9
10
11
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

12
13
#include "../support/ffi_aliases.h"

14
15
16
17
namespace tvm {
namespace tl {
using namespace tir;

18
PrimExpr pow_of_int_op(PrimExpr args) {
19
20
21
22
23
24
  const CallNode *call = args.as<CallNode>();
  CHECK(call != nullptr);
  const Array<PrimExpr> &arg = call->args;
  ICHECK_EQ(arg.size(), 2);
  PrimExpr base = arg[0];
  PrimExpr exp = arg[1];
25
26
  String pow_of_int_name =
      "tl::pow_of_int<" + std::to_string(exp.as<IntImmNode>()->value) + ">";
27
  return tir::Call(base.dtype(), tir::builtin::call_extern(),
28
                   {StringImm(pow_of_int_name), base});
29
30
}

31
TVM_REGISTER_OP("tl.pow_of_int")
32
33
34
    .set_num_inputs(2)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure))
35
    .set_attr<TScriptPrinterName>("TScriptPrinterName", "pow_of_int")
36
    .set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", pow_of_int_op)
37
    .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", pow_of_int_op);
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
PrimExpr infinity_op(PrimExpr args) {
  const CallNode *call = args.as<CallNode>();
  CHECK(call != nullptr);
  const DataType &dtype = call->dtype;
  ICHECK_EQ(dtype.lanes(), 1);

  // NOTE(wt): Codegen for PrintConst:Inf will handle this based on dtype
  if (dtype.is_float()) {
    if (dtype.bits() == 64 || dtype.bits() == 32 || dtype.bits() == 16) {
      return FloatImm(dtype, std::numeric_limits<float>::infinity(),
                      call->span);
    }
  } else if (dtype.is_bfloat16()) {
    return FloatImm(dtype, std::numeric_limits<float>::infinity(), call->span);
  }
  LOG(FATAL) << "Cannot decide infinity for type " << dtype;
  throw; // Unreachable, keeps compiler happy
}

TVM_REGISTER_OP("tl.infinity")
    .set_num_inputs(1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure))
    .set_attr<TScriptPrinterName>("TScriptPrinterName", "infinity")
63
64
    .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", infinity_op)
    .set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", infinity_op);
65

66
67
} // namespace tl
} // namespace tvm