c_runtime_api.cc 11.5 KB
Newer Older
1
/**
2
 *  Copyright (c) 2016-2022 by Contributors
3
4
 * @file c_runtime_api.cc
 * @brief Runtime API implementation
Minjie Wang's avatar
Minjie Wang committed
5
6
 */
#include <dgl/runtime/c_backend_api.h>
7
8
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
Minjie Wang's avatar
Minjie Wang committed
9
#include <dgl/runtime/module.h>
10
#include <dgl/runtime/packed_func.h>
Minjie Wang's avatar
Minjie Wang committed
11
#include <dgl/runtime/registry.h>
12
#include <dgl/runtime/tensordispatch.h>
13
14
#include <dmlc/thread_local.h>

Minjie Wang's avatar
Minjie Wang committed
15
#include <algorithm>
16
#include <array>
Minjie Wang's avatar
Minjie Wang committed
17
#include <cstdlib>
18
19
#include <string>

Minjie Wang's avatar
Minjie Wang committed
20
21
#include "runtime_base.h"

22
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
23
24
namespace runtime {

25
/**
26
27
 * @brief The name of Device API factory.
 * @param type The device type.
Minjie Wang's avatar
Minjie Wang committed
28
29
30
 */
inline std::string DeviceName(int type) {
  switch (type) {
31
32
33
34
    case kDGLCPU:
      return "cpu";
    case kDGLCUDA:
      return "cuda";
35
    // add more device here once supported
36
37
38
    default:
      LOG(FATAL) << "unknown type =" << type;
      return "Unknown";
Minjie Wang's avatar
Minjie Wang committed
39
40
41
42
43
44
45
  }
}

class DeviceAPIManager {
 public:
  static const int kMaxDeviceAPI = 32;
  // Get API
46
  static DeviceAPI* Get(const DGLContext& ctx) { return Get(ctx.device_type); }
Minjie Wang's avatar
Minjie Wang committed
47
48
49
50
51
52
53
54
55
  static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
    return Global()->GetAPI(dev_type, allow_missing);
  }

 private:
  std::array<DeviceAPI*, kMaxDeviceAPI> api_;
  DeviceAPI* rpc_api_{nullptr};
  std::mutex mutex_;
  // constructor
56
  DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }
Minjie Wang's avatar
Minjie Wang committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
  // Global static variable.
  static DeviceAPIManager* Global() {
    static DeviceAPIManager inst;
    return &inst;
  }
  // Get or initialize API.
  DeviceAPI* GetAPI(int type, bool allow_missing) {
    if (type < kRPCSessMask) {
      if (api_[type] != nullptr) return api_[type];
      std::lock_guard<std::mutex> lock(mutex_);
      if (api_[type] != nullptr) return api_[type];
      api_[type] = GetAPI(DeviceName(type), allow_missing);
      return api_[type];
    } else {
      if (rpc_api_ != nullptr) return rpc_api_;
      std::lock_guard<std::mutex> lock(mutex_);
      if (rpc_api_ != nullptr) return rpc_api_;
      rpc_api_ = GetAPI("rpc", allow_missing);
      return rpc_api_;
    }
  }
  DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
    std::string factory = "device_api." + name;
    auto* f = Registry::Get(factory);
    if (f == nullptr) {
      CHECK(allow_missing)
83
84
          << "Device API " << name
          << " is not enabled. Please install the cuda version of dgl.";
Minjie Wang's avatar
Minjie Wang committed
85
86
87
88
89
90
91
      return nullptr;
    }
    void* ptr = (*f)();
    return static_cast<DeviceAPI*>(ptr);
  }
};

92
DeviceAPI* DeviceAPI::Get(DGLContext ctx, bool allow_missing) {
Minjie Wang's avatar
Minjie Wang committed
93
94
95
96
  return DeviceAPIManager::Get(
      static_cast<int>(ctx.device_type), allow_missing);
}

97
DeviceAPI* DeviceAPI::Get(DGLDeviceType dev_type, bool allow_missing) {
98
99
100
  return DeviceAPIManager::Get(static_cast<int>(dev_type), allow_missing);
}

