elem.cc 4.97 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/*!
 * \file tl/op/elem.cc
 *
 * Define elment-wise operators.
 */

#include "elem.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../target/utils.h"
14
#include "../transform/common/loop_fusion_utils.h"
15
#include "../transform/common/loop_parallel_transform_utils.h"
16
17
18
19
20
21
22
23
24
25
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
26
  ObjectPtr<FillNode> node = make_object<FillNode>();
27
28
29
30
31
32
33
34
35
36

  if (args[0]->IsInstance<BufferLoadNode>()) {
    auto buffer_load = Downcast<BufferLoad>(args[0]);
    for (const auto &index : buffer_load->indices) {
      if (const auto *ramp = index.as<RampNode>()) {
        CHECK(ramp->stride.as<IntImmNode>()->value == 1)
            << "Only stride 1 ramps are supported";
        const auto *lanes = ramp->lanes.as<IntImmNode>();
        CHECK(lanes)
            << "Scalable vectors not supported in BufferRegion conversion";
37
        node->region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
38
      } else {
39
        node->region.push_back(Range::FromMinExtent(index, 1));
40
41
      }
    }
42
    node->dst = buffer_load->buffer;
43
  } else {
44
45
46
    node->dst = vmap[GetVarFromAccessPtr(args[0])];
    for (int i = 0; i < node->dst->shape.size(); i++) {
      node->region.push_back(Range(0, node->dst->shape[i]));
47
48
49
    }
  }

50
51
  if (args[1]->dtype != node->dst->dtype) {
    node->value = Cast(node->dst->dtype, args[1]);
52
  } else {
53
    node->value = args[1];
54
  }
55

56
57
58
59
  ICHECK(node->region.size() == node->dst->shape.size())
      << "region size = " << node->region.size()
      << " != " << node->dst->shape.size();
  for (int i = 0; i < node->region.size(); i++) {
60
    // bound check if region is static
61
62
    if (node->region[i]->min.as<IntImm>()) {
      int64_t min = Downcast<IntImm>(node->region[i]->min)->value;
63
64
      ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
    }
65
66
67
68
    if (node->region[i]->extent.as<IntImm>()) {
      int64_t extent = Downcast<IntImm>(node->region[i]->extent)->value;
      ICHECK_LE(extent, Downcast<IntImm>(node->dst->shape[i])->value)
          << "region[" << i << "] = " << extent << " > " << node->dst->shape[i];
69
70
    }
  }
71
  data_ = std::move(node);
72
73
}

74
75
76
77
78
79
TileOperator FillNode::Clone() const {
  auto op = make_object<FillNode>(*this);
  return Fill(op);
}

For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
80
81
82
83
  int ndim = dst->shape.size();
  Array<IterVar> loop_vars;
  Array<PrimExpr> dst_indices;
  for (int i = 0; i < ndim; i++) {
84
    Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype);
85
    loop_vars.push_back({region[i], var, IterVarType::kDataPar});
86
87
88
89
    dst_indices.push_back(var);
  }
  Stmt body = BufferStore(dst, value, dst_indices);
  for (int i = ndim - 1; i >= 0; i--) {
90
91
    body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
               ForKind::kParallel, body);
92
93
94
95
  }
  return Downcast<For>(body);
}

96
Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
97
  if (dst.scope() == "local.fragment") {
98
    auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
99
    par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
100
                        InferLevel::kFree);
101
    par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
102
103
104
                        InferLevel::kFree);
    auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
                                     par_op->GetLoopLayout());
105
106
    auto vectorized_thread_loop = VectorizeLoop(thread_loop);
    if (par_op->GetPredicate(T.thread_var).defined()) {
107
108
      return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                        vectorized_thread_loop);
109
110
111
112
113
114
    }
    return vectorized_thread_loop;
  } else if (dst.scope() == "local") {
    auto init_loop = MakeSIMTLoop(analyzer);
    auto vectorized_thread_loop = VectorizeLoop(init_loop);
    return vectorized_thread_loop;
115
  } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared") {
116
    auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
117
    par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
118
119
120
121
                        InferLevel::kFree);
    auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
                                     par_op->GetLoopLayout());
    auto vectorized_thread_loop = VectorizeLoop(thread_loop);
122
123
124
125
    if (par_op->GetPredicate(T.thread_var).defined()) {
      return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                        vectorized_thread_loop);
    }
126
    return vectorized_thread_loop;
127
128
129
130
131
  } else {
    LOG(FATAL) << "Unsupported scope " << dst.scope();
  }
}

132
133
134
135
136
LayoutMap FillNode::InferLayout(const LayoutInferArgs &T,
                                InferLevel level) const {
  return {};
}

137
138
TIR_REGISTER_TL_OP(Fill, fill)
    .set_num_inputs(2)
139
140
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
141

142
143
} // namespace tl
} // namespace tvm