fill.cc 8.83 KB
Newer Older
1
/*!
2
 * \file tl/op/fill.cc
3
4
5
6
 *
 * Define elment-wise operators.
 */

7
#include "fill.h"
8
9
10
11
12

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

13
#include "../layout/tcgen05_layout.h"
14
#include "../target/utils.h"
15
#include "../transform/common/loop_fusion_utils.h"
16
#include "../transform/common/loop_parallel_transform_utils.h"
17
18
19
20
21
22
23
24
25
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
/**
 * @brief Construct a Fill operator node from call arguments and a buffer map.
 *
 * This constructor builds a FillNode describing an element-wise fill of a
 * destination buffer region with a scalar/vector value and stores it in
 * `data_`.
 *
 * Detailed behavior:
 * - If `args[0]` is a `BufferLoad`, the loaded buffer becomes the destination
 * and the load indices are converted to per-dimension ranges:
 *   - `Ramp(base, lanes, stride)` is converted to `Range(base, lanes)`. Only
 * stride == 1 and constant `lanes` are supported.
 *   - Non-ramp indices become `Range(index, 1)`.
 * - Otherwise `args[0]` is treated as an access pointer; the destination buffer
 * is resolved via `vmap[GetVarFromAccessPtr(args[0])]` and the region is the
 * full buffer shape for each dimension.
 * - `args[1]` is used as the fill value; it is cast to the destination buffer's
 * dtype if necessary.
 * - Performs validation:
 *   - Region dimensionality must match destination rank.
 *   - For statically-known region mins and extents, checks that mins >= 0 and
 * extents do not exceed the corresponding destination shape extents.
 *
 * Parameters:
 * @param args Call arguments: expected layout is [dst_access_or_bufferload,
 * value].
 *             - args[0]: destination access (BufferLoad or pointer expression).
 *             - args[1]: value to fill (scalar or vector).
 * @param vmap Mapping from buffer variables to Buffer objects; used to resolve
 * the destination when args[0] is not a BufferLoad.
 *
 * Notes:
 * - The constructor enforces constraints (e.g., stride == 1 ramps, constant
 * lanes) and will terminate (via CHECK/ICHECK) if inputs are unsupported or out
 * of bounds.
 */
62
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
63
  ObjectPtr<FillNode> node = make_object<FillNode>();
64
65
66
67
68
69
70
71
72
73

  if (args[0]->IsInstance<BufferLoadNode>()) {
    auto buffer_load = Downcast<BufferLoad>(args[0]);
    for (const auto &index : buffer_load->indices) {
      if (const auto *ramp = index.as<RampNode>()) {
        CHECK(ramp->stride.as<IntImmNode>()->value == 1)
            << "Only stride 1 ramps are supported";
        const auto *lanes = ramp->lanes.as<IntImmNode>();
        CHECK(lanes)
            << "Scalable vectors not supported in BufferRegion conversion";
74
        node->region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
75
      } else {
76
        node->region.push_back(Range::FromMinExtent(index, 1));
77
78
      }
    }
79
    node->dst = buffer_load->buffer;
80
  } else {
81
82
83
    node->dst = vmap[GetVarFromAccessPtr(args[0])];
    for (int i = 0; i < node->dst->shape.size(); i++) {
      node->region.push_back(Range(0, node->dst->shape[i]));
84
85
86
    }
  }

87
88
  if (args[1]->dtype != node->dst->dtype) {
    node->value = Cast(node->dst->dtype, args[1]);
89
  } else {
90
    node->value = args[1];
91
  }
92

93
94
95
96
  ICHECK(node->region.size() == node->dst->shape.size())
      << "region size = " << node->region.size()
      << " != " << node->dst->shape.size();
  for (int i = 0; i < node->region.size(); i++) {
97
    // bound check if region is static
98
99
    if (node->region[i]->min.as<IntImm>()) {
      int64_t min = Downcast<IntImm>(node->region[i]->min)->value;
100
101
      ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
    }
102
103
104
105
    if (node->region[i]->extent.as<IntImm>()) {
      int64_t extent = Downcast<IntImm>(node->region[i]->extent)->value;
      ICHECK_LE(extent, Downcast<IntImm>(node->dst->shape[i])->value)
          << "region[" << i << "] = " << extent << " > " << node->dst->shape[i];
106
107
    }
  }
108
  data_ = std::move(node);
109
110
}

111
112
113
114
115
116
117
118
/**
 * @brief Create a copy of this FillNode and return it as a TileOperator.
 *
 * Constructs a new FillNode by copying the current node and wraps the copy in a
 * Fill TileOperator.
 *
 * @return TileOperator A TileOperator that owns the copied FillNode.
 */
