c_runtime_api.cc 12.4 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
/*!
 *  Copyright (c) 2016 by Contributors
 * \file c_runtime_api.cc
Minjie Wang's avatar
Minjie Wang committed
4
 * \brief Runtime API implementation
Minjie Wang's avatar
Minjie Wang committed
5
6
7
8
9
10
11
12
 */
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
13
#include <dgl/runtime/tensordispatch.h>
Minjie Wang's avatar
Minjie Wang committed
14
15
16
17
18
19
#include <array>
#include <algorithm>
#include <string>
#include <cstdlib>
#include "runtime_base.h"

20
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
namespace runtime {

/*!
 * \brief The name of Device API factory.
 * \param type The device type.
 */
inline std::string DeviceName(int type) {
  switch (type) {
    case kDLCPU: return "cpu";
    case kDLGPU: return "gpu";
    case kDLOpenCL: return "opencl";
    case kDLSDAccel: return "sdaccel";
    case kDLAOCL: return "aocl";
    case kDLVulkan: return "vulkan";
    case kDLMetal: return "metal";
    case kDLVPI: return "vpi";
    case kDLROCM: return "rocm";
    case kOpenGL: return "opengl";
    case kExtDev: return "ext_dev";
    default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
  }
}

class DeviceAPIManager {
 public:
  static const int kMaxDeviceAPI = 32;
  // Get API
48
  static DeviceAPI* Get(const DGLContext& ctx) {
Minjie Wang's avatar
Minjie Wang committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    return Get(ctx.device_type);
  }
  static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
    return Global()->GetAPI(dev_type, allow_missing);
  }

 private:
  std::array<DeviceAPI*, kMaxDeviceAPI> api_;
  DeviceAPI* rpc_api_{nullptr};
  std::mutex mutex_;
  // constructor
  DeviceAPIManager() {
    std::fill(api_.begin(), api_.end(), nullptr);
  }
  // Global static variable.
  static DeviceAPIManager* Global() {
    static DeviceAPIManager inst;
    return &inst;
  }
  // Get or initialize API.
  DeviceAPI* GetAPI(int type, bool allow_missing) {
    if (type < kRPCSessMask) {
      if (api_[type] != nullptr) return api_[type];
      std::lock_guard<std::mutex> lock(mutex_);
      if (api_[type] != nullptr) return api_[type];
      api_[type] = GetAPI(DeviceName(type), allow_missing);
      return api_[type];
    } else {
      if (rpc_api_ != nullptr) return rpc_api_;
      std::lock_guard<std::mutex> lock(mutex_);
      if (rpc_api_ != nullptr) return rpc_api_;
      rpc_api_ = GetAPI("rpc", allow_missing);
      return rpc_api_;
    }
  }
  DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
    std::string factory = "device_api." + name;
    auto* f = Registry::Get(factory);
    if (f == nullptr) {
      CHECK(allow_missing)
89
          << "Device API " << name << " is not enabled. Please install the cuda version of dgl.";
Minjie Wang's avatar
Minjie Wang committed
90
91
92
93
94
95
96
      return nullptr;
    }
    void* ptr = (*f)();
    return static_cast<DeviceAPI*>(ptr);
  }
};

97
DeviceAPI* DeviceAPI::Get(DGLContext ctx, bool allow_missing) {
Minjie Wang's avatar
Minjie Wang committed
98
99
100
101
  return DeviceAPIManager::Get(
      static_cast<int>(ctx.device_type), allow_missing);
}

102
void* DeviceAPI::AllocWorkspace(DGLContext ctx,
Minjie Wang's avatar
Minjie Wang committed
103
                                size_t size,
104
                                DGLType type_hint) {
Minjie Wang's avatar
Minjie Wang committed
105
106
107
  return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}

108
void DeviceAPI::FreeWorkspace(DGLContext ctx, void* ptr) {
Minjie Wang's avatar
Minjie Wang committed
109
110
111
  FreeDataSpace(ctx, ptr);
}

