fill.cc 8.39 KB
Newer Older
1
/*!
2
 * \file tl/op/fill.cc
3
4
5
6
 *
 * Define elment-wise operators.
 */

7
#include "fill.h"
8
9
10
11
12

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

13
#include "../layout/tcgen05_layout.h"
14
#include "../target/utils.h"
15
#include "../transform/common/loop_fusion_utils.h"
16
#include "../transform/common/loop_parallel_transform_utils.h"
17
18
19
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"
20
#include "utils.h"
21
22
23
24
25
26

namespace tvm {
namespace tl {

using namespace tir;

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
/**
 * @brief Construct a Fill operator node from call arguments and a buffer map.
 *
 * This constructor builds a FillNode describing an element-wise fill of a
 * destination buffer region with a scalar/vector value and stores it in
 * `data_`.
 *
 * Detailed behavior:
 * - If `args[0]` is a `BufferLoad`, the loaded buffer becomes the destination
 * and the load indices are converted to per-dimension ranges:
 *   - `Ramp(base, lanes, stride)` is converted to `Range(base, lanes)`. Only
 * stride == 1 and constant `lanes` are supported.
 *   - Non-ramp indices become `Range(index, 1)`.
 * - Otherwise `args[0]` is treated as an access pointer; the destination buffer
 * is resolved via `vmap[GetVarFromAccessPtr(args[0])]` and the region is the
 * full buffer shape for each dimension.
 * - `args[1]` is used as the fill value; it is cast to the destination buffer's
 * dtype if necessary.
 * - Performs validation:
 *   - Region dimensionality must match destination rank.
 *   - For statically-known region mins and extents, checks that mins >= 0 and
 * extents do not exceed the corresponding destination shape extents.
 *
 * Parameters:
 * @param args Call arguments: expected layout is [dst_access_or_bufferload,
 * value].
 *             - args[0]: destination access (BufferLoad or pointer expression).
 *             - args[1]: value to fill (scalar or vector).
 *
 * Notes:
 * - The constructor enforces constraints (e.g., stride == 1 ramps, constant
 * lanes) and will terminate (via CHECK/ICHECK) if inputs are unsupported or out
 * of bounds.
 */
61
Fill::Fill(Array<PrimExpr> args) {
62
  ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>();
63

64
65
66
  BufferRegion region = NormalizeToBufferRegion(args[0]);
  node->dst = region->buffer;
  node->region = region->region;
67

68
69
  if (args[1]->dtype != node->dst->dtype) {
    node->value = Cast(node->dst->dtype, args[1]);
70
  } else {
71
    node->value = args[1];
72
  }
73

74
75
76
77
  ICHECK(node->region.size() == node->dst->shape.size())
      << "region size = " << node->region.size()
      << " != " << node->dst->shape.size();
  for (int i = 0; i < node->region.size(); i++) {
78
    // bound check if region is static
79
80
    if (const auto *min_imm = node->region[i]->min.as<IntImmNode>()) {
      int64_t min = min_imm->value;
81
82
      ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
    }
83
84
85
86
87
88
89
90
91
    if (const auto *extent_imm = node->region[i]->extent.as<IntImmNode>()) {
      // Only perform the upper-bound check when the destination shape
      // extent is also statically known. If the shape is symbolic (e.g., Var),
      // skip this static check to avoid invalid downcasts.
      if (const auto *shape_imm = node->dst->shape[i].as<IntImmNode>()) {
        ICHECK_LE(extent_imm->value, shape_imm->value)
            << "region[" << i << "] = " << extent_imm->value << " > "
            << node->dst->shape[i];
      }
92
93
    }
  }
94
  data_ = std::move(node);
95
96
}

97
98
99
100
101
102
103
104
/**
 * @brief Create a copy of this FillNode and return it as a TileOperator.
 *
 * Constructs a new FillNode by copying the current node and wraps the copy in a
 * Fill TileOperator.
 *
 * @return TileOperator A TileOperator that owns the copied FillNode.
 */
105
TileOperator FillNode::Clone() const {
106
  auto op = tvm::ffi::make_object<FillNode>(*this);
107
108
109
  return Fill(op);
}

110
111
112
113
114
115
116
117
118
119
120
121
/**
 * @brief Build a SIMT-style nested parallel loop that fills the destination
 * buffer.
 *
 * Constructs per-dimension data-parallel loop iterators matching this node's
 * region extents, emits a BufferStore that writes the node's `value` into `dst`
 * at the loop indices, and nests the loops (innermost to outermost) as parallel
 * `For` nodes. Returns the outermost `For` loop representing the complete
 * multi-dimensional fill kernel.
 *
 * @return For Outermost parallel `For` loop of the generated nested SIMT loop.
 */
122
For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
123
124
125
126
  int ndim = dst->shape.size();
  Array<IterVar> loop_vars;
  Array<PrimExpr> dst_indices;
  for (int i = 0; i < ndim; i++) {
127
    Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype);
128
    loop_vars.push_back({region[i], var, IterVarType::kDataPar});
129
130
    // Offset the loop induction variable by region min to honor sliced regions
    dst_indices.push_back(region[i]->min + var);
131
132
133
  }
  Stmt body = BufferStore(dst, value, dst_indices);
  for (int i = ndim - 1; i >= 0; i--) {
134
135
    body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
               ForKind::kParallel, body);
136
137
138
139
  }
  return Downcast<For>(body);
}

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
/**
 * @brief Lower this Fill operator to a TIR statement for the target.
 *
 * Lowers the FillNode into a Stmt according to the destination buffer scope:
 * - "local.fragment" and shared ("shared", "shared.dyn"): create a parallel
 *   operation from a SIMT loop, infer its layout, partition the root loop by
 *   the thread variable, vectorize the resulting thread loop, and, if a
 *   per-thread predicate exists, guard the vectorized loop with that
 *   predicate.
 * - "local": build a SIMT loop and return its vectorized form.
 * - other scopes: fatal error.
 *
 * The lowering may query layout and thread information from @p T and uses the
 * provided analyzer for any required arithmetic/layout analysis.
 *
 * @param T Lowering arguments (target, thread bounds, thread var, layout map).
 * @return Stmt The lowered TIR statement implementing the fill.
 */
