thread_storage_scope.h 5.65 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
/*!
 *  Copyright (c) 2017 by Contributors
 * \file thread_storage_scope.h
4
 * \brief Extract thread axis configuration from DGLArgs.
Minjie Wang's avatar
Minjie Wang committed
5
 */
6
7
#ifndef DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_
#define DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_
Minjie Wang's avatar
Minjie Wang committed
8
9
10
11
12

#include <dgl/runtime/packed_func.h>
#include <string>
#include <vector>

13
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
namespace runtime {

/*!
 * \brief Memory hierachy rank in the storage system
 * \note The global rank and shared rank have one to one
 *       correspondence to the thread rank.
 */
enum class StorageRank {
  /*! \brief global memory */
  kGlobal = 0,
  /*! \brief shared memory among thread group */
  kShared = 1,
  /*!
   * \brief reserved for warp memory.
   *  This is only used by programming model.
   *  There is no such memory usually in GPU.
   *  Instead, we can simulate it by registers and shuffle.
   */
  kWarp = 2,
  /*! \brief thread local memory */
  kLocal = 3
};

/*!
 * \param thread_scope_rank The thread scope rank
 * \return default storage rank given the thread scope
 */
inline StorageRank DefaultStorageRank(int thread_scope_rank) {
  switch (thread_scope_rank) {
    case -1: return StorageRank::kGlobal;
    case 0: return StorageRank::kShared;
    case 1: return StorageRank::kLocal;
    default: {
      LOG(FATAL) << "unknown rank";
      return StorageRank::kGlobal;
    }
  }
}

/*! \brief class to represent storage scope */
struct StorageScope {
  /*! \brief The rank of the storage */
  StorageRank rank{StorageRank::kGlobal};
  /*! \brief tag for special purpose memory. */
  std::string tag;
  // comparator
  inline bool operator==(const StorageScope& other) const {
    return rank == other.rank && tag == other.tag;
  }
  inline bool operator!=(const StorageScope& other) const {
    return !(*this == other);
  }
  inline std::string to_string() const {
    std::string ret;
    switch (rank) {
      case StorageRank::kGlobal: return "global" + tag;
      case StorageRank::kShared: return "shared" + tag;
      case StorageRank::kWarp: return "warp" + tag;
      case StorageRank::kLocal: return "local" + tag;
      default: LOG(FATAL) << "unknown storage scope"; return "";
    }
  }
  /*!
   * \brief make storage scope from string
   * \param s The string to be parsed.
   * \return The storage scope.
   */
  static StorageScope make(const std::string& s) {
    StorageScope r;
    if (s.compare(0, 6, "global")  == 0) {
      r.rank = StorageRank::kGlobal;
      r.tag = s.substr(6, std::string::npos);
    } else if (s.compare(0, 6, "shared") == 0) {
      r.rank = StorageRank::kShared;
      r.tag = s.substr(6, std::string::npos);
    } else if (s.compare(0, 4, "warp") == 0) {
      r.rank = StorageRank::kWarp;
      r.tag = s.substr(4, std::string::npos);
    } else if (s.compare(0, 5, "local") == 0) {
      r.rank = StorageRank::kLocal;
      r.tag = s.substr(5, std::string::npos);
    } else {
      LOG(FATAL) << "unknown storage scope " << s;
    }
    return r;
  }
};

/*! \brief class to represent thread scope */
struct ThreadScope {
  /*! \brief The rank of thread scope */
  int rank{0};
  /*! \brief the dimension index under the rank */
  int dim_index{0};
  /*!
   * \brief make storage scope from string
   * \param s The string to be parsed.
   * \return The storage scope.
   */
  static ThreadScope make(const std::string& s) {
    ThreadScope r;
    if (s == "vthread" || s == "cthread") {
      // virtual thread at the same level as local
      r.rank = 1;
      r.dim_index = -1;
    } else if (s.compare(0, 9, "blockIdx.") == 0) {
      r.rank = 0;
      r.dim_index = static_cast<int>(s[9] - 'x');
    } else if (s.compare(0, 10, "threadIdx.") == 0) {
      r.rank = 1;
      r.dim_index = static_cast<int>(s[10] - 'x');
    } else {
      LOG(FATAL) << "Unknown threadscope " << s;
    }
    return r;
  }
};


/*! \brief workload speccification */
struct ThreadWorkLoad {
  // array, first three are thread configuration.
  size_t work_size[6];
  /*!
   * \param i The block dimension.
   * \return i-th block dim
   */
  inline size_t block_dim(size_t i) const {
    return work_size[i + 3];
  }
  /*!
   * \param i The grid dimension.
   * \return i-th grid dim
   */
  inline size_t grid_dim(size_t i) const {
    return work_size[i];
  }
};
/*! \brief Thread axis configuration */
class ThreadAxisConfig {
 public:
  void Init(size_t base,
            const std::vector<std::string>& thread_axis_tags)  {
    base_ = base;
    std::vector<bool> filled(6, false);
    for (size_t i = 0; i < thread_axis_tags.size(); ++i) {
      const std::string& tag = thread_axis_tags[i];
      ThreadScope ts = ThreadScope::make(tag);
      arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
      filled[ts.rank * 3 + ts.dim_index] = true;
    }
    work_dim_ = 1;
    for (int i = 0; i < 3; ++i) {
      if (filled[i] || filled[i + 3]) {
        work_dim_ = i + 1;
      }
    }
  }
  // extract workload from arguments.
173
  ThreadWorkLoad Extract(DGLArgs x) const {
Minjie Wang's avatar
Minjie Wang committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    ThreadWorkLoad w;
    std::fill(w.work_size, w.work_size + 6, 1);
    for (size_t i = 0; i < arg_index_map_.size(); ++i) {
      w.work_size[arg_index_map_[i]] =
          static_cast<size_t>(x.values[base_ + i].v_int64);
    }
    return w;
  }
  // return the work dim
  size_t work_dim() const {
    return work_dim_;
  }

 private:
  /*! \brief base axis */
  size_t base_;
  /*! \brief The worker dimension */
  size_t work_dim_;
  /*! \brief The index mapping. */
  std::vector<uint32_t> arg_index_map_;
};

}  // namespace runtime
197
}  // namespace dgl
Minjie Wang's avatar
Minjie Wang committed
198
199
200

namespace std {
template <>
201
202
struct hash<::dgl::runtime::StorageScope> {
  std::size_t operator()(const ::dgl::runtime::StorageScope& k) const {
Minjie Wang's avatar
Minjie Wang committed
203
204
205
206
    return static_cast<size_t>(k.rank);
  }
};
}  // namespace std
207
#endif  // DGL_RUNTIME_THREAD_STORAGE_SCOPE_H_