112
DGLStreamHandle DeviceAPI::CreateStream(DGLContext ctx) {
Minjie Wang's avatar
Minjie Wang committed
113
114
115
116
  LOG(FATAL) << "Device does not support stream api.";
  return 0;
}

117
void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
118
119
120
  LOG(FATAL) << "Device does not support stream api.";
}

121
122
123
void DeviceAPI::SyncStreamFromTo(DGLContext ctx,
                                 DGLStreamHandle event_src,
                                 DGLStreamHandle event_dst) {
Minjie Wang's avatar
Minjie Wang committed
124
125
  LOG(FATAL) << "Device does not support stream api.";
}
126
127
128
129
130
131
132
133

void DeviceAPI::PinData(DGLContext ctx, void* ptr, size_t nbytes) {
  LOG(FATAL) << "Device does not support cudaHostRegister api.";
}

void DeviceAPI::UnpinData(DGLContext ctx, void* ptr) {
  LOG(FATAL) << "Device does not support cudaHostUnregister api.";
}
Minjie Wang's avatar
Minjie Wang committed
134
}  // namespace runtime
135
}  // namespace dgl
Minjie Wang's avatar
Minjie Wang committed
136

137
using namespace dgl::runtime;
Minjie Wang's avatar
Minjie Wang committed
138

139
struct DGLRuntimeEntry {
Minjie Wang's avatar
Minjie Wang committed
140
141
  std::string ret_str;
  std::string last_error;
142
  DGLByteArray ret_bytes;
Minjie Wang's avatar
Minjie Wang committed
143
144
};

145
typedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore;
Minjie Wang's avatar
Minjie Wang committed
146

147
148
const char *DGLGetLastError() {
  return DGLAPIRuntimeStore::Get()->last_error.c_str();
Minjie Wang's avatar
Minjie Wang committed
149
150
}

151
void DGLAPISetLastError(const char* msg) {
Minjie Wang's avatar
Minjie Wang committed
152
#ifndef _LIBCPP_SGX_CONFIG
153
  DGLAPIRuntimeStore::Get()->last_error = msg;
Minjie Wang's avatar
Minjie Wang committed
154
155
156
157
158
#else
  sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif
}

159
int DGLModLoadFromFile(const char* file_name,
Minjie Wang's avatar
Minjie Wang committed
160
                       const char* format,
161
                       DGLModuleHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
162
163
164
165
166
167
  API_BEGIN();
  Module m = Module::LoadFromFile(file_name, format);
  *out = new Module(m);
  API_END();
}

168
169
int DGLModImport(DGLModuleHandle mod,
                 DGLModuleHandle dep) {
Minjie Wang's avatar
Minjie Wang committed
170
171
172
173
174
175
  API_BEGIN();
  static_cast<Module*>(mod)->Import(
      *static_cast<Module*>(dep));
  API_END();
}

176
int DGLModGetFunction(DGLModuleHandle mod,
Minjie Wang's avatar
Minjie Wang committed
177
178
                      const char* func_name,
                      int query_imports,
179
                      DGLFunctionHandle *func) {
Minjie Wang's avatar
Minjie Wang committed
180
181
182
183
184
185
186
187
188
189
190
  API_BEGIN();
  PackedFunc pf = static_cast<Module*>(mod)->GetFunction(
      func_name, query_imports != 0);
  if (pf != nullptr) {
    *func = new PackedFunc(pf);
  } else {
    *func = nullptr;
  }
  API_END();
}

191
int DGLModFree(DGLModuleHandle mod) {
Minjie Wang's avatar
Minjie Wang committed
192
193
194
195
196
  API_BEGIN();
  delete static_cast<Module*>(mod);
  API_END();
}

