fill.cc 10.2 KB
Newer Older
1
/*!
2
 * \file tl/op/fill.cc
3
4
5
6
 *
 * Define elment-wise operators.
 */

7
#include "fill.h"
8
9
10
11
12

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

13
#include "../layout/tcgen05_layout.h"
14
#include "../target/utils.h"
15
#include "../transform/common/loop_fusion_utils.h"
16
#include "../transform/common/loop_parallel_transform_utils.h"
17
18
19
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"
20
#include "region.h"
21
22
23
24
25
26

namespace tvm {
namespace tl {

using namespace tir;

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
/**
 * @brief Construct a Fill operator node from call arguments and a buffer map.
 *
 * This constructor builds a FillNode describing an element-wise fill of a
 * destination buffer region with a scalar/vector value and stores it in
 * `data_`.
 *
 * Detailed behavior:
 * - If `args[0]` is a `BufferLoad`, the loaded buffer becomes the destination
 * and the load indices are converted to per-dimension ranges:
 *   - `Ramp(base, lanes, stride)` is converted to `Range(base, lanes)`. Only
 * stride == 1 and constant `lanes` are supported.
 *   - Non-ramp indices become `Range(index, 1)`.
 * - Otherwise `args[0]` is treated as an access pointer; the destination buffer
 * is resolved via `vmap[GetVarFromAccessPtr(args[0])]` and the region is the
 * full buffer shape for each dimension.
 * - `args[1]` is used as the fill value; it is cast to the destination buffer's
 * dtype if necessary.
 * - Performs validation:
 *   - Region dimensionality must match destination rank.
 *   - For statically-known region mins and extents, checks that mins >= 0 and
 * extents do not exceed the corresponding destination shape extents.
 *
 * Parameters:
 * @param args Call arguments: expected layout is [dst_access_or_bufferload,
 * value].
 *             - args[0]: destination access (BufferLoad or pointer expression).
 *             - args[1]: value to fill (scalar or vector).
 * @param vmap Mapping from buffer variables to Buffer objects; used to resolve
 * the destination when args[0] is not a BufferLoad.
 *
 * Notes:
 * - The constructor enforces constraints (e.g., stride == 1 ramps, constant
 * lanes) and will terminate (via CHECK/ICHECK) if inputs are unsupported or out
 * of bounds.
 */
63
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
64
  ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>();
65

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
  // Case 1: Region descriptor call (tl.region)
  if (const auto *call = args[0].as<CallNode>()) {
    if (call->op.same_as(RegionOp::Get())) {
      auto region = RegionOp(call->args, vmap);
      node->dst = region->GetBuffer();
      node->region = region->GetRanges();
    } else if (call->op.same_as(builtin::tvm_access_ptr())) {
      node->dst = vmap[GetVarFromAccessPtr(args[0])];
      for (int i = 0; i < node->dst->shape.size(); i++) {
        node->region.push_back(Range(0, node->dst->shape[i]));
      }
    } else {
      ICHECK(false) << "Unsupported call op in tl.fill: "
                    << Downcast<Op>(call->op)->name;
    }

    // Case 2: Explicit BufferRegion (legacy path)
  } else if (args[0]->IsInstance<BufferRegionNode>()) {
    auto region = Downcast<BufferRegion>(args[0]);
    node->dst = region->buffer;
    node->region = region->region;

    // Case 3: Vector/scalar region expressed via BufferLoad indices
  } else if (args[0]->IsInstance<BufferLoadNode>()) {
90
91
92
93
94
95
96
97
    auto buffer_load = Downcast<BufferLoad>(args[0]);
    for (const auto &index : buffer_load->indices) {
      if (const auto *ramp = index.as<RampNode>()) {
        CHECK(ramp->stride.as<IntImmNode>()->value == 1)
            << "Only stride 1 ramps are supported";
        const auto *lanes = ramp->lanes.as<IntImmNode>();
        CHECK(lanes)
            << "Scalable vectors not supported in BufferRegion conversion";
98
        node->region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
99
      } else {
100
        node->region.push_back(Range::FromMinExtent(index, 1));
101
102
      }
    }
103
    node->dst = buffer_load->buffer;
104
    // Case 4: Access pointer, fill the full buffer
105
  } else {
106
107
108
    node->dst = vmap[GetVarFromAccessPtr(args[0])];
    for (int i = 0; i < node->dst->shape.size(); i++) {
      node->region.push_back(Range(0, node->dst->shape[i]));
109
110
111
    }
  }

112
113
  if (args[1]->dtype != node->dst->dtype) {
    node->value = Cast(node->dst->dtype, args[1]);
114
  } else {
115
    node->value = args[1];
116
  }
117

118
119
120
121
  ICHECK(node->region.size() == node->dst->shape.size())
      << "region size = " << node->region.size()
      << " != " << node->dst->shape.size();
  for (int i = 0; i < node->region.size(); i++) {
122
    // bound check if region is static
123
124
    if (const auto *min_imm = node->region[i]->min.as<IntImmNode>()) {
      int64_t min = min_imm->value;
125
126
      ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
    }
127
128
129
130
131
132
133
134
135
    if (const auto *extent_imm = node->region[i]->extent.as<IntImmNode>()) {
      // Only perform the upper-bound check when the destination shape
      // extent is also statically known. If the shape is symbolic (e.g., Var),
      // skip this static check to avoid invalid downcasts.
      if (const auto *shape_imm = node->dst->shape[i].as<IntImmNode>()) {
        ICHECK_LE(extent_imm->value, shape_imm->value)
            << "region[" << i << "] = " << extent_imm->value << " > "
            << node->dst->shape[i];
      }
136
137
    }
  }
138
  data_ = std::move(node);
139
140
}

141
142
143
144
145
146
147
148
/**
 * @brief Create a copy of this FillNode and return it as a TileOperator.
 *
 * Constructs a new FillNode by copying the current node and wraps the copy in a
 * Fill TileOperator.
 *
 * @return TileOperator A TileOperator that owns the copied FillNode.
 */
149
TileOperator FillNode::Clone() const {
150
  auto op = tvm::ffi::make_object<FillNode>(*this);
151
152
153
  return Fill(op);
}

154
155
156
157
158
159
160
161
162
163
164
165
/**
 * @brief Build a SIMT-style nested parallel loop that fills the destination
 * buffer.
 *
 * Constructs per-dimension data-parallel loop iterators matching this node's
 * region extents, emits a BufferStore that writes the node's `value` into `dst`
 * at the loop indices, and nests the loops (innermost to outermost) as parallel
 * `For` nodes. Returns the outermost `For` loop representing the complete
 * multi-dimensional fill kernel.
 *
 * @return For Outermost parallel `For` loop of the generated nested SIMT loop.
 */
166
For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
167
168
169
170
  int ndim = dst->shape.size();
  Array<IterVar> loop_vars;
  Array<PrimExpr> dst_indices;
  for (int i = 0; i < ndim; i++) {
171
    Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype);
172
    loop_vars.push_back({region[i], var, IterVarType::kDataPar});
173
174
    // Offset the loop induction variable by region min to honor sliced regions
    dst_indices.push_back(region[i]->min + var);
175
176
177
  }
  Stmt body = BufferStore(dst, value, dst_indices);
  for (int i = ndim - 1; i >= 0; i--) {
178
179
    body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
               ForKind::kParallel, body);
