c_runtime_api.cc 12.5 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
/*!
 *  Copyright (c) 2016 by Contributors
 * \file c_runtime_api.cc
Minjie Wang's avatar
Minjie Wang committed
4
 * \brief Runtime API implementation
Minjie Wang's avatar
Minjie Wang committed
5
6
7
8
9
10
11
12
 */
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
13
#include <dgl/runtime/tensordispatch.h>
Minjie Wang's avatar
Minjie Wang committed
14
15
16
17
18
19
#include <array>
#include <algorithm>
#include <string>
#include <cstdlib>
#include "runtime_base.h"

20
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
namespace runtime {

/*!
 * \brief The name of Device API factory.
 * \param type The device type.
 */
inline std::string DeviceName(int type) {
  switch (type) {
    case kDLCPU: return "cpu";
    case kDLGPU: return "gpu";
    case kDLOpenCL: return "opencl";
    case kDLSDAccel: return "sdaccel";
    case kDLAOCL: return "aocl";
    case kDLVulkan: return "vulkan";
    case kDLMetal: return "metal";
    case kDLVPI: return "vpi";
    case kDLROCM: return "rocm";
    case kOpenGL: return "opengl";
    case kExtDev: return "ext_dev";
    default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
  }
}

class DeviceAPIManager {
 public:
  static const int kMaxDeviceAPI = 32;
  // Get API
48
  static DeviceAPI* Get(const DGLContext& ctx) {
Minjie Wang's avatar
Minjie Wang committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    return Get(ctx.device_type);
  }
  static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
    return Global()->GetAPI(dev_type, allow_missing);
  }

 private:
  std::array<DeviceAPI*, kMaxDeviceAPI> api_;
  DeviceAPI* rpc_api_{nullptr};
  std::mutex mutex_;
  // constructor
  DeviceAPIManager() {
    std::fill(api_.begin(), api_.end(), nullptr);
  }
  // Global static variable.
  static DeviceAPIManager* Global() {
    static DeviceAPIManager inst;
    return &inst;
  }
  // Get or initialize API.
  DeviceAPI* GetAPI(int type, bool allow_missing) {
    if (type < kRPCSessMask) {
      if (api_[type] != nullptr) return api_[type];
      std::lock_guard<std::mutex> lock(mutex_);
      if (api_[type] != nullptr) return api_[type];
      api_[type] = GetAPI(DeviceName(type), allow_missing);
      return api_[type];
    } else {
      if (rpc_api_ != nullptr) return rpc_api_;
      std::lock_guard<std::mutex> lock(mutex_);
      if (rpc_api_ != nullptr) return rpc_api_;
      rpc_api_ = GetAPI("rpc", allow_missing);
      return rpc_api_;
    }
  }
  DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
    std::string factory = "device_api." + name;
    auto* f = Registry::Get(factory);
    if (f == nullptr) {
      CHECK(allow_missing)
89
          << "Device API " << name << " is not enabled. Please install the cuda version of dgl.";
Minjie Wang's avatar
Minjie Wang committed
90
91
92
93
94
95
96
      return nullptr;
    }
    void* ptr = (*f)();
    return static_cast<DeviceAPI*>(ptr);
  }
};

97
DeviceAPI* DeviceAPI::Get(DGLContext ctx, bool allow_missing) {
Minjie Wang's avatar
Minjie Wang committed
98
99
100
101
  return DeviceAPIManager::Get(
      static_cast<int>(ctx.device_type), allow_missing);
}

102
103
104
105
DeviceAPI* DeviceAPI::Get(DLDeviceType dev_type, bool allow_missing) {
  return DeviceAPIManager::Get(static_cast<int>(dev_type), allow_missing);
}

106
void* DeviceAPI::AllocWorkspace(DGLContext ctx,
Minjie Wang's avatar
Minjie Wang committed
107
                                size_t size,
108
                                DGLType type_hint) {
Minjie Wang's avatar
Minjie Wang committed
109
110
111
  return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}

