inject_blocal_layout_transform.cc 5.39 KB
Newer Older
wangziyang's avatar
wangziyang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file inject_blocal_layout_transform.cc
 * \brief Transform B_local layout from shared memory thread-interleaved layout
 *        to local row-major layout using ds_read_vector hardware instructions.
 */

#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"
#include "tir/ir/buffer_common.h"
#include "tvm/tir/stmt.h"
#include "inject_utils.h"

namespace tvm {
namespace tl {

using namespace tir;

/*!
45
46
47
48
49
 * \brief Transformer that handles B_local layout transformation with loop optimization
 *
 * This transformer handles two cases:
 * 1. B_local store with outer loop: halve the loop extent and double the offset
 * 2. B_local store without outer loop: just double the offset
wangziyang's avatar
wangziyang committed
50
 */
51
52
53
54
55
56
57
58
59
class BLocalLayoutTransformer : public StmtExprMutator {
 public:
  explicit BLocalLayoutTransformer(int expand)
      : expand_(expand) {}

 private:
  int expand_;

  Stmt VisitStmt_(const ForNode* op) final {
qisan's avatar
qisan committed
60
61
62
63
64
65
    // 1. 先递归处理子节点(重要:确保处理了嵌套的 For 或 Attr)
    Stmt new_body = this->VisitStmt(op->body);
    
    // 2. 检查当前循环是否是目标循环
    // 即使 body 变了,我们也尝试看看能不能在这个 loop 层级做变换
    auto store = new_body.as<BufferStoreNode>();
66
67
    if (op->kind != ForKind::kSerial) {
      return StmtExprMutator::VisitStmt_(op);
wangziyang's avatar
wangziyang committed
68
    }
69
70
    if (!store) {
      return StmtExprMutator::VisitStmt_(op);
wangziyang's avatar
wangziyang committed
71
72
    }

73
74
75
    if (!IsBLocal(store->buffer)) {
      return StmtExprMutator::VisitStmt_(op);
    }
wangziyang's avatar
wangziyang committed
76

77
78
79
80
81
82
83
84
85
86
87
88
89
    int64_t old_extent = op->extent.as<IntImmNode>()->value;

    ICHECK(old_extent % expand_ == 0)
        << "Loop extent must be divisible by expand factor.";

    int64_t new_extent = old_extent / expand_;

    For new_for =
        For(op->loop_var,
            op->min,
            Integer(new_extent),
            op->kind,
            MutateStore(store, op->loop_var));
wangziyang's avatar
wangziyang committed
90

91
    return new_for;
wangziyang's avatar
wangziyang committed
92
93
  }

94
95
96
  bool IsBLocal(const Buffer& buffer) {
    std::string name = buffer->name;
    return name.find("B_local") != std::string::npos;
wangziyang's avatar
wangziyang committed
97
  }
qisan's avatar
qisan committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
  PrimExpr UpdateIndexBase(PrimExpr base, const Var& loop_var, int expand) {
      if (const auto* add = base.as<AddNode>()) {
          return UpdateIndexBase(add->a, loop_var, expand) + 
                UpdateIndexBase(add->b, loop_var, expand);
      } else if (const auto* mul = base.as<MulNode>()) {
          if (mul->a.same_as(loop_var)) {
              return mul->a * (mul->b * expand);
          } else if (mul->b.same_as(loop_var)) {
              return (mul->a * expand) * mul->b;
          }
      }
      return base; 
  }
  Stmt MutateStore(const BufferStoreNode* store, const Var& loop_var) {
      auto n = tvm::ffi::make_object<BufferStoreNode>(*store);
      Array<PrimExpr> new_indices = store->indices;

      if (const auto* ramp = store->indices[0].as<RampNode>()) {
          PrimExpr new_base = UpdateIndexBase(ramp->base, loop_var, expand_);
          
          int new_lanes = ramp->lanes.as<IntImmNode>()->value * expand_;
          
          new_indices.Set(0, Ramp(new_base, ramp->stride, new_lanes));
wangziyang's avatar
wangziyang committed
121
122
      }

qisan's avatar
qisan committed
123
124
125
126
127
128
129
130
      PrimExpr new_value = store->value;
      if (const auto* load = store->value.as<BufferLoadNode>()) {
          if (const auto* l_ramp = load->indices[0].as<RampNode>()) {
              Array<PrimExpr> v_indices = load->indices;
              int v_new_lanes = l_ramp->lanes.as<IntImmNode>()->value * expand_;
              v_indices.Set(0, Ramp(l_ramp->base, l_ramp->stride, v_new_lanes));
              new_value = BufferLoad(load->buffer, v_indices);
          }
131
      }
wangziyang's avatar
wangziyang committed
132

qisan's avatar
qisan committed
133
      return BufferStore(store->buffer, new_value, new_indices);
wangziyang's avatar
wangziyang committed
134
  }
135
};
wangziyang's avatar
wangziyang committed
136
137


138
139
140
141
Stmt InjectBLocalLayoutTransformPass(Stmt stmt, int expand) {
  return BLocalLayoutTransformer(expand)(std::move(stmt));
}

wangziyang's avatar
wangziyang committed
142
143
144
145
146
147
148
149
150
151
152
153

using namespace tir::transform;

tvm::transform::Pass InjectBLocalLayoutTransform() {
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
    // Only apply to DCU targets
    if (!IsDCUTarget(m)) {
      std::cout << "[DEBUG InjectBLocalLayoutTransform] Not a DCU target, skipping" << std::endl;
      return f;
    }

    auto* n = f.CopyOnWrite();
154
    n->body = InjectBLocalLayoutTransformPass(n->body, 2);
wangziyang's avatar
wangziyang committed
155
156
157
158
159
160
161
162
163
164
165
166
167
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.InjectBLocalLayoutTransform", {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.InjectBLocalLayoutTransform",
                        InjectBLocalLayoutTransform);
}

}  // namespace tl
}  // namespace tvm