op.cc 2.38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

/*!
 * \file tl/op/op.cc
 *
 * Define operators usd in tile library.
 */

#include "op.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

namespace tvm {
namespace tl {

using namespace tir;

TIR_REGISTER_TL_OP(RegionOp, region)
    .set_num_inputs(-1)
    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));

std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap) {
  auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
  Op op = call->op.as<Op>().value();
  if (op_map.count(op)) {
    Operator* ptr = static_cast<Operator*>(op_map[op](call->args, vmap));
    ICHECK(ptr != nullptr);
    return std::unique_ptr<Operator>(ptr);
  }
  return nullptr;
}

std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap) {
  if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) {
    auto call = stmt.as<EvaluateNode>()->value.as<CallNode>();
    return ParseOperator(GetRef<Call>(call), vmap);
  }
  return nullptr;
}

Var GetVarFromAccessPtr(const PrimExpr& expr) {
  auto call = expr.as<CallNode>();
  ICHECK(call);
  ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
  auto var = call->args[1].as<VarNode>();
  ICHECK(var);
  return GetRef<Var>(var);
}

RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
  size_t n = args.size();
  size_t ndim = n - 2;
  auto load = args[0].as<BufferLoadNode>();
  ICHECK(load);
  ICHECK(load->indices.size() == ndim);
  buffer_ = load->buffer;
  access_mask_ = static_cast<int>(*as_const_int(args[1]));
  for (size_t i = 0; i < ndim; i++) {
    PrimExpr min = load->indices[i];
    PrimExpr extent = args[2 + i];
    ranges_.push_back(Range::FromMinExtent(min, extent));
  }
}

bool RegionOp::IsFullRegion() const {
  for (size_t i = 0; i < ranges_.size(); i++) {
    if (!is_zero(ranges_[i]->min)) return false;
    if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i])) return false;
  }
  return true;
}

Stmt Operator::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
  ICHECK(0) << "Not Implemented Lower method.";
  return Evaluate(0);
}

Stmt Operator::Canonialize(const CanonializeArgs& T, arith::Analyzer* analyzer) const { return {}; }

LayoutMap Operator::InferLayout(const LayoutInferArgs& T, InferLevel level) { return {}; }

}  // namespace tl
}  // namespace tvm