inject_blocal_layout_transform.cc 6.29 KB
Newer Older
wangziyang's avatar
wangziyang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file inject_blocal_layout_transform.cc
 * \brief Transform B_local layout from shared memory thread-interleaved layout
 *        to local row-major layout using ds_read_vector hardware instructions.
 */

#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"
#include "tir/ir/buffer_common.h"
#include "tvm/tir/stmt.h"
#include "inject_utils.h"

namespace tvm {
namespace tl {

using namespace tir;

/*!
45
46
47
48
49
 * \brief Transformer that handles B_local layout transformation with loop optimization
 *
 * This transformer handles two cases:
 * 1. B_local store with outer loop: halve the loop extent and double the offset
 * 2. B_local store without outer loop: just double the offset
wangziyang's avatar
wangziyang committed
50
 */
51
52
53
54
55
56
57
58
59
60
61
62
class BLocalLayoutTransformer : public StmtExprMutator {
 public:
  explicit BLocalLayoutTransformer(int expand)
      : expand_(expand) {}

 private:
  int expand_;

  Stmt VisitStmt_(const ForNode* op) final {
    // 只处理 serial 外层循环
    if (op->kind != ForKind::kSerial) {
      return StmtExprMutator::VisitStmt_(op);
wangziyang's avatar
wangziyang committed
63
    }
64
65
66
67
68

    // 判断是否是 B_local 写循环
    auto store = op->body.as<BufferStoreNode>();
    if (!store) {
      return StmtExprMutator::VisitStmt_(op);
wangziyang's avatar
wangziyang committed
69
70
    }

71
72
73
    if (!IsBLocal(store->buffer)) {
      return StmtExprMutator::VisitStmt_(op);
    }
wangziyang's avatar
wangziyang committed
74

75
76
77
78
79
80
81
82
83
84
85
86
87
88
    int64_t old_extent = op->extent.as<IntImmNode>()->value;

    ICHECK(old_extent % expand_ == 0)
        << "Loop extent must be divisible by expand factor.";

    int64_t new_extent = old_extent / expand_;

    // 修改循环范围
    For new_for =
        For(op->loop_var,
            op->min,
            Integer(new_extent),
            op->kind,
            MutateStore(store, op->loop_var));
wangziyang's avatar
wangziyang committed
89

90
    return new_for;
wangziyang's avatar
wangziyang committed
91
92
  }

93
94
95
  bool IsBLocal(const Buffer& buffer) {
    std::string name = buffer->name;
    return name.find("B_local") != std::string::npos;
wangziyang's avatar
wangziyang committed
96
97
  }

98
99
  Stmt MutateStore(const BufferStoreNode* store,
                   const Var& loop_var) {
wangziyang's avatar
wangziyang committed
100

101
    Array<PrimExpr> new_indices = store->indices;
wangziyang's avatar
wangziyang committed
102

103
    PrimExpr new_value = store->value;
wangziyang's avatar
wangziyang committed
104

105
106
107
    // 修改切片跨度:
    // 原来 j*vec : j*vec+vec
    // 改为 j*vec : j*vec*expand + vec
wangziyang's avatar
wangziyang committed
108

109
110
111
    PrimExpr idx = store->indices[0];
    //T.Ramp(j * 4, 1, 4) -> Ramp(j*8, 1, 4)
    std::cout << idx << std::endl;
wangziyang's avatar
wangziyang committed
112

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    // 解析 j*vec 结构
    // 假设结构为 j * vec + const

    // 不改 RHS
    // PrimExpr value = store->value;

    // 修改写入向量宽度
    // 原 value 是 Ramp(base=j*4, stride=1, lanes=4)
    // 匹配 j * stride
    // Ramp(base=j*8, stride=1, lanes=8)
    if (const auto* ramp = idx.as<RampNode>()) {

      PrimExpr base = ramp->base;
      PrimExpr stride = ramp->stride;
      int old_lanes = ramp->lanes.as<IntImmNode>()->value;
      int new_lanes = old_lanes * expand_;

      // 匹配 base = j * stride_val
      if (const auto* mul = base.as<MulNode>()) {

        if (mul->a.same_as(loop_var)) {
          int64_t old_stride =
              mul->b.as<IntImmNode>()->value;

          int64_t new_stride =
              old_stride * expand_;

          PrimExpr new_base =
              loop_var *
              make_const(DataType::Int(32), new_stride);

          new_indices.Set(
              0,
              Ramp(new_base, stride, new_lanes));
        }
        else if (mul->b.same_as(loop_var)) {
          int64_t old_stride =
              mul->a.as<IntImmNode>()->value;

          int64_t new_stride =
              old_stride * expand_;

          PrimExpr new_base =
              make_const(DataType::Int(32), new_stride) *
              loop_var;

          new_indices.Set(
              0,
              Ramp(new_base, stride, new_lanes));
        }
wangziyang's avatar
wangziyang committed
163
      }
164
    }
wangziyang's avatar
wangziyang committed
165

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    if (auto* load = new_value.as<BufferLoadNode>()) {
      // BufferLoad with region access: B_shared[start : end]
      // end - start = lanes,需要同步扩展
      Array<PrimExpr> value_indices = load->indices;
      if (auto* old_ramp = load->indices[0].as<RampNode>()) {

          PrimExpr scalar_base = old_ramp->base;     // 必须是 scalar
          PrimExpr stride      = old_ramp->stride;
          //RHS 4 lane
          int old_lanes = old_ramp->lanes.as<IntImmNode>()->value;
          //RHS 8 lane
          int new_lanes = old_lanes * expand_;

          value_indices.Set(
              0,
              Ramp(scalar_base, stride, new_lanes)
          );

          new_value = BufferLoad(load->buffer, value_indices);
      }
wangziyang's avatar
wangziyang committed
186
187
    }

188
189
190
    return BufferStore(store->buffer,
                       new_value,
                       new_indices);
wangziyang's avatar
wangziyang committed
191
  }
192
};
wangziyang's avatar
wangziyang committed
193
194


195
196
197
198
Stmt InjectBLocalLayoutTransformPass(Stmt stmt, int expand) {
  return BLocalLayoutTransformer(expand)(std::move(stmt));
}

wangziyang's avatar
wangziyang committed
199
200
201
202
203
204
205
206
207
208
209
210

using namespace tir::transform;

tvm::transform::Pass InjectBLocalLayoutTransform() {
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
    // Only apply to DCU targets
    if (!IsDCUTarget(m)) {
      std::cout << "[DEBUG InjectBLocalLayoutTransform] Not a DCU target, skipping" << std::endl;
      return f;
    }

    auto* n = f.CopyOnWrite();
211
    n->body = InjectBLocalLayoutTransformPass(n->body, 2);
wangziyang's avatar
wangziyang committed
212
213
214
215
216
217
218
219
220
221
222
223
224
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.InjectBLocalLayoutTransform", {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.InjectBLocalLayoutTransform",
                        InjectBLocalLayoutTransform);
}

}  // namespace tl
}  // namespace tvm