storage_access.h 5.78 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file storage_access.h
 * \brief Common data structure for storage access analysis.
 */
#ifndef TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_
#define TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_

#include <tvm/arith/int_set.h>
#include <tvm/ir/attrs.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>

#include <unordered_map>
#include <vector>

#include "arith/ir_visitor_with_analyzer.h"
#include "runtime/thread_storage_scope.h"

namespace tvm {
namespace tl {

using namespace tir;
using arith::IRVisitorWithAnalyzer;
using runtime::StorageRank;
using runtime::StorageScope;

/*!
 * \brief Base class of storage access analysis
 */
class TileLangStorageAccessVisitor : public IRVisitorWithAnalyzer {
public:
  /*! \brief Storage access type */
  enum AccessType : uint8_t {
    kRead,
    kWrite,
    kSync,
    kAlloc,
    // acquired version of read, only need to handle WAR dep.
    kReadAcquire
  };
  /*! \brief An access entry */
  struct AccessEntry {
    /*! \brief The thread index that access this entry */
    Array<IterVar> threads;
    /*! \brief The touched thread range */
    Map<Var, Range> thread_range;
    /*! \brief The buffer variable, if any */
    Array<PrimExpr> buffer_indices;
    /*! \brief The buffer ranges for pointer access */
    Array<Range> buffer_ranges;
    Var buffer = NullValue<Var>();
    /*! \brief The access data type */
    DataType dtype;
    /*! \brief The touched access range
     *
     * Has one IntSet for each index in the buffer being accessed.
     */
    Array<arith::IntSet> touched;
    /*! \brief The type of access */
    AccessType type;
    /*! \brief The storage scope */
    StorageScope scope;
    /*! \brief Whether the access is double buffer write */
    bool double_buffer_write = false;
    /*! \brief Whether the access is pointer access */
    bool is_pointer_access = false;
  };

  /*! \brief Access pattern about a single statement */
  struct StmtEntry {
    /*! \brief The statement */
    const Object *stmt{};
    /*! \brief access patterns in the statement */
    std::vector<AccessEntry> access;
  };
  // override visitor pattern
  void VisitExpr_(const BufferLoadNode *op) final;
  void VisitStmt_(const BufferStoreNode *op) final;
  void VisitStmt_(const EvaluateNode *op) final;
  void VisitStmt_(const LetStmtNode *op) final;
  void VisitStmt_(const AttrStmtNode *op) override;
  void VisitStmt_(const ForNode *op) final;
  void VisitStmt_(const IfThenElseNode *op) final;
  void VisitStmt_(const WhileNode *op) final;
  void VisitExpr_(const CallNode *op) final;
  void VisitStmt_(const BlockNode *op) final;

  void SetBufferDataToBuffer(const Var &buffer_var, const Buffer &buffer) {
    buffer_data_to_buffer_.Set(buffer_var, buffer);
  }

protected:
  TileLangStorageAccessVisitor() { scope_.push_back(std::vector<StmtEntry>()); }
  /*! \return number of conditions in the current scope. */
  int condition_counter() const { return condition_counter_; }
  /*! \return whether we are in device environment. */
  bool in_device_env() const { return in_device_env_; }
  /*! \return environment threads */
  const Array<IterVar> &env_threads() const { return env_threads_; }
  /*!
   * \brief Whether we need analyze the buffer in current scope.
   * \param buffer The buffer to be checked
   * \param scope The scope of the buffer.
   * \return Whether the analysis of buffer is enabled.
   */
  virtual bool Enabled(const VarNode *buffer, const StorageScope &scope) const {
    return true;
  }
  /*!
   * \brief Summarize the sequence of operations into parent.
   *
   *  Insert synchronization if necessary and remove un-necessary
   *  memory access which are already synced.
   *
   * \param seq The sequence of the access operations.
   * \param loop Pass loop node if it is a loop, otherwise nullptr.
   * \return The summarized sequence that represent access that
   *  the parent should taken care of to synchronize.
   */
  virtual std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq,
                                             const ForNode *loop) = 0;

  /*!
   * \brief Compute the thread range for the given threads.
   * \param threads The threads to compute the range for.
   * \return The thread range.
   */
  Map<Var, Range> ComputeThreadRange(const Array<IterVar> &threads);

  /*!
   * \brief Get the scope of the buffer array.
   * \return The scope of the final buffer array.
   */
  StorageScope GetScope(const Var &buffer_var) const;
  // access scope
  std::vector<std::vector<StmtEntry>> scope_;

private:
  // whether access appending is enabled.
  bool allow_append_{false};
  // Whether we are in device environment
  bool in_device_env_{false};
  // Whether we are inside condition.
  int condition_counter_{0};
  // The current double buffer write scope.
  const VarNode *double_buffer_write_{nullptr};
  // the current free stmt entry.
  StmtEntry curr_stmt_;
  // The involving threads
  Array<IterVar> env_threads_;
  // The buffer map
  Map<Var, Buffer> buffer_data_to_buffer_;
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_TRANSFORMS_STORAGE_ACCESS_H_