101
102
void* DeviceAPI::AllocWorkspace(
    DGLContext ctx, size_t size, DGLDataType type_hint) {
Minjie Wang's avatar
Minjie Wang committed
103
104
105
  return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}

106
void DeviceAPI::FreeWorkspace(DGLContext ctx, void* ptr) {
Minjie Wang's avatar
Minjie Wang committed
107
108
109
  FreeDataSpace(ctx, ptr);
}

110
DGLStreamHandle DeviceAPI::CreateStream(DGLContext ctx) {
Minjie Wang's avatar
Minjie Wang committed
111
112
113
114
  LOG(FATAL) << "Device does not support stream api.";
  return 0;
}

115
void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
116
117
118
  LOG(FATAL) << "Device does not support stream api.";
}

119
120
void DeviceAPI::SyncStreamFromTo(
    DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {
Minjie Wang's avatar
Minjie Wang committed
121
122
  LOG(FATAL) << "Device does not support stream api.";
}
123

124
bool DeviceAPI::PinData(void* ptr, size_t nbytes) {
125
  LOG(FATAL) << "Device does not support cudaHostRegister api.";
126
  return false;
127
128
}

129
void DeviceAPI::UnpinData(void* ptr) {
130
131
  LOG(FATAL) << "Device does not support cudaHostUnregister api.";
}
Minjie Wang's avatar
Minjie Wang committed
132
}  // namespace runtime
133
}  // namespace dgl
Minjie Wang's avatar
Minjie Wang committed
134

135
using namespace dgl::runtime;
Minjie Wang's avatar
Minjie Wang committed
136

137
struct DGLRuntimeEntry {
Minjie Wang's avatar
Minjie Wang committed
138
139
  std::string ret_str;
  std::string last_error;
140
  DGLByteArray ret_bytes;
Minjie Wang's avatar
Minjie Wang committed
141
142
};

143
typedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore;
Minjie Wang's avatar
Minjie Wang committed
144

145
const char* DGLGetLastError() {
146
  return DGLAPIRuntimeStore::Get()->last_error.c_str();
Minjie Wang's avatar
Minjie Wang committed
147
148
}

149
void DGLAPISetLastError(const char* msg) {
Minjie Wang's avatar
Minjie Wang committed
150
#ifndef _LIBCPP_SGX_CONFIG
151
  DGLAPIRuntimeStore::Get()->last_error = msg;
Minjie Wang's avatar
Minjie Wang committed
152
153
154
155
156
#else
  sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif
}

157
158
int DGLModLoadFromFile(
    const char* file_name, const char* format, DGLModuleHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
159
160
161
162
163
164
  API_BEGIN();
  Module m = Module::LoadFromFile(file_name, format);
  *out = new Module(m);
  API_END();
}

165
int DGLModImport(DGLModuleHandle mod, DGLModuleHandle dep) {
Minjie Wang's avatar
Minjie Wang committed
166
  API_BEGIN();
167
  static_cast<Module*>(mod)->Import(*static_cast<Module*>(dep));
Minjie Wang's avatar
Minjie Wang committed
168
169
170
  API_END();
}

171
172
173
int DGLModGetFunction(
    DGLModuleHandle mod, const char* func_name, int query_imports,
    DGLFunctionHandle* func) {
Minjie Wang's avatar
Minjie Wang committed
174
  API_BEGIN();
175
176
  PackedFunc pf =
      static_cast<Module*>(mod)->GetFunction(func_name, query_imports != 0);
Minjie Wang's avatar
Minjie Wang committed
177
178
179
180
181
182
183
184
  if (pf != nullptr) {
    *func = new PackedFunc(pf);
  } else {
    *func = nullptr;
  }
  API_END();
}

185
int DGLModFree(DGLModuleHandle mod) {
Minjie Wang's avatar
Minjie Wang committed
186
187
188
189
190
  API_BEGIN();
  delete static_cast<Module*>(mod);
  API_END();
}