158
Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
159
  if (dst.scope() == "local.fragment") {
160
    auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
161
162
    par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
                         false, T.buffer_remap},
163
164
165
                        InferLevel::kFree);
    auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
                                     par_op->GetLoopLayout());
166
    auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
167
    if (par_op->GetPredicate(T.thread_var).defined()) {
168
169
      return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                        vectorized_thread_loop);
170
171
172
173
    }
    return vectorized_thread_loop;
  } else if (dst.scope() == "local") {
    auto init_loop = MakeSIMTLoop(analyzer);
174
    auto vectorized_thread_loop = VectorizeLoop(init_loop, analyzer);
175
    return vectorized_thread_loop;
Kurisu's avatar
Kurisu committed
176
177
  } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" ||
             dst.scope() == "global") {
178
    auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
179
180
    par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
                         false, T.buffer_remap},
181
182
183
                        InferLevel::kFree);
    auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
                                     par_op->GetLoopLayout());
184
    auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
185
186
187
188
    if (par_op->GetPredicate(T.thread_var).defined()) {
      return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                        vectorized_thread_loop);
    }
189
    return vectorized_thread_loop;
190
191
  } else {
    LOG(FATAL) << "Unsupported scope " << dst.scope();
192
    return Stmt();
193
194
195
  }
}

196
197
198
199
200
201
202
203
204
205
206
/**
 * @brief Infer memory/layout mapping for the Fill operator.
 *
 * Returns the layout mapping produced by layout inference for this FillNode.
 * Currently no layout inference is performed for Fill and the function returns
 * an empty LayoutMap.
 *
 * @param T Context required for layout inference (unused).
 * @param level The inference level requested (unused).
 * @return LayoutMap Empty map indicating no inferred layouts for this operator.
 */
207
208
209
210
211
LayoutMap FillNode::InferLayout(const LayoutInferArgs &T,
                                InferLevel level) const {
  return {};
}

212
TIR_REGISTER_TL_TILE_OP(Fill, fill)
213
    .set_num_inputs(2)
214
215
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
216

217
TVM_FFI_STATIC_INIT_BLOCK() { FillNode::RegisterReflection(); }
218

219
} // namespace tl
220
} // namespace tvm