112
void DeviceAPI::FreeWorkspace(DGLContext ctx, void* ptr) {
Minjie Wang's avatar
Minjie Wang committed
113
114
115
  FreeDataSpace(ctx, ptr);
}

116
DGLStreamHandle DeviceAPI::CreateStream(DGLContext ctx) {
Minjie Wang's avatar
Minjie Wang committed
117
118
119
120
  LOG(FATAL) << "Device does not support stream api.";
  return 0;
}

121
void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
122
123
124
  LOG(FATAL) << "Device does not support stream api.";
}

125
126
127
void DeviceAPI::SyncStreamFromTo(DGLContext ctx,
                                 DGLStreamHandle event_src,
                                 DGLStreamHandle event_dst) {
Minjie Wang's avatar
Minjie Wang committed
128
129
  LOG(FATAL) << "Device does not support stream api.";
}
130

131
void DeviceAPI::PinData(void* ptr, size_t nbytes) {
132
133
134
  LOG(FATAL) << "Device does not support cudaHostRegister api.";
}

135
void DeviceAPI::UnpinData(void* ptr) {
136
137
  LOG(FATAL) << "Device does not support cudaHostUnregister api.";
}
Minjie Wang's avatar
Minjie Wang committed
138
}  // namespace runtime
139
}  // namespace dgl
Minjie Wang's avatar
Minjie Wang committed
140

141
using namespace dgl::runtime;
Minjie Wang's avatar
Minjie Wang committed
142

143
struct DGLRuntimeEntry {
Minjie Wang's avatar
Minjie Wang committed
144
145
  std::string ret_str;
  std::string last_error;
146
  DGLByteArray ret_bytes;
Minjie Wang's avatar
Minjie Wang committed
147
148
};

149
typedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore;
Minjie Wang's avatar
Minjie Wang committed
150

151
152
const char *DGLGetLastError() {
  return DGLAPIRuntimeStore::Get()->last_error.c_str();
Minjie Wang's avatar
Minjie Wang committed
153
154
}

155
void DGLAPISetLastError(const char* msg) {
Minjie Wang's avatar
Minjie Wang committed
156
#ifndef _LIBCPP_SGX_CONFIG
157
  DGLAPIRuntimeStore::Get()->last_error = msg;
Minjie Wang's avatar
Minjie Wang committed
158
159
160
161
162
#else
  sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif
}

163
int DGLModLoadFromFile(const char* file_name,
Minjie Wang's avatar
Minjie Wang committed
164
                       const char* format,
165
                       DGLModuleHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
166
167
168
169
170
171
  API_BEGIN();
  Module m = Module::LoadFromFile(file_name, format);
  *out = new Module(m);
  API_END();
}

172
173
int DGLModImport(DGLModuleHandle mod,
                 DGLModuleHandle dep) {
Minjie Wang's avatar
Minjie Wang committed
174
175
176
177
178
179
  API_BEGIN();
  static_cast<Module*>(mod)->Import(
      *static_cast<Module*>(dep));
  API_END();
}

180
int DGLModGetFunction(DGLModuleHandle mod,
Minjie Wang's avatar
Minjie Wang committed
181
182
                      const char* func_name,
                      int query_imports,
183
                      DGLFunctionHandle *func) {
Minjie Wang's avatar
Minjie Wang committed
184
185
186
187
188
189
190
191
192
193
194
  API_BEGIN();
  PackedFunc pf = static_cast<Module*>(mod)->GetFunction(
      func_name, query_imports != 0);
  if (pf != nullptr) {
    *func = new PackedFunc(pf);
  } else {
    *func = nullptr;
  }
  API_END();
}

195
int DGLModFree(DGLModuleHandle mod) {
Minjie Wang's avatar
Minjie Wang committed
196
197
198
199
200
  API_BEGIN();
  delete static_cast<Module*>(mod);
  API_END();
}

