c_runtime_api.cc 11.8 KB
Newer Older
1
/**
2
 *  Copyright (c) 2016-2022 by Contributors
3
4
 * @file c_runtime_api.cc
 * @brief Runtime API implementation
Minjie Wang's avatar
Minjie Wang committed
5
6
 */
#include <dgl/runtime/c_backend_api.h>
7
8
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
Minjie Wang's avatar
Minjie Wang committed
9
#include <dgl/runtime/module.h>
10
#include <dgl/runtime/packed_func.h>
Minjie Wang's avatar
Minjie Wang committed
11
#include <dgl/runtime/registry.h>
12
#include <dgl/runtime/tensordispatch.h>
13
14
#include <dmlc/thread_local.h>

Minjie Wang's avatar
Minjie Wang committed
15
#include <algorithm>
16
#include <array>
Minjie Wang's avatar
Minjie Wang committed
17
#include <cstdlib>
18
19
#include <string>

Minjie Wang's avatar
Minjie Wang committed
20
21
#include "runtime_base.h"

22
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
23
24
namespace runtime {

25
/**
26
27
 * @brief The name of Device API factory.
 * @param type The device type.
Minjie Wang's avatar
Minjie Wang committed
28
29
30
 */
inline std::string DeviceName(int type) {
  switch (type) {
31
32
33
34
    case kDGLCPU:
      return "cpu";
    case kDGLCUDA:
      return "cuda";
35
    // add more device here once supported
36
37
38
    default:
      LOG(FATAL) << "unknown type =" << type;
      return "Unknown";
Minjie Wang's avatar
Minjie Wang committed
39
40
41
42
43
44
45
  }
}

class DeviceAPIManager {
 public:
  static const int kMaxDeviceAPI = 32;
  // Get API
46
  static DeviceAPI* Get(const DGLContext& ctx) { return Get(ctx.device_type); }
Minjie Wang's avatar
Minjie Wang committed
47
48
49
50
51
52
53
54
55
  static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
    return Global()->GetAPI(dev_type, allow_missing);
  }

 private:
  std::array<DeviceAPI*, kMaxDeviceAPI> api_;
  DeviceAPI* rpc_api_{nullptr};
  std::mutex mutex_;
  // constructor
56
  DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }
Minjie Wang's avatar
Minjie Wang committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
  // Global static variable.
  static DeviceAPIManager* Global() {
    static DeviceAPIManager inst;
    return &inst;
  }
  // Get or initialize API.
  DeviceAPI* GetAPI(int type, bool allow_missing) {
    if (type < kRPCSessMask) {
      if (api_[type] != nullptr) return api_[type];
      std::lock_guard<std::mutex> lock(mutex_);
      if (api_[type] != nullptr) return api_[type];
      api_[type] = GetAPI(DeviceName(type), allow_missing);
      return api_[type];
    } else {
      if (rpc_api_ != nullptr) return rpc_api_;
      std::lock_guard<std::mutex> lock(mutex_);
      if (rpc_api_ != nullptr) return rpc_api_;
      rpc_api_ = GetAPI("rpc", allow_missing);
      return rpc_api_;
    }
  }
  DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
    std::string factory = "device_api." + name;
    auto* f = Registry::Get(factory);
    if (f == nullptr) {
      CHECK(allow_missing)
83
84
          << "Device API " << name
          << " is not enabled. Please install the cuda version of dgl.";
Minjie Wang's avatar
Minjie Wang committed
85
86
87
88
89
90
91
      return nullptr;
    }
    void* ptr = (*f)();
    return static_cast<DeviceAPI*>(ptr);
  }
};

92
DeviceAPI* DeviceAPI::Get(DGLContext ctx, bool allow_missing) {
Minjie Wang's avatar
Minjie Wang committed
93
94
95
96
  return DeviceAPIManager::Get(
      static_cast<int>(ctx.device_type), allow_missing);
}

97
DeviceAPI* DeviceAPI::Get(DGLDeviceType dev_type, bool allow_missing) {
98
99
100
  return DeviceAPIManager::Get(static_cast<int>(dev_type), allow_missing);
}

