c_runtime_api.cc 11.5 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
/*!
2
 *  Copyright (c) 2016-2022 by Contributors
3
4
 * @file c_runtime_api.cc
 * @brief Runtime API implementation
Minjie Wang's avatar
Minjie Wang committed
5
6
 */
#include <dgl/runtime/c_backend_api.h>
7
8
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
Minjie Wang's avatar
Minjie Wang committed
9
#include <dgl/runtime/module.h>
10
#include <dgl/runtime/packed_func.h>
Minjie Wang's avatar
Minjie Wang committed
11
#include <dgl/runtime/registry.h>
12
#include <dgl/runtime/tensordispatch.h>
13
14
#include <dmlc/thread_local.h>

Minjie Wang's avatar
Minjie Wang committed
15
#include <algorithm>
16
#include <array>
Minjie Wang's avatar
Minjie Wang committed
17
#include <cstdlib>
18
19
#include <string>

Minjie Wang's avatar
Minjie Wang committed
20
21
#include "runtime_base.h"

22
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
23
24
25
namespace runtime {

/*!
26
27
 * @brief The name of Device API factory.
 * @param type The device type.
Minjie Wang's avatar
Minjie Wang committed
28
29
30
 */
inline std::string DeviceName(int type) {
  switch (type) {
31
32
33
34
    case kDGLCPU:
      return "cpu";
    case kDGLCUDA:
      return "cuda";
35
    // add more device here once supported
36
37
38
    default:
      LOG(FATAL) << "unknown type =" << type;
      return "Unknown";
Minjie Wang's avatar
Minjie Wang committed
39
40
41
42
43
44
45
  }
}

class DeviceAPIManager {
 public:
  static const int kMaxDeviceAPI = 32;
  // Get API
46
  static DeviceAPI* Get(const DGLContext& ctx) { return Get(ctx.device_type); }
Minjie Wang's avatar
Minjie Wang committed
47
48
49
50
51
52
53
54
55
  static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
    return Global()->GetAPI(dev_type, allow_missing);
  }

 private:
  std::array<DeviceAPI*, kMaxDeviceAPI> api_;
  DeviceAPI* rpc_api_{nullptr};
  std::mutex mutex_;
  // constructor
56
  DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }
Minjie Wang's avatar
Minjie Wang committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
  // Global static variable.
  static DeviceAPIManager* Global() {
    static DeviceAPIManager inst;
    return &inst;
  }
  // Get or initialize API.
  DeviceAPI* GetAPI(int type, bool allow_missing) {
    if (type < kRPCSessMask) {
      if (api_[type] != nullptr) return api_[type];
      std::lock_guard<std::mutex> lock(mutex_);
      if (api_[type] != nullptr) return api_[type];
      api_[type] = GetAPI(DeviceName(type), allow_missing);
      return api_[type];
    } else {
      if (rpc_api_ != nullptr) return rpc_api_;
      std::lock_guard<std::mutex> lock(mutex_);
      if (rpc_api_ != nullptr) return rpc_api_;
      rpc_api_ = GetAPI("rpc", allow_missing);
      return rpc_api_;
    }
  }
  DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
    std::string factory = "device_api." + name;
    auto* f = Registry::Get(factory);
    if (f == nullptr) {
      CHECK(allow_missing)
83
84
          << "Device API " << name
          << " is not enabled. Please install the cuda version of dgl.";
Minjie Wang's avatar
Minjie Wang committed
85
86
87
88
89
90
91
      return nullptr;
    }
    void* ptr = (*f)();
    return static_cast<DeviceAPI*>(ptr);
  }
};

92
DeviceAPI* DeviceAPI::Get(DGLContext ctx, bool allow_missing) {
Minjie Wang's avatar
Minjie Wang committed
93
94
95
96
  return DeviceAPIManager::Get(
      static_cast<int>(ctx.device_type), allow_missing);
}

97
DeviceAPI* DeviceAPI::Get(DGLDeviceType dev_type, bool allow_missing) {
98
99
100
  return DeviceAPIManager::Get(static_cast<int>(dev_type), allow_missing);
}

