c_runtime_api.cc 12.3 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
/*!
2
 *  Copyright (c) 2016-2022 by Contributors
Minjie Wang's avatar
Minjie Wang committed
3
 * \file c_runtime_api.cc
Minjie Wang's avatar
Minjie Wang committed
4
 * \brief Runtime API implementation
Minjie Wang's avatar
Minjie Wang committed
5
6
7
8
9
10
11
12
 */
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
13
#include <dgl/runtime/tensordispatch.h>
Minjie Wang's avatar
Minjie Wang committed
14
15
16
17
18
19
#include <array>
#include <algorithm>
#include <string>
#include <cstdlib>
#include "runtime_base.h"

20
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
21
22
23
24
25
26
27
28
namespace runtime {

/*!
 * \brief The name of Device API factory.
 * \param type The device type.
 */
inline std::string DeviceName(int type) {
  switch (type) {
29
30
31
    case kDGLCPU: return "cpu";
    case kDGLCUDA: return "cuda";
    // add more device here once supported
Minjie Wang's avatar
Minjie Wang committed
32
33
34
35
36
37
38
39
    default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
  }
}

class DeviceAPIManager {
 public:
  static const int kMaxDeviceAPI = 32;
  // Get API
40
  static DeviceAPI* Get(const DGLContext& ctx) {
Minjie Wang's avatar
Minjie Wang committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    return Get(ctx.device_type);
  }
  static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
    return Global()->GetAPI(dev_type, allow_missing);
  }

 private:
  std::array<DeviceAPI*, kMaxDeviceAPI> api_;
  DeviceAPI* rpc_api_{nullptr};
  std::mutex mutex_;
  // constructor
  DeviceAPIManager() {
    std::fill(api_.begin(), api_.end(), nullptr);
  }
  // Global static variable.
  static DeviceAPIManager* Global() {
    static DeviceAPIManager inst;
    return &inst;
  }
  // Get or initialize API.
  DeviceAPI* GetAPI(int type, bool allow_missing) {
    if (type < kRPCSessMask) {
      if (api_[type] != nullptr) return api_[type];
      std::lock_guard<std::mutex> lock(mutex_);
      if (api_[type] != nullptr) return api_[type];
      api_[type] = GetAPI(DeviceName(type), allow_missing);
      return api_[type];
    } else {
      if (rpc_api_ != nullptr) return rpc_api_;
      std::lock_guard<std::mutex> lock(mutex_);
      if (rpc_api_ != nullptr) return rpc_api_;
      rpc_api_ = GetAPI("rpc", allow_missing);
      return rpc_api_;
    }
  }
  DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
    std::string factory = "device_api." + name;
    auto* f = Registry::Get(factory);
    if (f == nullptr) {
      CHECK(allow_missing)
81
          << "Device API " << name << " is not enabled. Please install the cuda version of dgl.";
Minjie Wang's avatar
Minjie Wang committed
82
83
84
85
86
87
88
      return nullptr;
    }
    void* ptr = (*f)();
    return static_cast<DeviceAPI*>(ptr);
  }
};

89
DeviceAPI* DeviceAPI::Get(DGLContext ctx, bool allow_missing) {
Minjie Wang's avatar
Minjie Wang committed
90
91
92
93
  return DeviceAPIManager::Get(
      static_cast<int>(ctx.device_type), allow_missing);
}

94
DeviceAPI* DeviceAPI::Get(DGLDeviceType dev_type, bool allow_missing) {
95
96
97
  return DeviceAPIManager::Get(static_cast<int>(dev_type), allow_missing);
}

98
void* DeviceAPI::AllocWorkspace(DGLContext ctx,
Minjie Wang's avatar
Minjie Wang committed
99
                                size_t size,
100
                                DGLDataType type_hint) {
Minjie Wang's avatar
Minjie Wang committed
101
102
103
  return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}

104
void DeviceAPI::FreeWorkspace(DGLContext ctx, void* ptr) {
Minjie Wang's avatar
Minjie Wang committed
105
106
107
  FreeDataSpace(ctx, ptr);
}