101
102
void* DeviceAPI::AllocWorkspace(
    DGLContext ctx, size_t size, DGLDataType type_hint) {
Minjie Wang's avatar
Minjie Wang committed
103
104
105
  return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}

106
void DeviceAPI::FreeWorkspace(DGLContext ctx, void* ptr) {
Minjie Wang's avatar
Minjie Wang committed
107
108
109
  FreeDataSpace(ctx, ptr);
}

110
DGLStreamHandle DeviceAPI::CreateStream(DGLContext ctx) {
Minjie Wang's avatar
Minjie Wang committed
111
112
113
114
  LOG(FATAL) << "Device does not support stream api.";
  return 0;
}

115
void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
116
117
118
  LOG(FATAL) << "Device does not support stream api.";
}

119
120
void DeviceAPI::SyncStreamFromTo(
    DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {
Minjie Wang's avatar
Minjie Wang committed
121
122
  LOG(FATAL) << "Device does not support stream api.";
}
123

124
bool DeviceAPI::PinData(void* ptr, size_t nbytes) {
125
  LOG(FATAL) << "Device does not support cudaHostRegister api.";
126
  return false;
127
128
}

129
130
131
132
133
134
135
136
137
138
void* DeviceAPI::AllocPinnedDataSpace(
    size_t nbytes, void** ctx, void** deleter) {
  LOG(FATAL) << "Device does not support cudaHostAlloc api.";
  return nullptr;
}

void DeviceAPI::FreePinnedDataSpace(void** deleter) {
  LOG(FATAL) << "Device does not support cudaHostFree api.";
}

139
void DeviceAPI::UnpinData(void* ptr) {
140
141
  LOG(FATAL) << "Device does not support cudaHostUnregister api.";
}
Minjie Wang's avatar
Minjie Wang committed
142
}  // namespace runtime
143
}  // namespace dgl
Minjie Wang's avatar
Minjie Wang committed
144

145
using namespace dgl::runtime;
Minjie Wang's avatar
Minjie Wang committed
146

147
struct DGLRuntimeEntry {
Minjie Wang's avatar
Minjie Wang committed
148
149
  std::string ret_str;
  std::string last_error;
150
  DGLByteArray ret_bytes;
Minjie Wang's avatar
Minjie Wang committed
151
152
};

153
typedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore;
Minjie Wang's avatar
Minjie Wang committed
154

155
const char* DGLGetLastError() {
156
  return DGLAPIRuntimeStore::Get()->last_error.c_str();
Minjie Wang's avatar
Minjie Wang committed
157
158
}

159
void DGLAPISetLastError(const char* msg) {
Minjie Wang's avatar
Minjie Wang committed
160
#ifndef _LIBCPP_SGX_CONFIG
161
  DGLAPIRuntimeStore::Get()->last_error = msg;
Minjie Wang's avatar
Minjie Wang committed
162
163
164
165
166
#else
  sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif
}

167
168
int DGLModLoadFromFile(
    const char* file_name, const char* format, DGLModuleHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
169
170
171
172
173
174
  API_BEGIN();
  Module m = Module::LoadFromFile(file_name, format);
  *out = new Module(m);
  API_END();
}

175
int DGLModImport(DGLModuleHandle mod, DGLModuleHandle dep) {
Minjie Wang's avatar
Minjie Wang committed
176
  API_BEGIN();
177
  static_cast<Module*>(mod)->Import(*static_cast<Module*>(dep));
Minjie Wang's avatar
Minjie Wang committed
178
179
180
  API_END();
}

181
182
183
int DGLModGetFunction(
    DGLModuleHandle mod, const char* func_name, int query_imports,
    DGLFunctionHandle* func) {
Minjie Wang's avatar
Minjie Wang committed
184
  API_BEGIN();
185
186
  PackedFunc pf =
      static_cast<Module*>(mod)->GetFunction(func_name, query_imports != 0);
Minjie Wang's avatar
Minjie Wang committed
187
188
189
190
191
192
193
194
  if (pf != nullptr) {
    *func = new PackedFunc(pf);
  } else {
    *func = nullptr;
  }
  API_END();
}