201
int DGLBackendGetFuncFromEnv(void* mod_node,
Minjie Wang's avatar
Minjie Wang committed
202
                             const char* func_name,
203
                             DGLFunctionHandle *func) {
Minjie Wang's avatar
Minjie Wang committed
204
  API_BEGIN();
205
  *func = (DGLFunctionHandle)(
Minjie Wang's avatar
Minjie Wang committed
206
207
208
209
      static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name));
  API_END();
}

210
void* DGLBackendAllocWorkspace(int device_type,
Minjie Wang's avatar
Minjie Wang committed
211
212
213
214
                               int device_id,
                               uint64_t size,
                               int dtype_code_hint,
                               int dtype_bits_hint) {
215
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
216
217
218
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;

219
  DGLType type_hint;
Minjie Wang's avatar
Minjie Wang committed
220
221
222
223
224
225
226
227
228
  type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
  type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
  type_hint.lanes = 1;

  return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx,
                                                    static_cast<size_t>(size),
                                                    type_hint);
}

229
int DGLBackendFreeWorkspace(int device_type,
Minjie Wang's avatar
Minjie Wang committed
230
231
                            int device_id,
                            void* ptr) {
232
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
233
234
235
236
237
238
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
  return 0;
}

239
int DGLBackendRunOnce(void** handle,
Minjie Wang's avatar
Minjie Wang committed
240
241
242
243
244
245
246
247
248
249
                      int (*f)(void*),
                      void* cdata,
                      int nbytes) {
  if (*handle == nullptr) {
    *handle = reinterpret_cast<void*>(1);
    return (*f)(cdata);
  }
  return 0;
}

250
int DGLFuncFree(DGLFunctionHandle func) {
Minjie Wang's avatar
Minjie Wang committed
251
252
253
254
255
  API_BEGIN();
  delete static_cast<PackedFunc*>(func);
  API_END();
}

256
257
int DGLFuncCall(DGLFunctionHandle func,
                DGLValue* args,
Minjie Wang's avatar
Minjie Wang committed
258
259
                int* arg_type_codes,
                int num_args,
260
                DGLValue* ret_val,
Minjie Wang's avatar
Minjie Wang committed
261
262
                int* ret_type_code) {
  API_BEGIN();
263
  DGLRetValue rv;
Minjie Wang's avatar
Minjie Wang committed
264
  (*static_cast<const PackedFunc*>(func)).CallPacked(
265
      DGLArgs(args, arg_type_codes, num_args), &rv);
Minjie Wang's avatar
Minjie Wang committed
266
267
  // handle return string.
  if (rv.type_code() == kStr ||
268
     rv.type_code() == kDGLType ||
Minjie Wang's avatar
Minjie Wang committed
269
      rv.type_code() == kBytes) {
270
271
    DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get();
    if (rv.type_code() != kDGLType) {
Minjie Wang's avatar
Minjie Wang committed
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
      e->ret_str = *rv.ptr<std::string>();
    } else {
      e->ret_str = rv.operator std::string();
    }
    if (rv.type_code() == kBytes) {
      e->ret_bytes.data = e->ret_str.c_str();
      e->ret_bytes.size = e->ret_str.length();
      *ret_type_code = kBytes;
      ret_val->v_handle = &(e->ret_bytes);
    } else {
      *ret_type_code = kStr;
      ret_val->v_str = e->ret_str.c_str();
    }
  } else {
    rv.MoveToCHost(ret_val, ret_type_code);
  }
  API_END();
}

291
292
int DGLCFuncSetReturn(DGLRetValueHandle ret,
                      DGLValue* value,
Minjie Wang's avatar
Minjie Wang committed
293
294
295
296
                      int* type_code,
                      int num_ret) {
  API_BEGIN();
  CHECK_EQ(num_ret, 1);
297
298
  DGLRetValue* rv = static_cast<DGLRetValue*>(ret);
  *rv = DGLArgValue(value[0], type_code[0]);
Minjie Wang's avatar
Minjie Wang committed
299
300
301
  API_END();
}

302
int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
Minjie Wang's avatar
Minjie Wang committed
303
                           void* resource_handle,