101
102
void* DeviceAPI::AllocWorkspace(
    DGLContext ctx, size_t size, DGLDataType type_hint) {
Minjie Wang's avatar
Minjie Wang committed
103
104
105
  return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}

106
void DeviceAPI::FreeWorkspace(DGLContext ctx, void* ptr) {
Minjie Wang's avatar
Minjie Wang committed
107
108
109
  FreeDataSpace(ctx, ptr);
}

110
DGLStreamHandle DeviceAPI::CreateStream(DGLContext ctx) {
Minjie Wang's avatar
Minjie Wang committed
111
112
113
114
  LOG(FATAL) << "Device does not support stream api.";
  return 0;
}

115
void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
116
117
118
  LOG(FATAL) << "Device does not support stream api.";
}

119
120
void DeviceAPI::SyncStreamFromTo(
    DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {
Minjie Wang's avatar
Minjie Wang committed
121
122
  LOG(FATAL) << "Device does not support stream api.";
}
123

124
void DeviceAPI::PinData(void* ptr, size_t nbytes) {
125
126
127
  LOG(FATAL) << "Device does not support cudaHostRegister api.";
}

128
void DeviceAPI::UnpinData(void* ptr) {
129
130
  LOG(FATAL) << "Device does not support cudaHostUnregister api.";
}
Minjie Wang's avatar
Minjie Wang committed
131
}  // namespace runtime
132
}  // namespace dgl
Minjie Wang's avatar
Minjie Wang committed
133

134
using namespace dgl::runtime;
Minjie Wang's avatar
Minjie Wang committed
135

136
struct DGLRuntimeEntry {
Minjie Wang's avatar
Minjie Wang committed
137
138
  std::string ret_str;
  std::string last_error;
139
  DGLByteArray ret_bytes;
Minjie Wang's avatar
Minjie Wang committed
140
141
};

142
typedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore;
Minjie Wang's avatar
Minjie Wang committed
143

144
const char* DGLGetLastError() {
145
  return DGLAPIRuntimeStore::Get()->last_error.c_str();
Minjie Wang's avatar
Minjie Wang committed
146
147
}

148
void DGLAPISetLastError(const char* msg) {
Minjie Wang's avatar
Minjie Wang committed
149
#ifndef _LIBCPP_SGX_CONFIG
150
  DGLAPIRuntimeStore::Get()->last_error = msg;
Minjie Wang's avatar
Minjie Wang committed
151
152
153
154
155
#else
  sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif
}

156
157
int DGLModLoadFromFile(
    const char* file_name, const char* format, DGLModuleHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
158
159
160
161
162
163
  API_BEGIN();
  Module m = Module::LoadFromFile(file_name, format);
  *out = new Module(m);
  API_END();
}

164
int DGLModImport(DGLModuleHandle mod, DGLModuleHandle dep) {
Minjie Wang's avatar
Minjie Wang committed
165
  API_BEGIN();
166
  static_cast<Module*>(mod)->Import(*static_cast<Module*>(dep));
Minjie Wang's avatar
Minjie Wang committed
167
168
169
  API_END();
}

170
171
172
int DGLModGetFunction(
    DGLModuleHandle mod, const char* func_name, int query_imports,
    DGLFunctionHandle* func) {
Minjie Wang's avatar
Minjie Wang committed
173
  API_BEGIN();
174
175
  PackedFunc pf =
      static_cast<Module*>(mod)->GetFunction(func_name, query_imports != 0);
Minjie Wang's avatar
Minjie Wang committed
176
177
178
179
180
181
182
183
  if (pf != nullptr) {
    *func = new PackedFunc(pf);
  } else {
    *func = nullptr;
  }
  API_END();
}

184
int DGLModFree(DGLModuleHandle mod) {
Minjie Wang's avatar
Minjie Wang committed
185
186
187
188
189
  API_BEGIN();
  delete static_cast<Module*>(mod);
  API_END();
}