108
DGLStreamHandle DeviceAPI::CreateStream(DGLContext ctx) {
Minjie Wang's avatar
Minjie Wang committed
109
110
111
112
  LOG(FATAL) << "Device does not support stream api.";
  return 0;
}

113
void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
114
115
116
  LOG(FATAL) << "Device does not support stream api.";
}

117
118
119
void DeviceAPI::SyncStreamFromTo(DGLContext ctx,
                                 DGLStreamHandle event_src,
                                 DGLStreamHandle event_dst) {
Minjie Wang's avatar
Minjie Wang committed
120
121
  LOG(FATAL) << "Device does not support stream api.";
}
122

123
void DeviceAPI::PinData(void* ptr, size_t nbytes) {
124
125
126
  LOG(FATAL) << "Device does not support cudaHostRegister api.";
}

127
void DeviceAPI::UnpinData(void* ptr) {
128
129
  LOG(FATAL) << "Device does not support cudaHostUnregister api.";
}
Minjie Wang's avatar
Minjie Wang committed
130
}  // namespace runtime
131
}  // namespace dgl
Minjie Wang's avatar
Minjie Wang committed
132

133
using namespace dgl::runtime;
Minjie Wang's avatar
Minjie Wang committed
134

135
struct DGLRuntimeEntry {
Minjie Wang's avatar
Minjie Wang committed
136
137
  std::string ret_str;
  std::string last_error;
138
  DGLByteArray ret_bytes;
Minjie Wang's avatar
Minjie Wang committed
139
140
};

141
typedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore;
Minjie Wang's avatar
Minjie Wang committed
142

143
144
const char *DGLGetLastError() {
  return DGLAPIRuntimeStore::Get()->last_error.c_str();
Minjie Wang's avatar
Minjie Wang committed
145
146
}

147
void DGLAPISetLastError(const char* msg) {
Minjie Wang's avatar
Minjie Wang committed
148
#ifndef _LIBCPP_SGX_CONFIG
149
  DGLAPIRuntimeStore::Get()->last_error = msg;
Minjie Wang's avatar
Minjie Wang committed
150
151
152
153
154
#else
  sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif
}

155
int DGLModLoadFromFile(const char* file_name,
Minjie Wang's avatar
Minjie Wang committed
156
                       const char* format,
157
                       DGLModuleHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
158
159
160
161
162
163
  API_BEGIN();
  Module m = Module::LoadFromFile(file_name, format);
  *out = new Module(m);
  API_END();
}

164
165
int DGLModImport(DGLModuleHandle mod,
                 DGLModuleHandle dep) {
Minjie Wang's avatar
Minjie Wang committed
166
167
168
169
170
171
  API_BEGIN();
  static_cast<Module*>(mod)->Import(
      *static_cast<Module*>(dep));
  API_END();
}

172
int DGLModGetFunction(DGLModuleHandle mod,
Minjie Wang's avatar
Minjie Wang committed
173
174
                      const char* func_name,
                      int query_imports,
175
                      DGLFunctionHandle *func) {
Minjie Wang's avatar
Minjie Wang committed
176
177
178
179
180
181
182
183
184
185
186
  API_BEGIN();
  PackedFunc pf = static_cast<Module*>(mod)->GetFunction(
      func_name, query_imports != 0);
  if (pf != nullptr) {
    *func = new PackedFunc(pf);
  } else {
    *func = nullptr;
  }
  API_END();
}

187
int DGLModFree(DGLModuleHandle mod) {
Minjie Wang's avatar
Minjie Wang committed
188
189
190
191
192
  API_BEGIN();
  delete static_cast<Module*>(mod);
  API_END();
}

193
int DGLBackendGetFuncFromEnv(void* mod_node,
Minjie Wang's avatar
Minjie Wang committed
194
                             const char* func_name,
195
                             DGLFunctionHandle *func) {
Minjie Wang's avatar
Minjie Wang committed
196
  API_BEGIN();
197
  *func = (DGLFunctionHandle)(
Minjie Wang's avatar
Minjie Wang committed
198
199
200
201
      static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name));
  API_END();
}