195
int DGLModFree(DGLModuleHandle mod) {
Minjie Wang's avatar
Minjie Wang committed
196
197
198
199
200
  API_BEGIN();
  delete static_cast<Module*>(mod);
  API_END();
}

201
202
int DGLBackendGetFuncFromEnv(
    void* mod_node, const char* func_name, DGLFunctionHandle* func) {
Minjie Wang's avatar
Minjie Wang committed
203
  API_BEGIN();
204
205
206
  *func =
      (DGLFunctionHandle)(static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(
          func_name));
Minjie Wang's avatar
Minjie Wang committed
207
208
209
  API_END();
}

210
211
212
void* DGLBackendAllocWorkspace(
    int device_type, int device_id, uint64_t size, int dtype_code_hint,
    int dtype_bits_hint) {
213
  DGLContext ctx;
214
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
215
216
  ctx.device_id = device_id;

217
  DGLDataType type_hint;
Minjie Wang's avatar
Minjie Wang committed
218
219
220
221
  type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
  type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
  type_hint.lanes = 1;

222
223
  return DeviceAPIManager::Get(ctx)->AllocWorkspace(
      ctx, static_cast<size_t>(size), type_hint);
Minjie Wang's avatar
Minjie Wang committed
224
225
}

226
int DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
227
  DGLContext ctx;
228
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
229
230
231
232
233
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
  return 0;
}

234
int DGLBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) {
Minjie Wang's avatar
Minjie Wang committed
235
236
237
238
239
240
241
  if (*handle == nullptr) {
    *handle = reinterpret_cast<void*>(1);
    return (*f)(cdata);
  }
  return 0;
}

242
int DGLFuncFree(DGLFunctionHandle func) {
Minjie Wang's avatar
Minjie Wang committed
243
244
245
246
247
  API_BEGIN();
  delete static_cast<PackedFunc*>(func);
  API_END();
}

248
249
250
int DGLFuncCall(
    DGLFunctionHandle func, DGLValue* args, int* arg_type_codes, int num_args,
    DGLValue* ret_val, int* ret_type_code) {
Minjie Wang's avatar
Minjie Wang committed
251
  API_BEGIN();
252
  DGLRetValue rv;
253
254
  (*static_cast<const PackedFunc*>(func))
      .CallPacked(DGLArgs(args, arg_type_codes, num_args), &rv);
Minjie Wang's avatar
Minjie Wang committed
255
  // handle return string.
256
  if (rv.type_code() == kStr || rv.type_code() == kDGLDataType ||
Minjie Wang's avatar
Minjie Wang committed
257
      rv.type_code() == kBytes) {
258
    DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get();
259
    if (rv.type_code() != kDGLDataType) {
Minjie Wang's avatar
Minjie Wang committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
      e->ret_str = *rv.ptr<std::string>();
    } else {
      e->ret_str = rv.operator std::string();
    }
    if (rv.type_code() == kBytes) {
      e->ret_bytes.data = e->ret_str.c_str();
      e->ret_bytes.size = e->ret_str.length();
      *ret_type_code = kBytes;
      ret_val->v_handle = &(e->ret_bytes);
    } else {
      *ret_type_code = kStr;
      ret_val->v_str = e->ret_str.c_str();
    }
  } else {
    rv.MoveToCHost(ret_val, ret_type_code);
  }
  API_END();
}

279
280
int DGLCFuncSetReturn(
    DGLRetValueHandle ret, DGLValue* value, int* type_code, int num_ret) {
Minjie Wang's avatar
Minjie Wang committed
281
282
  API_BEGIN();
  CHECK_EQ(num_ret, 1);
283
284
  DGLRetValue* rv = static_cast<DGLRetValue*>(ret);
  *rv = DGLArgValue(value[0], type_code[0]);
Minjie Wang's avatar
Minjie Wang committed
285
286
287
  API_END();
}