191
192
int DGLBackendGetFuncFromEnv(
    void* mod_node, const char* func_name, DGLFunctionHandle* func) {
Minjie Wang's avatar
Minjie Wang committed
193
  API_BEGIN();
194
195
196
  *func =
      (DGLFunctionHandle)(static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(
          func_name));
Minjie Wang's avatar
Minjie Wang committed
197
198
199
  API_END();
}

200
201
202
void* DGLBackendAllocWorkspace(
    int device_type, int device_id, uint64_t size, int dtype_code_hint,
    int dtype_bits_hint) {
203
  DGLContext ctx;
204
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
205
206
  ctx.device_id = device_id;

207
  DGLDataType type_hint;
Minjie Wang's avatar
Minjie Wang committed
208
209
210
211
  type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
  type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
  type_hint.lanes = 1;

212
213
  return DeviceAPIManager::Get(ctx)->AllocWorkspace(
      ctx, static_cast<size_t>(size), type_hint);
Minjie Wang's avatar
Minjie Wang committed
214
215
}

216
int DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
217
  DGLContext ctx;
218
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
219
220
221
222
223
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
  return 0;
}

224
int DGLBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) {
Minjie Wang's avatar
Minjie Wang committed
225
226
227
228
229
230
231
  if (*handle == nullptr) {
    *handle = reinterpret_cast<void*>(1);
    return (*f)(cdata);
  }
  return 0;
}

232
int DGLFuncFree(DGLFunctionHandle func) {
Minjie Wang's avatar
Minjie Wang committed
233
234
235
236
237
  API_BEGIN();
  delete static_cast<PackedFunc*>(func);
  API_END();
}

238
239
240
int DGLFuncCall(
    DGLFunctionHandle func, DGLValue* args, int* arg_type_codes, int num_args,
    DGLValue* ret_val, int* ret_type_code) {
Minjie Wang's avatar
Minjie Wang committed
241
  API_BEGIN();
242
  DGLRetValue rv;
243
244
  (*static_cast<const PackedFunc*>(func))
      .CallPacked(DGLArgs(args, arg_type_codes, num_args), &rv);
Minjie Wang's avatar
Minjie Wang committed
245
  // handle return string.
246
  if (rv.type_code() == kStr || rv.type_code() == kDGLDataType ||
Minjie Wang's avatar
Minjie Wang committed
247
      rv.type_code() == kBytes) {
248
    DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get();
249
    if (rv.type_code() != kDGLDataType) {
Minjie Wang's avatar
Minjie Wang committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
      e->ret_str = *rv.ptr<std::string>();
    } else {
      e->ret_str = rv.operator std::string();
    }
    if (rv.type_code() == kBytes) {
      e->ret_bytes.data = e->ret_str.c_str();
      e->ret_bytes.size = e->ret_str.length();
      *ret_type_code = kBytes;
      ret_val->v_handle = &(e->ret_bytes);
    } else {
      *ret_type_code = kStr;
      ret_val->v_str = e->ret_str.c_str();
    }
  } else {
    rv.MoveToCHost(ret_val, ret_type_code);
  }
  API_END();
}

269
270
int DGLCFuncSetReturn(
    DGLRetValueHandle ret, DGLValue* value, int* type_code, int num_ret) {
Minjie Wang's avatar
Minjie Wang committed
271
272
  API_BEGIN();
  CHECK_EQ(num_ret, 1);
273
274
  DGLRetValue* rv = static_cast<DGLRetValue*>(ret);
  *rv = DGLArgValue(value[0], type_code[0]);
Minjie Wang's avatar
Minjie Wang committed
275
276
277
  API_END();
}