202
void* DGLBackendAllocWorkspace(int device_type,
Minjie Wang's avatar
Minjie Wang committed
203
204
205
206
                               int device_id,
                               uint64_t size,
                               int dtype_code_hint,
                               int dtype_bits_hint) {
207
  DGLContext ctx;
208
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
209
210
  ctx.device_id = device_id;

211
  DGLDataType type_hint;
Minjie Wang's avatar
Minjie Wang committed
212
213
214
215
216
217
218
219
220
  type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
  type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
  type_hint.lanes = 1;

  return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx,
                                                    static_cast<size_t>(size),
                                                    type_hint);
}

221
int DGLBackendFreeWorkspace(int device_type,
Minjie Wang's avatar
Minjie Wang committed
222
223
                            int device_id,
                            void* ptr) {
224
  DGLContext ctx;
225
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
226
227
228
229
230
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
  return 0;
}

231
int DGLBackendRunOnce(void** handle,
Minjie Wang's avatar
Minjie Wang committed
232
233
234
235
236
237
238
239
240
241
                      int (*f)(void*),
                      void* cdata,
                      int nbytes) {
  if (*handle == nullptr) {
    *handle = reinterpret_cast<void*>(1);
    return (*f)(cdata);
  }
  return 0;
}

242
int DGLFuncFree(DGLFunctionHandle func) {
Minjie Wang's avatar
Minjie Wang committed
243
244
245
246
247
  API_BEGIN();
  delete static_cast<PackedFunc*>(func);
  API_END();
}

248
249
int DGLFuncCall(DGLFunctionHandle func,
                DGLValue* args,
Minjie Wang's avatar
Minjie Wang committed
250
251
                int* arg_type_codes,
                int num_args,
252
                DGLValue* ret_val,
Minjie Wang's avatar
Minjie Wang committed
253
254
                int* ret_type_code) {
  API_BEGIN();
255
  DGLRetValue rv;
Minjie Wang's avatar
Minjie Wang committed
256
  (*static_cast<const PackedFunc*>(func)).CallPacked(
257
      DGLArgs(args, arg_type_codes, num_args), &rv);
Minjie Wang's avatar
Minjie Wang committed
258
259
  // handle return string.
  if (rv.type_code() == kStr ||
260
     rv.type_code() == kDGLDataType ||
Minjie Wang's avatar
Minjie Wang committed
261
      rv.type_code() == kBytes) {
262
    DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get();
263
    if (rv.type_code() != kDGLDataType) {
Minjie Wang's avatar
Minjie Wang committed
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
      e->ret_str = *rv.ptr<std::string>();
    } else {
      e->ret_str = rv.operator std::string();
    }
    if (rv.type_code() == kBytes) {
      e->ret_bytes.data = e->ret_str.c_str();
      e->ret_bytes.size = e->ret_str.length();
      *ret_type_code = kBytes;
      ret_val->v_handle = &(e->ret_bytes);
    } else {
      *ret_type_code = kStr;
      ret_val->v_str = e->ret_str.c_str();
    }
  } else {
    rv.MoveToCHost(ret_val, ret_type_code);
  }
  API_END();
}

283
284
int DGLCFuncSetReturn(DGLRetValueHandle ret,
                      DGLValue* value,
Minjie Wang's avatar
Minjie Wang committed
285
286
287
288
                      int* type_code,
                      int num_ret) {
  API_BEGIN();
  CHECK_EQ(num_ret, 1);
289
290
  DGLRetValue* rv = static_cast<DGLRetValue*>(ret);
  *rv = DGLArgValue(value[0], type_code[0]);
Minjie Wang's avatar
Minjie Wang committed
291
292
293
  API_END();
}

294
int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
Minjie Wang's avatar
Minjie Wang committed
295
                           void* resource_handle,