190
191
int DGLBackendGetFuncFromEnv(
    void* mod_node, const char* func_name, DGLFunctionHandle* func) {
Minjie Wang's avatar
Minjie Wang committed
192
  API_BEGIN();
193
194
195
  *func =
      (DGLFunctionHandle)(static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(
          func_name));
Minjie Wang's avatar
Minjie Wang committed
196
197
198
  API_END();
}

199
200
201
void* DGLBackendAllocWorkspace(
    int device_type, int device_id, uint64_t size, int dtype_code_hint,
    int dtype_bits_hint) {
202
  DGLContext ctx;
203
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
204
205
  ctx.device_id = device_id;

206
  DGLDataType type_hint;
Minjie Wang's avatar
Minjie Wang committed
207
208
209
210
  type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
  type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
  type_hint.lanes = 1;

211
212
  return DeviceAPIManager::Get(ctx)->AllocWorkspace(
      ctx, static_cast<size_t>(size), type_hint);
Minjie Wang's avatar
Minjie Wang committed
213
214
}

215
int DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
216
  DGLContext ctx;
217
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
218
219
220
221
222
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
  return 0;
}

223
int DGLBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) {
Minjie Wang's avatar
Minjie Wang committed
224
225
226
227
228
229
230
  if (*handle == nullptr) {
    *handle = reinterpret_cast<void*>(1);
    return (*f)(cdata);
  }
  return 0;
}

231
int DGLFuncFree(DGLFunctionHandle func) {
Minjie Wang's avatar
Minjie Wang committed
232
233
234
235
236
  API_BEGIN();
  delete static_cast<PackedFunc*>(func);
  API_END();
}

237
238
239
int DGLFuncCall(
    DGLFunctionHandle func, DGLValue* args, int* arg_type_codes, int num_args,
    DGLValue* ret_val, int* ret_type_code) {
Minjie Wang's avatar
Minjie Wang committed
240
  API_BEGIN();
241
  DGLRetValue rv;
242
243
  (*static_cast<const PackedFunc*>(func))
      .CallPacked(DGLArgs(args, arg_type_codes, num_args), &rv);
Minjie Wang's avatar
Minjie Wang committed
244
  // handle return string.
245
  if (rv.type_code() == kStr || rv.type_code() == kDGLDataType ||
Minjie Wang's avatar
Minjie Wang committed
246
      rv.type_code() == kBytes) {
247
    DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get();
248
    if (rv.type_code() != kDGLDataType) {
Minjie Wang's avatar
Minjie Wang committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
      e->ret_str = *rv.ptr<std::string>();
    } else {
      e->ret_str = rv.operator std::string();
    }
    if (rv.type_code() == kBytes) {
      e->ret_bytes.data = e->ret_str.c_str();
      e->ret_bytes.size = e->ret_str.length();
      *ret_type_code = kBytes;
      ret_val->v_handle = &(e->ret_bytes);
    } else {
      *ret_type_code = kStr;
      ret_val->v_str = e->ret_str.c_str();
    }
  } else {
    rv.MoveToCHost(ret_val, ret_type_code);
  }
  API_END();
}

268
269
int DGLCFuncSetReturn(
    DGLRetValueHandle ret, DGLValue* value, int* type_code, int num_ret) {
Minjie Wang's avatar
Minjie Wang committed
270
271
  API_BEGIN();
  CHECK_EQ(num_ret, 1);
272
273
  DGLRetValue* rv = static_cast<DGLRetValue*>(ret);
  *rv = DGLArgValue(value[0], type_code[0]);
Minjie Wang's avatar
Minjie Wang committed
274
275
276
  API_END();
}

