loop_fusion_utils.h 6.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file common.h
 * \brief Common utilities for TL transforms
 */

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include <queue>

#include "../../op/parallel.h"
#include "../loop_partition.h"
#include "../loop_vectorize.h"
38
#include "arith/ir_mutator_with_analyzer.h"
39
40
41
42
43
44
45
46

namespace tvm {
namespace tl {

using namespace tir;
using arith::IRMutatorWithAnalyzer;

class FragmentAccessDetector : public StmtExprVisitor {
47
public:
48
49
50
51
52
53
  FragmentAccessDetector() = default;

  void Collect(Stmt stmt) { VisitStmt(stmt); }

  bool HasFragmentAccess() { return has_fragment_access_; }

54
55
private:
  void VisitExpr_(const BufferLoadNode *op) final {
56
57
58
59
60
61
62
    // Check if the buffer is in global scope
    if (IsFragmentBuffer(op->buffer)) {
      has_fragment_access_ = true;
    }
    StmtExprVisitor::VisitExpr_(op);
  }

63
  void VisitStmt_(const BufferStoreNode *op) final {
64
65
66
67
68
69
70
71
    // Check if the buffer is in global scope
    if (IsFragmentBuffer(op->buffer)) {
      has_fragment_access_ = true;
    }
    StmtExprVisitor::VisitStmt_(op);
  }

  // Helper function to determine if a buffer is local.fragment
72
73
74
  bool IsFragmentBuffer(const Buffer &buffer) {
    // The storage scope is often encoded in the buffer->data var name or
    // associated attributes.
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    String scope = buffer.scope();
    return scope == "local.fragment";
  }

  bool has_fragment_access_{false};
};

/*!
 * \brief ParallelLoopFuser
 * This class is used to fuse a chain of parallel loops into one loop.
 * The loops must:
 *  - All be parallel (ForKind::kParallel)
 *  - Have bounds from 0 to their extent
 * Once fused, a single loop variable will replace the chain, and the
 * original loop variables will be derived by division and modulo operations.
 *
91
92
 * This can be helpful for inferring layout for the fragment in a subsequent
 * pass.
93
94
 */
class ParallelLoopFuser : public IRMutatorWithAnalyzer {
95
public:
96
97
98
99
100
101
  static Stmt Fuse(Stmt stmt) {
    arith::Analyzer analyzer;
    ParallelLoopFuser substituter(&analyzer);
    return substituter.VisitStmt(stmt);
  }

102
103
104
private:
  ParallelLoopFuser(arith::Analyzer *analyzer)
      : IRMutatorWithAnalyzer(analyzer){};
105

106
  Stmt VisitStmt_(const ForNode *op) final {
107
    // Gather consecutive parallel loops
108
109
    std::vector<const ForNode *> loop_chain;
    const ForNode *current = op;
110
111
112
113
114
115
116
117
118
    // check if has fragment access
    FragmentAccessDetector detector;
    detector.Collect(op->body);
    // Do not fuse if there is a fragment access
    if (detector.HasFragmentAccess()) {
      return IRMutatorWithAnalyzer::VisitStmt_(op);
    }

    while (true) {
119
120
121
122
      if (current->kind != ForKind::kParallel)
        break;
      if (!is_zero(current->min))
        break;
123
124
      loop_chain.push_back(current);

125
      const ForNode *inner_for = current->body.as<ForNode>();
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
      if (!inner_for) {
        break;
      }
      current = inner_for;
    }

    // If only one loop found or loop chain size is 1, no fusion needed.
    if (loop_chain.size() <= 1) {
      return IRMutatorWithAnalyzer::VisitStmt_(op);
    }

    // At this point we have multiple nested parallel loops starting at zero
    // We will fuse them all.
    PrimExpr fused_extent = make_const(DataType::Int(32), 1);
    for (auto it = loop_chain.rbegin(); it != loop_chain.rend(); ++it) {
      fused_extent = fused_extent * (*it)->extent;
    }

    std::string fused_name;
    for (auto it = loop_chain.begin(); it != loop_chain.end(); ++it) {
      fused_name += (*it)->loop_var->name_hint + "_";
    }

    fused_name += "fused";

    // Create a new fused loop var
    Var fused_var(fused_name, DataType::Int(32));

    // The body of the last loop in the chain:
155
    const ForNode *innermost_loop = loop_chain.back();
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    Stmt body = innermost_loop->body;

    // We need to substitute all loop variables in the chain.
    // The scheme:
    // Suppose we have loops (i in [0,M], j in [0,N], k in [0,O])
    // fused loop var f in [0, M*N*O]
    // i = f / (N*O)
    // j = (f % (N*O)) / O
    // k = f % O
    //
    // Generalizing for a chain of lengths L:
    // extents: E_0, E_1, ... E_{L-1}
    // index_i = (f / (E_{i+1}*...*E_{L-1})) % E_i
    // For the last one, it's just f % E_{L-1} if i == L-1.

    // Compute the "stride" products for each loop variable
    // stride[i] = product of extents of loops after i
    // for L loops: stride[L-1] = 1
    // stride[L-2] = E_{L-1}
    // stride[L-3] = E_{L-1} * E_{L-2}
    // ...
    std::vector<PrimExpr> extents;
    extents.reserve(loop_chain.size());
    for (auto l : loop_chain) {
      extents.push_back(l->extent);
    }

183
184
    std::vector<PrimExpr> strides(loop_chain.size(),
                                  make_const(DataType::Int(32), 1));
185
186
187
188
189
190
191
192
193
194
195
196
197
    for (int i = static_cast<int>(loop_chain.size()) - 2; i >= 0; i--) {
      strides[i] = strides[i + 1] * extents[i + 1];
    }

    // We'll create a substitution map for all loop variables
    // index_i = (f / strides[i]) % extents[i]
    // We'll define a helper lambda:
    auto create_index_expr = [&](int i) {
      return FloorMod(FloorDiv(fused_var, strides[i]), extents[i]);
    };

    Map<Var, PrimExpr> var_map;
    for (size_t i = 0; i < loop_chain.size(); i++) {
198
199
200
      const ForNode *loop = loop_chain[i];
      var_map.Set(loop->loop_var,
                  analyzer_->Simplify(create_index_expr(static_cast<int>(i))));
201
202
203
204
205
206
207
208
209
210
211
212
    }

    // Perform the substitution
    body = Substitute(body, var_map);

    // Create the fused loop
    For fused_for = For(fused_var, 0, fused_extent, ForKind::kParallel, body);

    return fused_for;
  }
};

213
214
} // namespace tl
} // namespace tvm