197
int DGLBackendGetFuncFromEnv(void* mod_node,
Minjie Wang's avatar
Minjie Wang committed
198
                             const char* func_name,
199
                             DGLFunctionHandle *func) {
Minjie Wang's avatar
Minjie Wang committed
200
  API_BEGIN();
201
  *func = (DGLFunctionHandle)(
Minjie Wang's avatar
Minjie Wang committed
202
203
204
205
      static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name));
  API_END();
}

206
void* DGLBackendAllocWorkspace(int device_type,
Minjie Wang's avatar
Minjie Wang committed
207
208
209
210
                               int device_id,
                               uint64_t size,
                               int dtype_code_hint,
                               int dtype_bits_hint) {
211
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
212
213
214
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;

215
  DGLType type_hint;
Minjie Wang's avatar
Minjie Wang committed
216
217
218
219
220
221
222
223
224
  type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
  type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
  type_hint.lanes = 1;

  return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx,
                                                    static_cast<size_t>(size),
                                                    type_hint);
}

225
int DGLBackendFreeWorkspace(int device_type,
Minjie Wang's avatar
Minjie Wang committed
226
227
                            int device_id,
                            void* ptr) {
228
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
229
230
231
232
233
234
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
  return 0;
}

235
int DGLBackendRunOnce(void** handle,
Minjie Wang's avatar
Minjie Wang committed
236
237
238
239
240
241
242
243
244
245
                      int (*f)(void*),
                      void* cdata,
                      int nbytes) {
  if (*handle == nullptr) {
    *handle = reinterpret_cast<void*>(1);
    return (*f)(cdata);
  }
  return 0;
}

246
int DGLFuncFree(DGLFunctionHandle func) {
Minjie Wang's avatar
Minjie Wang committed
247
248
249
250
251
  API_BEGIN();
  delete static_cast<PackedFunc*>(func);
  API_END();
}

252
253
int DGLFuncCall(DGLFunctionHandle func,
                DGLValue* args,
Minjie Wang's avatar
Minjie Wang committed
254
255
                int* arg_type_codes,
                int num_args,
256
                DGLValue* ret_val,
Minjie Wang's avatar
Minjie Wang committed
257
258
                int* ret_type_code) {
  API_BEGIN();
259
  DGLRetValue rv;
Minjie Wang's avatar
Minjie Wang committed
260
  (*static_cast<const PackedFunc*>(func)).CallPacked(
261
      DGLArgs(args, arg_type_codes, num_args), &rv);
Minjie Wang's avatar
Minjie Wang committed
262
263
  // handle return string.
  if (rv.type_code() == kStr ||
264
     rv.type_code() == kDGLType ||
Minjie Wang's avatar
Minjie Wang committed
265
      rv.type_code() == kBytes) {
266
267
    DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get();
    if (rv.type_code() != kDGLType) {
Minjie Wang's avatar
Minjie Wang committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
      e->ret_str = *rv.ptr<std::string>();
    } else {
      e->ret_str = rv.operator std::string();
    }
    if (rv.type_code() == kBytes) {
      e->ret_bytes.data = e->ret_str.c_str();
      e->ret_bytes.size = e->ret_str.length();
      *ret_type_code = kBytes;
      ret_val->v_handle = &(e->ret_bytes);
    } else {
      *ret_type_code = kStr;
      ret_val->v_str = e->ret_str.c_str();
    }
  } else {
    rv.MoveToCHost(ret_val, ret_type_code);
  }
  API_END();
}

287
288
int DGLCFuncSetReturn(DGLRetValueHandle ret,
                      DGLValue* value,
Minjie Wang's avatar
Minjie Wang committed
289
290
291
292
                      int* type_code,
                      int num_ret) {
  API_BEGIN();
  CHECK_EQ(num_ret, 1);
293
294
  DGLRetValue* rv = static_cast<DGLRetValue*>(ret);
  *rv = DGLArgValue(value[0], type_code[0]);
Minjie Wang's avatar
Minjie Wang committed
295
296
297
  API_END();
}

298
int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
Minjie Wang's avatar
Minjie Wang committed
299
                           void* resource_handle,