277
278
279
int DGLFuncCreateFromCFunc(
    DGLPackedCFunc func, void* resource_handle, DGLPackedCFuncFinalizer fin,
    DGLFunctionHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
280
281
  API_BEGIN();
  if (fin == nullptr) {
282
283
284
285
286
    *out =
        new PackedFunc([func, resource_handle](DGLArgs args, DGLRetValue* rv) {
          int ret = func(
              (DGLValue*)args.values, (int*)args.type_codes,  // NOLINT(*)
              args.num_args, rv, resource_handle);
Minjie Wang's avatar
Minjie Wang committed
287
          if (ret != 0) {
288
289
            std::string err = "DGLCall CFunc Error:\n";
            err += DGLGetLastError();
Minjie Wang's avatar
Minjie Wang committed
290
291
292
293
294
295
296
            throw dmlc::Error(err);
          }
        });
  } else {
    // wrap it in a shared_ptr, with fin as deleter.
    // so fin will be called when the lambda went out of scope.
    std::shared_ptr<void> rpack(resource_handle, fin);
297
298
299
300
301
302
303
304
305
306
    *out = new PackedFunc([func, rpack](DGLArgs args, DGLRetValue* rv) {
      int ret = func(
          (DGLValue*)args.values, (int*)args.type_codes,  // NOLINT(*)
          args.num_args, rv, rpack.get());
      if (ret != 0) {
        std::string err = "DGLCall CFunc Error:\n";
        err += DGLGetLastError();
        throw dmlc::Error(err);
      }
    });
Minjie Wang's avatar
Minjie Wang committed
307
308
309
310
  }
  API_END();
}

311
int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
312
  API_BEGIN();
313
  DGLContext ctx;
314
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
315
316
317
318
319
  ctx.device_id = device_id;
  *out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
  API_END();
}

320
int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
321
  API_BEGIN();
322
  DGLContext ctx;
323
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
324
325
326
327
328
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
  API_END();
}

329
int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
330
  API_BEGIN();
331
  DGLContext ctx;
332
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
333
334
335
336
337
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
  API_END();
}

338
339
340
int DGLGetStream(int device_type, int device_id, DGLStreamHandle* stream) {
  API_BEGIN();
  DGLContext ctx;
341
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
342
343
344
345
346
  ctx.device_id = device_id;
  *stream = DeviceAPIManager::Get(ctx)->GetStream();
  API_END();
}

347
int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
348
  API_BEGIN();
349
  DGLContext ctx;
350
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
351
352
353
354
355
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
  API_END();
}

356
357
int DGLStreamStreamSynchronize(
    int device_type, int device_id, DGLStreamHandle src, DGLStreamHandle dst) {
Minjie Wang's avatar
Minjie Wang committed
358
  API_BEGIN();
359
  DGLContext ctx;
360
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
361
362
363
364
365
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
  API_END();
}

366
int DGLCbArgToReturn(DGLValue* value, int code) {
Minjie Wang's avatar
Minjie Wang committed
367
  API_BEGIN();
368
369
  dgl::runtime::DGLRetValue rv;
  rv = dgl::runtime::DGLArgValue(*value, code);
Minjie Wang's avatar
Minjie Wang committed
370
371
372
373
374
375
  int tcode;
  rv.MoveToCHost(value, &tcode);
  CHECK_EQ(tcode, code);
  API_END();
}

376
int DGLLoadTensorAdapter(const char* path) {
377
  return TensorDispatcher::Global()->Load(path) ? 0 : -1;
378
379
}

Minjie Wang's avatar
Minjie Wang committed
380
// set device api
381
DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
382
383
384
385
386
387
    .set_body([](DGLArgs args, DGLRetValue* ret) {
      DGLContext ctx;
      ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
      ctx.device_id = args[1];
      DeviceAPIManager::Get(ctx)->SetDevice(ctx);
    });
Minjie Wang's avatar
Minjie Wang committed
388
389

// set device api
390
DGL_REGISTER_GLOBAL("_GetDeviceAttr")
391
392
393
394
395
396
397
398
399
400
401
402
403
    .set_body([](DGLArgs args, DGLRetValue* ret) {
      DGLContext ctx;
      ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
      ctx.device_id = args[1];

      DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
      if (kind == kExist) {
        DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
        if (api != nullptr) {
          api->GetAttr(ctx, kind, ret);
        } else {
          *ret = 0;
        }
Minjie Wang's avatar
Minjie Wang committed
404
      } else {
405
        DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
Minjie Wang's avatar
Minjie Wang committed
406
      }
407
    });