180
181
182
183
  }
  return Downcast<For>(body);
}

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
/**
 * @brief Lower this Fill operator to a TIR statement for the target.
 *
 * Lowers the FillNode into a Stmt according to the destination buffer scope:
 * - "local.fragment" and shared ("shared", "shared.dyn"): create a parallel
 *   operation from a SIMT loop, infer its layout, partition the root loop by
 *   the thread variable, vectorize the resulting thread loop, and, if a
 *   per-thread predicate exists, guard the vectorized loop with that
 *   predicate.
 * - "local": build a SIMT loop and return its vectorized form.
 * - other scopes: fatal error.
 *
 * The lowering may query layout and thread information from @p T and uses the
 * provided analyzer for any required arithmetic/layout analysis.
 *
 * @param T Lowering arguments (target, thread bounds, thread var, layout map).
 * @return Stmt The lowered TIR statement implementing the fill.
 */
202
Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
203
  if (dst.scope() == "local.fragment") {
204
    auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
205
206
    par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
                         false, T.buffer_remap},
207
208
209
                        InferLevel::kFree);
    auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
                                     par_op->GetLoopLayout());
210
    auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
211
    if (par_op->GetPredicate(T.thread_var).defined()) {
212
213
      return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                        vectorized_thread_loop);
214
215
216
217
    }
    return vectorized_thread_loop;
  } else if (dst.scope() == "local") {
    auto init_loop = MakeSIMTLoop(analyzer);
218
    auto vectorized_thread_loop = VectorizeLoop(init_loop, analyzer);
219
    return vectorized_thread_loop;
Kurisu's avatar
Kurisu committed
220
221
  } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" ||
             dst.scope() == "global") {
222
    auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
223
224
    par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
                         false, T.buffer_remap},
225
226
227
                        InferLevel::kFree);
    auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
                                     par_op->GetLoopLayout());
228
    auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
229
230
231
232
    if (par_op->GetPredicate(T.thread_var).defined()) {
      return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                        vectorized_thread_loop);
    }
233
    return vectorized_thread_loop;
234
235
  } else {
    LOG(FATAL) << "Unsupported scope " << dst.scope();
236
    return Stmt();
237
238
239
  }
}

240
241
242
243
244
245
246
247
248
249
250
/**
 * @brief Infer memory/layout mapping for the Fill operator.
 *
 * Returns the layout mapping produced by layout inference for this FillNode.
 * Currently no layout inference is performed for Fill and the function returns
 * an empty LayoutMap.
 *
 * @param T Context required for layout inference (unused).
 * @param level The inference level requested (unused).
 * @return LayoutMap Empty map indicating no inferred layouts for this operator.
 */
251
252
253
254
255
LayoutMap FillNode::InferLayout(const LayoutInferArgs &T,
                                InferLevel level) const {
  return {};
}

256
257
TIR_REGISTER_TL_OP(Fill, fill)
    .set_num_inputs(2)
258
259
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
260

261
TVM_FFI_STATIC_INIT_BLOCK() { FillNode::RegisterReflection(); }
262

263
} // namespace tl
264
} // namespace tvm