300
301
                           DGLPackedCFuncFinalizer fin,
                           DGLFunctionHandle *out) {
Minjie Wang's avatar
Minjie Wang committed
302
303
304
  API_BEGIN();
  if (fin == nullptr) {
    *out = new PackedFunc(
305
306
        [func, resource_handle](DGLArgs args, DGLRetValue* rv) {
          int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
Minjie Wang's avatar
Minjie Wang committed
307
308
                         args.num_args, rv, resource_handle);
          if (ret != 0) {
309
310
            std::string err = "DGLCall CFunc Error:\n";
            err += DGLGetLastError();
Minjie Wang's avatar
Minjie Wang committed
311
312
313
314
315
316
317
318
            throw dmlc::Error(err);
          }
        });
  } else {
    // wrap it in a shared_ptr, with fin as deleter.
    // so fin will be called when the lambda went out of scope.
    std::shared_ptr<void> rpack(resource_handle, fin);
    *out = new PackedFunc(
319
320
        [func, rpack](DGLArgs args, DGLRetValue* rv) {
          int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
Minjie Wang's avatar
Minjie Wang committed
321
322
                         args.num_args, rv, rpack.get());
          if (ret != 0) {
323
324
            std::string err = "DGLCall CFunc Error:\n";
            err += DGLGetLastError();
Minjie Wang's avatar
Minjie Wang committed
325
326
327
328
329
330
331
            throw dmlc::Error(err);
          }
      });
  }
  API_END();
}

332
int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
333
  API_BEGIN();
334
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
335
336
337
338
339
340
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  *out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
  API_END();
}

341
int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
342
  API_BEGIN();
343
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
344
345
346
347
348
349
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
  API_END();
}

350
int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
351
  API_BEGIN();
352
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
353
354
355
356
357
358
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
  API_END();
}

359
360
361
362
363
364
365
366
367
int DGLGetStream(int device_type, int device_id, DGLStreamHandle* stream) {
  API_BEGIN();
  DGLContext ctx;
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  *stream = DeviceAPIManager::Get(ctx)->GetStream();
  API_END();
}

368
int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
369
  API_BEGIN();
370
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
371
372
373
374
375
376
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
  API_END();
}

377
int DGLStreamStreamSynchronize(int device_type,
Minjie Wang's avatar
Minjie Wang committed
378
                               int device_id,
379
380
                               DGLStreamHandle src,
                               DGLStreamHandle dst) {
Minjie Wang's avatar
Minjie Wang committed
381
  API_BEGIN();
382
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
383
384
385
386
387
388
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
  API_END();
}

389
int DGLCbArgToReturn(DGLValue* value, int code) {
Minjie Wang's avatar
Minjie Wang committed
390
  API_BEGIN();
391
392
  dgl::runtime::DGLRetValue rv;
  rv = dgl::runtime::DGLArgValue(*value, code);
Minjie Wang's avatar
Minjie Wang committed
393
394
395
396
397
398
  int tcode;
  rv.MoveToCHost(value, &tcode);
  CHECK_EQ(tcode, code);
  API_END();
}

399
400
int DGLLoadTensorAdapter(const char *path) {
  return TensorDispatcher::Global()->Load(path) ? 0 : -1;
401
402
}

Minjie Wang's avatar
Minjie Wang committed
403
// set device api
404
405
406
DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
.set_body([](DGLArgs args, DGLRetValue *ret) {
    DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
407
408
409
410
411
412
    ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
    ctx.device_id = args[1];
    DeviceAPIManager::Get(ctx)->SetDevice(ctx);
  });

// set device api
413
414
415
DGL_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](DGLArgs args, DGLRetValue *ret) {
    DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
    ctx.device_id = args[1];

    DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
    if (kind == kExist) {
      DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
      if (api != nullptr) {
        api->GetAttr(ctx, kind, ret);
      } else {
        *ret = 0;
      }
    } else {
      DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
    }
  });
431