loop_fusion_utils.h 6.98 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file common.h
 * \brief Common utilities for TL transforms
 */

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include <queue>

#include "../../op/parallel.h"
#include "../loop_partition.h"
#include "../loop_vectorize.h"
36
#include "arith/ir_mutator_with_analyzer.h"
37
38
39
40
41
42
43
44

namespace tvm {
namespace tl {

using namespace tir;
using arith::IRMutatorWithAnalyzer;

class FragmentAccessDetector : public StmtExprVisitor {
45
public:
46
47
  FragmentAccessDetector() = default;

48
  void Collect(const Stmt &stmt) { VisitStmt(stmt); }
49
50
51

  bool HasFragmentAccess() { return has_fragment_access_; }

52
53
private:
  void VisitExpr_(const BufferLoadNode *op) final {
54
55
56
57
58
59
60
    // Check if the buffer is in global scope
    if (IsFragmentBuffer(op->buffer)) {
      has_fragment_access_ = true;
    }
    StmtExprVisitor::VisitExpr_(op);
  }

61
  void VisitStmt_(const BufferStoreNode *op) final {
62
63
64
65
66
67
68
69
    // Check if the buffer is in global scope
    if (IsFragmentBuffer(op->buffer)) {
      has_fragment_access_ = true;
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  // Helper function to determine if a buffer is local.fragment
70
71
72
  bool IsFragmentBuffer(const Buffer &buffer) {
    // The storage scope is often encoded in the buffer->data var name or
    // associated attributes.
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    String scope = buffer.scope();
    return scope == "local.fragment";
  }

  bool has_fragment_access_{false};
};

/*!
 * \brief ParallelLoopFuser
 * This class is used to fuse a chain of parallel loops into one loop.
 * The loops must:
 *  - All be parallel (ForKind::kParallel)
 *  - Have bounds from 0 to their extent
 * Once fused, a single loop variable will replace the chain, and the
 * original loop variables will be derived by division and modulo operations.
 *
89
90
 * This can be helpful for inferring layout for the fragment in a subsequent
 * pass.
91
92
 */
class ParallelLoopFuser : public IRMutatorWithAnalyzer {
93
public:
94
  static Stmt Fuse(const Stmt &stmt) {
95
96
97
98
99
    arith::Analyzer analyzer;
    ParallelLoopFuser substituter(&analyzer);
    return substituter.VisitStmt(stmt);
  }

100
101
102
private:
  ParallelLoopFuser(arith::Analyzer *analyzer)
      : IRMutatorWithAnalyzer(analyzer){};
103

104
  Stmt VisitStmt_(const ForNode *op) final {
105
    // Gather consecutive parallel loops
106
107
    std::vector<const ForNode *> loop_chain;
    const ForNode *current = op;
108
109
110
111
112
113
114
115
    // check if has fragment access
    FragmentAccessDetector detector;
    detector.Collect(op->body);
    // Do not fuse if there is a fragment access
    if (detector.HasFragmentAccess()) {
      return IRMutatorWithAnalyzer::VisitStmt_(op);
    }
    while (true) {
116
117
118
119
      if (current->kind != ForKind::kParallel)
        break;
      if (!is_zero(current->min))
        break;
120
121
      loop_chain.push_back(current);

122
      const ForNode *inner_for = current->body.as<ForNode>();
123
124
125
126
127
128
129
130
131
132
133
      if (!inner_for) {
        break;
      }
      current = inner_for;
    }

    // If only one loop found or loop chain size is 1, no fusion needed.
    if (loop_chain.size() <= 1) {
      return IRMutatorWithAnalyzer::VisitStmt_(op);
    }

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    // If one of the loop has extent which is not 2^n, we do not fuse
    for (auto l : loop_chain) {
      PrimExpr extent = l->extent;
      // If extent is not a constant integer, we cannot determine if it's power
      // of 2
      if (!extent.as<IntImmNode>()) {
        return IRMutatorWithAnalyzer::VisitStmt_(op);
      }
      int64_t value = extent.as<IntImmNode>()->value;
      // Check if value is power of 2: value > 0 and only has one bit set
      if (value <= 0 || (value & (value - 1)) != 0) {
        return IRMutatorWithAnalyzer::VisitStmt_(op);
      }
    }

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    // At this point we have multiple nested parallel loops starting at zero
    // We will fuse them all.
    PrimExpr fused_extent = make_const(DataType::Int(32), 1);
    for (auto it = loop_chain.rbegin(); it != loop_chain.rend(); ++it) {
      fused_extent = fused_extent * (*it)->extent;
    }

    std::string fused_name;
    for (auto it = loop_chain.begin(); it != loop_chain.end(); ++it) {
      fused_name += (*it)->loop_var->name_hint + "_";
    }

    fused_name += "fused";

    // Create a new fused loop var
    Var fused_var(fused_name, DataType::Int(32));

    // The body of the last loop in the chain:
167
    const ForNode *innermost_loop = loop_chain.back();
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    Stmt body = innermost_loop->body;

    // We need to substitute all loop variables in the chain.
    // The scheme:
    // Suppose we have loops (i in [0,M], j in [0,N], k in [0,O])
    // fused loop var f in [0, M*N*O]
    // i = f / (N*O)
    // j = (f % (N*O)) / O
    // k = f % O
    //
    // Generalizing for a chain of lengths L:
    // extents: E_0, E_1, ... E_{L-1}
    // index_i = (f / (E_{i+1}*...*E_{L-1})) % E_i
    // For the last one, it's just f % E_{L-1} if i == L-1.

    // Compute the "stride" products for each loop variable
    // stride[i] = product of extents of loops after i
    // for L loops: stride[L-1] = 1
    // stride[L-2] = E_{L-1}
    // stride[L-3] = E_{L-1} * E_{L-2}
    // ...
    std::vector<PrimExpr> extents;
    extents.reserve(loop_chain.size());
    for (auto l : loop_chain) {
      extents.push_back(l->extent);
    }

195
196
    std::vector<PrimExpr> strides(loop_chain.size(),
                                  make_const(DataType::Int(32), 1));
197
198
199
200
201
202
203
204
205
206
207
208
209
    for (int i = static_cast<int>(loop_chain.size()) - 2; i >= 0; i--) {
      strides[i] = strides[i + 1] * extents[i + 1];
    }

    // We'll create a substitution map for all loop variables
    // index_i = (f / strides[i]) % extents[i]
    // We'll define a helper lambda:
    auto create_index_expr = [&](int i) {
      return FloorMod(FloorDiv(fused_var, strides[i]), extents[i]);
    };

    Map<Var, PrimExpr> var_map;
    for (size_t i = 0; i < loop_chain.size(); i++) {
210
211
212
      const ForNode *loop = loop_chain[i];
      var_map.Set(loop->loop_var,
                  analyzer_->Simplify(create_index_expr(static_cast<int>(i))));
213
214
215
216
217
218
219
    }

    // Perform the substitution
    body = Substitute(body, var_map);

    // Create the fused loop
    For fused_for = For(fused_var, 0, fused_extent, ForKind::kParallel, body);
220
    fused_for.CopyOnWrite()->annotations = op->annotations;
221
222
223
224
    return fused_for;
  }
};

225
226
} // namespace tl
} // namespace tvm