278
279
280
int DGLFuncCreateFromCFunc(
    DGLPackedCFunc func, void* resource_handle, DGLPackedCFuncFinalizer fin,
    DGLFunctionHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
281
282
  API_BEGIN();
  if (fin == nullptr) {
283
284
285
286
287
    *out =
        new PackedFunc([func, resource_handle](DGLArgs args, DGLRetValue* rv) {
          int ret = func(
              (DGLValue*)args.values, (int*)args.type_codes,  // NOLINT(*)
              args.num_args, rv, resource_handle);
Minjie Wang's avatar
Minjie Wang committed
288
          if (ret != 0) {
289
290
            std::string err = "DGLCall CFunc Error:\n";
            err += DGLGetLastError();
Minjie Wang's avatar
Minjie Wang committed
291
292
293
294
295
296
297
            throw dmlc::Error(err);
          }
        });
  } else {
    // wrap it in a shared_ptr, with fin as deleter.
    // so fin will be called when the lambda went out of scope.
    std::shared_ptr<void> rpack(resource_handle, fin);
298
299
300
301
302
303
304
305
306
307
    *out = new PackedFunc([func, rpack](DGLArgs args, DGLRetValue* rv) {
      int ret = func(
          (DGLValue*)args.values, (int*)args.type_codes,  // NOLINT(*)
          args.num_args, rv, rpack.get());
      if (ret != 0) {
        std::string err = "DGLCall CFunc Error:\n";
        err += DGLGetLastError();
        throw dmlc::Error(err);
      }
    });
Minjie Wang's avatar
Minjie Wang committed
308
309
310
311
  }
  API_END();
}

312
int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
313
  API_BEGIN();
314
  DGLContext ctx;
315
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
316
317
318
319
320
  ctx.device_id = device_id;
  *out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
  API_END();
}

321
int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
322
  API_BEGIN();
323
  DGLContext ctx;
324
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
325
326
327
328
329
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
  API_END();
}

330
int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
331
  API_BEGIN();
332
  DGLContext ctx;
333
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
334
335
336
337
338
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
  API_END();
}

339
340
341
int DGLGetStream(int device_type, int device_id, DGLStreamHandle* stream) {
  API_BEGIN();
  DGLContext ctx;
342
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
343
344
345
346
347
  ctx.device_id = device_id;
  *stream = DeviceAPIManager::Get(ctx)->GetStream();
  API_END();
}

348
int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
349
  API_BEGIN();
350
  DGLContext ctx;
351
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
352
353
354
355
356
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
  API_END();
}

357
358
int DGLStreamStreamSynchronize(
    int device_type, int device_id, DGLStreamHandle src, DGLStreamHandle dst) {
Minjie Wang's avatar
Minjie Wang committed
359
  API_BEGIN();
360
  DGLContext ctx;
361
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
362
363
364
365
366
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
  API_END();
}

367
int DGLCbArgToReturn(DGLValue* value, int code) {
Minjie Wang's avatar
Minjie Wang committed
368
  API_BEGIN();
369
370
  dgl::runtime::DGLRetValue rv;
  rv = dgl::runtime::DGLArgValue(*value, code);
Minjie Wang's avatar
Minjie Wang committed
371
372
373
374
375
376
  int tcode;
  rv.MoveToCHost(value, &tcode);
  CHECK_EQ(tcode, code);
  API_END();
}

377
int DGLLoadTensorAdapter(const char* path) {
378
  return TensorDispatcher::Global()->Load(path) ? 0 : -1;
379
380
}

Minjie Wang's avatar
Minjie Wang committed
381
// set device api
382
DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
383
384
385
386
387
388
    .set_body([](DGLArgs args, DGLRetValue* ret) {
      DGLContext ctx;
      ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
      ctx.device_id = args[1];
      DeviceAPIManager::Get(ctx)->SetDevice(ctx);
    });
Minjie Wang's avatar
Minjie Wang committed
389
390

// set device api
391
DGL_REGISTER_GLOBAL("_GetDeviceAttr")
392
393
394
395
396
397
398
399
400
401
402
403
404
    .set_body([](DGLArgs args, DGLRetValue* ret) {
      DGLContext ctx;
      ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
      ctx.device_id = args[1];

      DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
      if (kind == kExist) {
        DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
        if (api != nullptr) {
          api->GetAttr(ctx, kind, ret);
        } else {
          *ret = 0;
        }
Minjie Wang's avatar
Minjie Wang committed
405
      } else {
406
        DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
Minjie Wang's avatar
Minjie Wang committed
407
      }
408
    });