296
297
                           DGLPackedCFuncFinalizer fin,
                           DGLFunctionHandle *out) {
Minjie Wang's avatar
Minjie Wang committed
298
299
300
  API_BEGIN();
  if (fin == nullptr) {
    *out = new PackedFunc(
301
302
        [func, resource_handle](DGLArgs args, DGLRetValue* rv) {
          int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
Minjie Wang's avatar
Minjie Wang committed
303
304
                         args.num_args, rv, resource_handle);
          if (ret != 0) {
305
306
            std::string err = "DGLCall CFunc Error:\n";
            err += DGLGetLastError();
Minjie Wang's avatar
Minjie Wang committed
307
308
309
310
311
312
313
314
            throw dmlc::Error(err);
          }
        });
  } else {
    // wrap it in a shared_ptr, with fin as deleter.
    // so fin will be called when the lambda went out of scope.
    std::shared_ptr<void> rpack(resource_handle, fin);
    *out = new PackedFunc(
315
316
        [func, rpack](DGLArgs args, DGLRetValue* rv) {
          int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
Minjie Wang's avatar
Minjie Wang committed
317
318
                         args.num_args, rv, rpack.get());
          if (ret != 0) {
319
320
            std::string err = "DGLCall CFunc Error:\n";
            err += DGLGetLastError();
Minjie Wang's avatar
Minjie Wang committed
321
322
323
324
325
326
327
            throw dmlc::Error(err);
          }
      });
  }
  API_END();
}

328
int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
329
  API_BEGIN();
330
  DGLContext ctx;
331
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
332
333
334
335
336
  ctx.device_id = device_id;
  *out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
  API_END();
}

337
int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
338
  API_BEGIN();
339
  DGLContext ctx;
340
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
341
342
343
344
345
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
  API_END();
}

346
int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
347
  API_BEGIN();
348
  DGLContext ctx;
349
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
350
351
352
353
354
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
  API_END();
}

355
356
357
int DGLGetStream(int device_type, int device_id, DGLStreamHandle* stream) {
  API_BEGIN();
  DGLContext ctx;
358
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
359
360
361
362
363
  ctx.device_id = device_id;
  *stream = DeviceAPIManager::Get(ctx)->GetStream();
  API_END();
}

364
int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
365
  API_BEGIN();
366
  DGLContext ctx;
367
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
368
369
370
371
372
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
  API_END();
}

373
int DGLStreamStreamSynchronize(int device_type,
Minjie Wang's avatar
Minjie Wang committed
374
                               int device_id,
375
376
                               DGLStreamHandle src,
                               DGLStreamHandle dst) {
Minjie Wang's avatar
Minjie Wang committed
377
  API_BEGIN();
378
  DGLContext ctx;
379
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
380
381
382
383
384
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
  API_END();
}

385
int DGLCbArgToReturn(DGLValue* value, int code) {
Minjie Wang's avatar
Minjie Wang committed
386
  API_BEGIN();
387
388
  dgl::runtime::DGLRetValue rv;
  rv = dgl::runtime::DGLArgValue(*value, code);
Minjie Wang's avatar
Minjie Wang committed
389
390
391
392
393
394
  int tcode;
  rv.MoveToCHost(value, &tcode);
  CHECK_EQ(tcode, code);
  API_END();
}

395
396
int DGLLoadTensorAdapter(const char *path) {
  return TensorDispatcher::Global()->Load(path) ? 0 : -1;
397
398
}

Minjie Wang's avatar
Minjie Wang committed
399
// set device api
400
401
402
DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
.set_body([](DGLArgs args, DGLRetValue *ret) {
    DGLContext ctx;
403
    ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
Minjie Wang's avatar
Minjie Wang committed
404
405
406
407
408
    ctx.device_id = args[1];
    DeviceAPIManager::Get(ctx)->SetDevice(ctx);
  });

// set device api
409
410
411
DGL_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](DGLArgs args, DGLRetValue *ret) {
    DGLContext ctx;
412
    ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
Minjie Wang's avatar
Minjie Wang committed
413
414
415
416
417
418
419
420
421
422
423
424
425
426
    ctx.device_id = args[1];

    DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
    if (kind == kExist) {
      DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
      if (api != nullptr) {
        api->GetAttr(ctx, kind, ret);
      } else {
        *ret = 0;
      }
    } else {
      DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
    }
  });
427