op.cc 2.38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/*!
 * \file tl/op/op.cc
 *
 * Define operators usd in tile library.
 */

#include "op.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

namespace tvm {
namespace tl {

using namespace tir;

TIR_REGISTER_TL_OP(RegionOp, region)
    .set_num_inputs(-1)
20
21
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure));
22
23
24
25
26

std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap) {
  auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
  Op op = call->op.as<Op>().value();
  if (op_map.count(op)) {
27
    Operator *ptr = static_cast<Operator *>(op_map[op](call->args, vmap));
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    ICHECK(ptr != nullptr);
    return std::unique_ptr<Operator>(ptr);
  }
  return nullptr;
}

std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap) {
  if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) {
    auto call = stmt.as<EvaluateNode>()->value.as<CallNode>();
    return ParseOperator(GetRef<Call>(call), vmap);
  }
  return nullptr;
}

42
Var GetVarFromAccessPtr(const PrimExpr &expr) {
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
  auto call = expr.as<CallNode>();
  ICHECK(call);
  ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
  auto var = call->args[1].as<VarNode>();
  ICHECK(var);
  return GetRef<Var>(var);
}

RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
  size_t n = args.size();
  size_t ndim = n - 2;
  auto load = args[0].as<BufferLoadNode>();
  ICHECK(load);
  ICHECK(load->indices.size() == ndim);
  buffer_ = load->buffer;
  access_mask_ = static_cast<int>(*as_const_int(args[1]));
  for (size_t i = 0; i < ndim; i++) {
    PrimExpr min = load->indices[i];
    PrimExpr extent = args[2 + i];
    ranges_.push_back(Range::FromMinExtent(min, extent));
  }
}

bool RegionOp::IsFullRegion() const {
  for (size_t i = 0; i < ranges_.size(); i++) {
68
69
70
71
    if (!is_zero(ranges_[i]->min))
      return false;
    if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i]))
      return false;
72
73
74
75
  }
  return true;
}

76
Stmt Operator::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
77
78
79
80
  ICHECK(0) << "Not Implemented Lower method.";
  return Evaluate(0);
}

81
82
83
84
Stmt Operator::Canonialize(const CanonializeArgs &T,
                           arith::Analyzer *analyzer) const {
  return {};
}
85

86
87
88
LayoutMap Operator::InferLayout(const LayoutInferArgs &T, InferLevel level) {
  return {};
}
89

90
91
} // namespace tl
} // namespace tvm