288
289
290
int DGLFuncCreateFromCFunc(
    DGLPackedCFunc func, void* resource_handle, DGLPackedCFuncFinalizer fin,
    DGLFunctionHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
291
292
  API_BEGIN();
  if (fin == nullptr) {
293
294
295
296
297
    *out =
        new PackedFunc([func, resource_handle](DGLArgs args, DGLRetValue* rv) {
          int ret = func(
              (DGLValue*)args.values, (int*)args.type_codes,  // NOLINT(*)
              args.num_args, rv, resource_handle);
Minjie Wang's avatar
Minjie Wang committed
298
          if (ret != 0) {
299
300
            std::string err = "DGLCall CFunc Error:\n";
            err += DGLGetLastError();
Minjie Wang's avatar
Minjie Wang committed
301
302
303
304
305
306
307
            throw dmlc::Error(err);
          }
        });
  } else {
    // wrap it in a shared_ptr, with fin as deleter.
    // so fin will be called when the lambda went out of scope.
    std::shared_ptr<void> rpack(resource_handle, fin);
308
309
310
311
312
313
314
315
316
317
    *out = new PackedFunc([func, rpack](DGLArgs args, DGLRetValue* rv) {
      int ret = func(
          (DGLValue*)args.values, (int*)args.type_codes,  // NOLINT(*)
          args.num_args, rv, rpack.get());
      if (ret != 0) {
        std::string err = "DGLCall CFunc Error:\n";
        err += DGLGetLastError();
        throw dmlc::Error(err);
      }
    });
Minjie Wang's avatar
Minjie Wang committed
318
319
320
321
  }
  API_END();
}

322
int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
323
  API_BEGIN();
324
  DGLContext ctx;
325
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
326
327
328
329
330
  ctx.device_id = device_id;
  *out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
  API_END();
}

331
int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
332
  API_BEGIN();
333
  DGLContext ctx;
334
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
335
336
337
338
339
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
  API_END();
}

340
int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
341
  API_BEGIN();
342
  DGLContext ctx;
343
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
344
345
346
347
348
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
  API_END();
}

349
350
351
int DGLGetStream(int device_type, int device_id, DGLStreamHandle* stream) {
  API_BEGIN();
  DGLContext ctx;
352
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
353
354
355
356
357
  ctx.device_id = device_id;
  *stream = DeviceAPIManager::Get(ctx)->GetStream();
  API_END();
}

358
int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
359
  API_BEGIN();
360
  DGLContext ctx;
361
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
362
363
364
365
366
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
  API_END();
}

367
368
int DGLStreamStreamSynchronize(
    int device_type, int device_id, DGLStreamHandle src, DGLStreamHandle dst) {
Minjie Wang's avatar
Minjie Wang committed
369
  API_BEGIN();
370
  DGLContext ctx;
371
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
372
373
374
375
376
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
  API_END();
}

377
int DGLCbArgToReturn(DGLValue* value, int code) {
Minjie Wang's avatar
Minjie Wang committed
378
  API_BEGIN();
379
380
  dgl::runtime::DGLRetValue rv;
  rv = dgl::runtime::DGLArgValue(*value, code);
Minjie Wang's avatar
Minjie Wang committed
381
382
383
384
385
386
  int tcode;
  rv.MoveToCHost(value, &tcode);
  CHECK_EQ(tcode, code);
  API_END();
}

387
int DGLLoadTensorAdapter(const char* path) {
388
  return TensorDispatcher::Global()->Load(path) ? 0 : -1;
389
390
}

Minjie Wang's avatar
Minjie Wang committed
391
// set device api
392
DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
393
394
395
396
397
398
    .set_body([](DGLArgs args, DGLRetValue* ret) {
      DGLContext ctx;
      ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
      ctx.device_id = args[1];
      DeviceAPIManager::Get(ctx)->SetDevice(ctx);
    });
Minjie Wang's avatar
Minjie Wang committed
399
400

// set device api
401
DGL_REGISTER_GLOBAL("_GetDeviceAttr")
402
403
404
405
406
407
408
409
410
411
412
413
414
    .set_body([](DGLArgs args, DGLRetValue* ret) {
      DGLContext ctx;
      ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
      ctx.device_id = args[1];

      DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
      if (kind == kExist) {
        DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
        if (api != nullptr) {
          api->GetAttr(ctx, kind, ret);
        } else {
          *ret = 0;
        }
Minjie Wang's avatar
Minjie Wang committed
415
      } else {
416
        DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
Minjie Wang's avatar
Minjie Wang committed
417
      }
418
    });