119
120
121
122
123
TileOperator FillNode::Clone() const {
  auto op = make_object<FillNode>(*this);
  return Fill(op);
}

124
125
126
127
128
129
130
131
132
133
134
135
/**
 * @brief Build a SIMT-style nested parallel loop that fills the destination
 * buffer.
 *
 * Constructs per-dimension data-parallel loop iterators matching this node's
 * region extents, emits a BufferStore that writes the node's `value` into `dst`
 * at the loop indices, and nests the loops (innermost to outermost) as parallel
 * `For` nodes. Returns the outermost `For` loop representing the complete
 * multi-dimensional fill kernel.
 *
 * @return For Outermost parallel `For` loop of the generated nested SIMT loop.
 */
136
For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
137
138
139
140
  int ndim = dst->shape.size();
  Array<IterVar> loop_vars;
  Array<PrimExpr> dst_indices;
  for (int i = 0; i < ndim; i++) {
141
    Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype);
142
    loop_vars.push_back({region[i], var, IterVarType::kDataPar});
143
144
145
146
    dst_indices.push_back(var);
  }
  Stmt body = BufferStore(dst, value, dst_indices);
  for (int i = ndim - 1; i >= 0; i--) {
147
148
    body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
               ForKind::kParallel, body);
149
150
151
152
  }
  return Downcast<For>(body);
}

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
/**
 * @brief Lower this Fill operator to a TIR statement for the target.
 *
 * Lowers the FillNode into a Stmt according to the destination buffer scope:
 * - "local.fragment" and shared ("shared", "shared.dyn"): create a parallel
 *   operation from a SIMT loop, infer its layout, partition the root loop by
 *   the thread variable, vectorize the resulting thread loop, and, if a
 *   per-thread predicate exists, guard the vectorized loop with that
 *   predicate.
 * - "local": build a SIMT loop and return its vectorized form.
 * - other scopes: fatal error.
 *
 * The lowering may query layout and thread information from @p T and uses the
 * provided analyzer for any required arithmetic/layout analysis.
 *
 * @param T Lowering arguments (target, thread bounds, thread var, layout map).
 * @return Stmt The lowered TIR statement implementing the fill.
 */
171
Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
172
  if (dst.scope() == "local.fragment") {
173
    auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
174
175
    par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
                         false, T.buffer_remap},
176
177
178
                        InferLevel::kFree);
    auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
                                     par_op->GetLoopLayout());
179
180
    auto vectorized_thread_loop = VectorizeLoop(thread_loop);
    if (par_op->GetPredicate(T.thread_var).defined()) {
181
182
      return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                        vectorized_thread_loop);
183
184
185
186
187
188
    }
    return vectorized_thread_loop;
  } else if (dst.scope() == "local") {
    auto init_loop = MakeSIMTLoop(analyzer);
    auto vectorized_thread_loop = VectorizeLoop(init_loop);
    return vectorized_thread_loop;
Kurisu's avatar
Kurisu committed
189
190
  } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" ||
             dst.scope() == "global") {
191
    auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
192
193
    par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
                         false, T.buffer_remap},
194
195
196
197
                        InferLevel::kFree);
    auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
                                     par_op->GetLoopLayout());
    auto vectorized_thread_loop = VectorizeLoop(thread_loop);
198
199
200
201
    if (par_op->GetPredicate(T.thread_var).defined()) {
      return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                        vectorized_thread_loop);
    }
202
    return vectorized_thread_loop;
203
204
205
206
207
  } else {
    LOG(FATAL) << "Unsupported scope " << dst.scope();
  }
}

208
209
210
211
212
213
214
215
216
217
218
/**
 * @brief Infer memory/layout mapping for the Fill operator.
 *
 * Returns the layout mapping produced by layout inference for this FillNode.
 * Currently no layout inference is performed for Fill and the function returns
 * an empty LayoutMap.
 *
 * @param T Context required for layout inference (unused).
 * @param level The inference level requested (unused).
 * @return LayoutMap Empty map indicating no inferred layouts for this operator.
 */
219
220
221
222
223
LayoutMap FillNode::InferLayout(const LayoutInferArgs &T,
                                InferLevel level) const {
  return {};
}

224
225
TIR_REGISTER_TL_OP(Fill, fill)
    .set_num_inputs(2)
226
227
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
228

229
TVM_FFI_STATIC_INIT_BLOCK({ FillNode::RegisterReflection(); });
230

231
232
} // namespace tl
} // namespace tvm