304
305
                           DGLPackedCFuncFinalizer fin,
                           DGLFunctionHandle *out) {
Minjie Wang's avatar
Minjie Wang committed
306
307
308
  API_BEGIN();
  if (fin == nullptr) {
    *out = new PackedFunc(
309
310
        [func, resource_handle](DGLArgs args, DGLRetValue* rv) {
          int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
Minjie Wang's avatar
Minjie Wang committed
311
312
                         args.num_args, rv, resource_handle);
          if (ret != 0) {
313
314
            std::string err = "DGLCall CFunc Error:\n";
            err += DGLGetLastError();
Minjie Wang's avatar
Minjie Wang committed
315
316
317
318
319
320
321
322
            throw dmlc::Error(err);
          }
        });
  } else {
    // wrap it in a shared_ptr, with fin as deleter.
    // so fin will be called when the lambda went out of scope.
    std::shared_ptr<void> rpack(resource_handle, fin);
    *out = new PackedFunc(
323
324
        [func, rpack](DGLArgs args, DGLRetValue* rv) {
          int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
Minjie Wang's avatar
Minjie Wang committed
325
326
                         args.num_args, rv, rpack.get());
          if (ret != 0) {
327
328
            std::string err = "DGLCall CFunc Error:\n";
            err += DGLGetLastError();
Minjie Wang's avatar
Minjie Wang committed
329
330
331
332
333
334
335
            throw dmlc::Error(err);
          }
      });
  }
  API_END();
}

336
int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
337
  API_BEGIN();
338
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
339
340
341
342
343
344
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  *out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
  API_END();
}

345
int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
346
  API_BEGIN();
347
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
348
349
350
351
352
353
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
  API_END();
}

354
int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
355
  API_BEGIN();
356
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
357
358
359
360
361
362
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
  API_END();
}

363
364
365
366
367
368
369
370
371
int DGLGetStream(int device_type, int device_id, DGLStreamHandle* stream) {
  API_BEGIN();
  DGLContext ctx;
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  *stream = DeviceAPIManager::Get(ctx)->GetStream();
  API_END();
}

372
int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
373
  API_BEGIN();
374
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
375
376
377
378
379
380
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
  API_END();
}

381
int DGLStreamStreamSynchronize(int device_type,
Minjie Wang's avatar
Minjie Wang committed
382
                               int device_id,
383
384
                               DGLStreamHandle src,
                               DGLStreamHandle dst) {
Minjie Wang's avatar
Minjie Wang committed
385
  API_BEGIN();
386
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
387
388
389
390
391
392
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
  API_END();
}

393
int DGLCbArgToReturn(DGLValue* value, int code) {
Minjie Wang's avatar
Minjie Wang committed
394
  API_BEGIN();
395
396
  dgl::runtime::DGLRetValue rv;
  rv = dgl::runtime::DGLArgValue(*value, code);
Minjie Wang's avatar
Minjie Wang committed
397
398
399
400
401
402
  int tcode;
  rv.MoveToCHost(value, &tcode);
  CHECK_EQ(tcode, code);
  API_END();
}

403
404
int DGLLoadTensorAdapter(const char *path) {
  return TensorDispatcher::Global()->Load(path) ? 0 : -1;
405
406
}

Minjie Wang's avatar
Minjie Wang committed
407
// set device api
408
409
410
DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
.set_body([](DGLArgs args, DGLRetValue *ret) {
    DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
411
412
413
414
415
416
    ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
    ctx.device_id = args[1];
    DeviceAPIManager::Get(ctx)->SetDevice(ctx);
  });

// set device api
417
418
419
DGL_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](DGLArgs args, DGLRetValue *ret) {
    DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
    ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
    ctx.device_id = args[1];

    DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
    if (kind == kExist) {
      DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
      if (api != nullptr) {
        api->GetAttr(ctx, kind, ret);
      } else {
        *ret = 0;
      }
    } else {
      DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
    }
  });
435