pack_args.h 9.55 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
/*!
 *  Copyright (c) 2017 by Contributors
 * \file pack_args.h
4
5
 * \brief Utility to pack DGLArgs to other type-erased fution calling
 * convention.
Minjie Wang's avatar
Minjie Wang committed
6
7
8
9
10
11
12
13
 *
 *  Two type erased function signatures are supported.
 *   - cuda_style(void** args, int num_args);
 *      - Pack everything by address
 *   - metal_style(void** buffers, int num_buffers,
 *                 union_32bit args[N], int num_args);
 *      - Pack buffer by address, pack rest parameter into 32bit union buffer.
 */
14
15
#ifndef DGL_RUNTIME_PACK_ARGS_H_
#define DGL_RUNTIME_PACK_ARGS_H_
Minjie Wang's avatar
Minjie Wang committed
16
17

#include <dgl/runtime/c_runtime_api.h>
18
#include <dgl/runtime/packed_func.h>
19

Minjie Wang's avatar
Minjie Wang committed
20
#include <cstring>
21
#include <vector>
Minjie Wang's avatar
Minjie Wang committed
22

23
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
24
25
26
27
28
29
30
31
32
33
34
35
36
namespace runtime {
/*!
 * \brief argument union type of 32bit.
 * Choose 32 bit because most GPU API do not work well with 64 bit.
 */
union ArgUnion {
  int32_t v_int32;
  uint32_t v_uint32;
  float v_float32;
};
/*!
 * \brief Create a packed function from void addr types.
 *
37
 * \param f with signiture (DGLArgs args, DGLRetValue* rv, void* void_args)
Minjie Wang's avatar
Minjie Wang committed
38
39
40
41
42
 * \param arg_types The arguments type information.
 * \tparam F the function type
 *
 * \return The wrapped packed function.
 */
43
44
45
template <typename F>
inline PackedFunc PackFuncVoidAddr(
    F f, const std::vector<DGLDataType>& arg_types);
Minjie Wang's avatar
Minjie Wang committed
46
/*!
47
48
 * \brief Create a packed function that from function only packs buffer
 * arguments.
Minjie Wang's avatar
Minjie Wang committed
49
 *
50
 * \param f with signiture (DGLArgs args, DGLRetValue* rv, ArgUnion* pack_args)
Minjie Wang's avatar
Minjie Wang committed
51
52
53
54
55
 * \param arg_types The arguments type information.
 * \tparam F the function type
 *
 * \return The wrapped packed function.
 */
56
57
58
template <typename F>
inline PackedFunc PackFuncNonBufferArg(
    F f, const std::vector<DGLDataType>& arg_types);
Minjie Wang's avatar
Minjie Wang committed
59
/*!
60
61
 * \brief Create a packed function that from function that takes a packed
 * arguments.
Minjie Wang's avatar
Minjie Wang committed
62
 *
63
64
 * \param f with signature (DGLArgs args, DGLRetValue* rv, void* pack_args,
 * size_t nbytes)
Minjie Wang's avatar
Minjie Wang committed
65
66
67
68
69
 * \param arg_types The arguments that wish to get from
 * \tparam F the function type
 *
 * \return The wrapped packed function.
 */
70
71
72
template <typename F>
inline PackedFunc PackFuncPackedArg(
    F f, const std::vector<DGLDataType>& arg_types);
Minjie Wang's avatar
Minjie Wang committed
73
74
75
76
77
/*!
 * \brief Extract number of buffer argument from the argument types.
 * \param arg_types The argument types.
 * \return number of buffer arguments
 */
78
inline size_t NumBufferArgs(const std::vector<DGLDataType>& arg_types);
Minjie Wang's avatar
Minjie Wang committed
79
80
81

// implementations details
namespace detail {
82
template <typename T, int kSize>
Minjie Wang's avatar
Minjie Wang committed
83
84
85
class TempArray {
 public:
  explicit TempArray(int size) {}
86
87
  T* data() { return data_; }

Minjie Wang's avatar
Minjie Wang committed
88
89
90
 private:
  T data_[kSize];
};
91
template <typename T>
Minjie Wang's avatar
Minjie Wang committed
92
93
94
class TempArray<T, 0> {
 public:
  explicit TempArray(int size) : data_(size) {}
95
96
  T* data() { return data_.data(); }

Minjie Wang's avatar
Minjie Wang committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
 private:
  std::vector<T> data_;
};

/*! \brief conversion code used in void arg. */
enum ArgConvertCode {
  INT64_TO_INT64,
  INT64_TO_INT32,
  INT64_TO_UINT32,
  FLOAT64_TO_FLOAT32,
  FLOAT64_TO_FLOAT64,
  HANDLE_TO_HANDLE
};

111
inline ArgConvertCode GetArgConvertCode(DGLDataType t) {
Minjie Wang's avatar
Minjie Wang committed
112
113
  CHECK_EQ(t.lanes, 1U)
      << "Cannot pass vector type argument to devic function for now";
114
  if (t.code == kDGLInt) {
Minjie Wang's avatar
Minjie Wang committed
115
116
    if (t.bits == 64U) return INT64_TO_INT64;
    if (t.bits == 32U) return INT64_TO_INT32;
117
  } else if (t.code == kDGLUInt) {
Minjie Wang's avatar
Minjie Wang committed
118
    if (t.bits == 32U) return INT64_TO_UINT32;
119
  } else if (t.code == kDGLFloat) {
Minjie Wang's avatar
Minjie Wang committed
120
121
122
123
124
125
126
127
128
    if (t.bits == 64U) return FLOAT64_TO_FLOAT64;
    if (t.bits == 32U) return FLOAT64_TO_FLOAT32;
  } else if (t.code == kHandle) {
    return HANDLE_TO_HANDLE;
  }
  LOG(FATAL) << "Cannot handle " << t << " as device function argument";
  return HANDLE_TO_HANDLE;
}

129
130
131
template <int N, typename F>
inline PackedFunc PackFuncVoidAddr_(
    F f, const std::vector<ArgConvertCode>& codes) {
Minjie Wang's avatar
Minjie Wang committed
132
  int num_args = static_cast<int>(codes.size());
133
  auto ret = [f, codes, num_args](DGLArgs args, DGLRetValue* ret) {
Minjie Wang's avatar
Minjie Wang committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    TempArray<void*, N> addr_(num_args);
    TempArray<ArgUnion, N> holder_(num_args);
    void** addr = addr_.data();
    ArgUnion* holder = holder_.data();
    for (int i = 0; i < num_args; ++i) {
      switch (codes[i]) {
        case INT64_TO_INT64:
        case FLOAT64_TO_FLOAT64:
        case HANDLE_TO_HANDLE: {
          addr[i] = (void*)&(args.values[i]);  // NOLINT(*)
          break;
        }
        case INT64_TO_INT32: {
          holder[i].v_int32 = static_cast<int32_t>(args.values[i].v_int64);
          addr[i] = &(holder[i]);
          break;
        }
151
        case INT64_TO_UINT32: {
Minjie Wang's avatar
Minjie Wang committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
          holder[i].v_uint32 = static_cast<uint32_t>(args.values[i].v_int64);
          addr[i] = &(holder[i]);
          break;
        }
        case FLOAT64_TO_FLOAT32: {
          holder[i].v_float32 = static_cast<float>(args.values[i].v_float64);
          addr[i] = &(holder[i]);
          break;
        }
      }
    }
    f(args, ret, addr);
  };
  return PackedFunc(ret);
}

168
template <int N, typename F>
Minjie Wang's avatar
Minjie Wang committed
169
170
171
inline PackedFunc PackFuncNonBufferArg_(
    F f, int base, const std::vector<ArgConvertCode>& codes) {
  int num_args = static_cast<int>(codes.size());
172
  auto ret = [f, codes, base, num_args](DGLArgs args, DGLRetValue* ret) {
Minjie Wang's avatar
Minjie Wang committed
173
174
175
176
177
178
    TempArray<ArgUnion, N> holder_(num_args);
    ArgUnion* holder = holder_.data();
    for (int i = 0; i < num_args; ++i) {
      switch (codes[i]) {
        case INT64_TO_INT64:
        case FLOAT64_TO_FLOAT64: {
179
180
          LOG(FATAL) << "Donot support 64bit argument to device function";
          break;
Minjie Wang's avatar
Minjie Wang committed
181
182
        }
        case INT64_TO_INT32: {
183
184
          holder[i].v_int32 =
              static_cast<int32_t>(args.values[base + i].v_int64);
Minjie Wang's avatar
Minjie Wang committed
185
186
          break;
        }
187
188
189
        case INT64_TO_UINT32: {
          holder[i].v_uint32 =
              static_cast<uint32_t>(args.values[base + i].v_int64);
Minjie Wang's avatar
Minjie Wang committed
190
191
192
          break;
        }
        case FLOAT64_TO_FLOAT32: {
193
194
          holder[i].v_float32 =
              static_cast<float>(args.values[base + i].v_float64);
Minjie Wang's avatar
Minjie Wang committed
195
196
197
          break;
        }
        case HANDLE_TO_HANDLE: {
198
199
          LOG(FATAL) << "not reached";
          break;
Minjie Wang's avatar
Minjie Wang committed
200
201
202
203
204
205
206
207
        }
      }
    }
    f(args, ret, holder);
  };
  return PackedFunc(ret);
}

208
template <int N, typename F>
Minjie Wang's avatar
Minjie Wang committed
209
210
211
inline PackedFunc PackFuncPackedArg_(
    F f, const std::vector<ArgConvertCode>& codes) {
  int num_args = static_cast<int>(codes.size());
212
  auto ret = [f, codes, num_args](DGLArgs args, DGLRetValue* ret) {
Minjie Wang's avatar
Minjie Wang committed
213
214
215
    TempArray<uint64_t, N> pack_(num_args);
    int32_t* pack = reinterpret_cast<int32_t*>(pack_.data());
    int32_t* ptr = pack;
216
    static_assert(sizeof(DGLValue) == 8, "invariant");
Minjie Wang's avatar
Minjie Wang committed
217
218
219
220
221
222
223
224
225
226
    static_assert(sizeof(void*) % sizeof(int32_t) == 0, "invariant");
    for (int i = 0; i < num_args; ++i) {
      switch (codes[i]) {
        case HANDLE_TO_HANDLE: {
          std::memcpy(ptr, &(args.values[i].v_handle), sizeof(void*));
          ptr += sizeof(void*) / sizeof(int32_t);
          break;
        }
        case INT64_TO_INT64:
        case FLOAT64_TO_FLOAT64: {
227
          std::memcpy(ptr, &args.values[i], sizeof(DGLValue));
Minjie Wang's avatar
Minjie Wang committed
228
229
230
231
232
233
234
235
          ptr += 2;
          break;
        }
        case INT64_TO_INT32: {
          *ptr = static_cast<int32_t>(args.values[i].v_int64);
          ++ptr;
          break;
        }
236
        case INT64_TO_UINT32: {
Minjie Wang's avatar
Minjie Wang committed
237
238
239
240
241
242
243
244
245
246
247
248
          *reinterpret_cast<uint32_t*>(ptr) =
              static_cast<uint32_t>(args.values[i].v_int64);
          ++ptr;
          break;
        }
        case FLOAT64_TO_FLOAT32: {
          *reinterpret_cast<float*>(ptr) =
              static_cast<float>(args.values[i].v_float64);
          ++ptr;
          break;
        }
        default: {
249
250
          LOG(FATAL) << "not reached";
          break;
Minjie Wang's avatar
Minjie Wang committed
251
252
253
254
255
256
257
258
259
        }
      }
    }
    f(args, ret, pack, (ptr - pack) * sizeof(int32_t));
  };
  return PackedFunc(ret);
}
}  // namespace detail

260
261
262
template <typename F>
inline PackedFunc PackFuncVoidAddr(
    F f, const std::vector<DGLDataType>& arg_types) {
Minjie Wang's avatar
Minjie Wang committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
  std::vector<detail::ArgConvertCode> codes(arg_types.size());
  for (size_t i = 0; i < arg_types.size(); ++i) {
    codes[i] = detail::GetArgConvertCode(arg_types[i]);
  }
  size_t num_void_args = arg_types.size();
  // specialization
  if (num_void_args <= 4) {
    return detail::PackFuncVoidAddr_<4>(f, codes);
  } else if (num_void_args <= 8) {
    return detail::PackFuncVoidAddr_<8>(f, codes);
  } else {
    return detail::PackFuncVoidAddr_<0>(f, codes);
  }
}

278
inline size_t NumBufferArgs(const std::vector<DGLDataType>& arg_types) {
Minjie Wang's avatar
Minjie Wang committed
279
280
281
  size_t base = arg_types.size();
  for (size_t i = 0; i < arg_types.size(); ++i) {
    if (arg_types[i].code != kHandle) {
282
283
      base = i;
      break;
Minjie Wang's avatar
Minjie Wang committed
284
285
286
287
288
289
290
291
292
    }
  }
  for (size_t i = base; i < arg_types.size(); ++i) {
    CHECK(arg_types[i].code != kHandle)
        << "Device function need to be organized";
  }
  return base;
}

293
294
295
template <typename F>
inline PackedFunc PackFuncNonBufferArg(
    F f, const std::vector<DGLDataType>& arg_types) {
Minjie Wang's avatar
Minjie Wang committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
  size_t num_buffer = NumBufferArgs(arg_types);
  std::vector<detail::ArgConvertCode> codes;
  for (size_t i = num_buffer; i < arg_types.size(); ++i) {
    codes.push_back(detail::GetArgConvertCode(arg_types[i]));
  }
  int base = static_cast<int>(num_buffer);
  size_t nargs = codes.size();
  // specialization
  if (nargs <= 4) {
    return detail::PackFuncNonBufferArg_<4>(f, base, codes);
  } else {
    return detail::PackFuncNonBufferArg_<0>(f, base, codes);
  }
}

311
312
313
template <typename F>
inline PackedFunc PackFuncPackedArg(
    F f, const std::vector<DGLDataType>& arg_types) {
Minjie Wang's avatar
Minjie Wang committed
314
315
316
317
318
319
320
321
322
323
324
325
326
  std::vector<detail::ArgConvertCode> codes;
  for (size_t i = 0; i < arg_types.size(); ++i) {
    codes.push_back(detail::GetArgConvertCode(arg_types[i]));
  }
  size_t nargs = codes.size();
  // specialization
  if (nargs <= 4) {
    return detail::PackFuncPackedArg_<4>(f, codes);
  } else {
    return detail::PackFuncPackedArg_<0>(f, codes);
  }
}
}  // namespace runtime
327
}  // namespace dgl
328
#endif  // DGL_RUNTIME_PACK_ARGS_H_