inject_ds_read.cc 5.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \brief Replace shared memory BufferLoad with ds_read hardware instructions
 * \file inject_ds_read.cc
 */
wangziyang's avatar
wangziyang committed
24
#include <iostream>
25
26
27
28
29
30
31
32
33
34
35
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../op/builtin.h"
#include "tir/ir/buffer_common.h"
#include "tvm/tir/stmt.h"
wangziyang's avatar
wangziyang committed
36
#include "inject_utils.h"
37
38
39
40
41
42

namespace tvm {
namespace tl {

using namespace tir;

wangziyang's avatar
wangziyang committed
43
class DSReadInjector : public StmtExprMutator {
44
 public:
wangziyang's avatar
wangziyang committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
  /*!
   * \brief Visit EvaluateNode to handle explicit ds_read_vector call
   * ds_read_vector Call is wrapped in Evaluate to become a statement
   * Parameters m, n, offset are passed explicitly via CallNode args
   */
  Stmt VisitStmt_(const EvaluateNode* op) override {
    std::cout << "[DEBUG VisitStmt_] Visiting EvaluateNode" << std::endl;
    const CallNode* call = op->value.as<CallNode>();
    std::cout << "[DEBUG VisitStmt_] CallNode ptr: " << call << std::endl;
    if (call != nullptr && call->op.same_as(ds_read_vector())) {
      ICHECK(call->args.size() == 5)
          << "ds_read_vector expects 5 arguments: (dst, src, m, n, offset)";

      // Print args for debugging - these are the actual CallNode args passed in
      std::cout << "[DEBUG ds_read_vector] args[0] (dst): " << call->args[0] << std::endl;
      std::cout << "[DEBUG ds_read_vector] args[1] (src): " << call->args[1] << std::endl;
      std::cout << "[DEBUG ds_read_vector] args[2] (m): " << call->args[2] << std::endl;
      std::cout << "[DEBUG ds_read_vector] args[3] (n): " << call->args[3] << std::endl;
      std::cout << "[DEBUG ds_read_vector] args[4] (offset): " << call->args[4] << std::endl;
    }
    // Continue with default traversal (don't replace the existing call)
    return StmtExprMutator::VisitStmt_(op);
  }

  /*!
   * \brief Visit BufferStoreNode to inject ds_read_vector call
   * Pattern: local_buffer[...] = shared_buffer[...] (BufferLoad)
wangziyang's avatar
wangziyang committed
72
   * Parameters m, n, offset are passed via a CallNode (tl.ds_read_config)
wangziyang's avatar
wangziyang committed
73
74
75
76
   */
  Stmt VisitStmt_(const BufferStoreNode* op) override {
    std::cout << "[DEBUG VisitStmt_] Visiting BufferStoreNode" << std::endl;

77
    // Check if the store is to a local register (not shared memory)
wangziyang's avatar
wangziyang committed
78
79
80
81
    bool is_local = op->buffer.scope() == "local" ||
                    op->buffer.scope() == "local.fragment";
    std::cout << "[DEBUG BufferStore] is_local: " << is_local
              << ", scope: " << op->buffer.scope() << std::endl;
82
83

    if (!is_local) {
wangziyang's avatar
wangziyang committed
84
      return StmtExprMutator::VisitStmt_(op);
85
86
87
    }

    // Check if the value is a BufferLoad from shared memory
wangziyang's avatar
wangziyang committed
88
89
90
91
    const BufferLoadNode* load = op->value.as<BufferLoadNode>();
    if (load == nullptr) {
      std::cout << "[DEBUG BufferStore] value is not BufferLoad" << std::endl;
      return StmtExprMutator::VisitStmt_(op);
92
93
    }

wangziyang's avatar
wangziyang committed
94
95
96
97
    bool is_shared_load = load->buffer.scope() == "shared" ||
                         load->buffer.scope() == "shared.dyn";
    std::cout << "[DEBUG BufferStore] is_shared_load: " << is_shared_load
              << ", load scope: " << load->buffer.scope() << std::endl;
98

wangziyang's avatar
wangziyang committed
99
100
    if (!is_shared_load) {
      return StmtExprMutator::VisitStmt_(op);
101
102
    }

wangziyang's avatar
wangziyang committed
103
104
    // For A_shared, use the actual shared memory base pointer
    PrimExpr m = make_const(DataType::Int(32), 32);
wangziyang's avatar
wangziyang committed
105
106
107
    PrimExpr n = make_const(DataType::Int(32), 16);
    PrimExpr offset = make_const(DataType::Int(32), 0);

wangziyang's avatar
wangziyang committed
108
    // Use buffer data vars directly
wangziyang's avatar
wangziyang committed
109
    Array<PrimExpr> new_args = {
wangziyang's avatar
wangziyang committed
110
111
112
113
114
        load->buffer->data,   // src
        op->buffer->data,      // dst
        m,
        n,
        offset
115
116
    };

wangziyang's avatar
wangziyang committed
117
118
119
    // Create the ds_read call
    Call ds_read_call = Call(DataType::Handle(), ds_read_vector(), new_args);
    return Evaluate(ds_read_call);
120
121
122
123
124
125
126
  }
};

using namespace tir::transform;

tvm::transform::Pass InjectDSRead() {
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
wangziyang's avatar
wangziyang committed
127
    std::cout << "[DEBUG InjectDSRead] Pass is being executed" << std::endl;
128
129
    // Only apply to DCU targets
    if (!IsDCUTarget(m)) {
wangziyang's avatar
wangziyang committed
130
      std::cout << "[DEBUG InjectDSRead] Not a DCU target, skipping" << std::endl;
131
132
133
      return f;
    }

wangziyang's avatar
wangziyang committed
134
    std::cout << "[DEBUG InjectDSRead] Is DCU target, applying injector" << std::endl;
135
136
137
138
139
140
141
142
143
144
145
146
147
    auto *n = f.CopyOnWrite();
    n->body = DSReadInjector()(n->body);
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.InjectDSRead", {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.InjectDSRead", InjectDSRead);
}

}  // namespace tl
wangziyang's avatar
wangziyang committed
148
}  // namespace tvm