"src/vscode:/vscode.git/clone" did not exist on "5bacc2f5af7904749827fac15df05df0cb782fe4"
c_runtime_api.cc 11.8 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
/*!
 *  Copyright (c) 2016 by Contributors
 * \file c_runtime_api.cc
Minjie Wang's avatar
Minjie Wang committed
4
 * \brief Runtime API implementation
Minjie Wang's avatar
Minjie Wang committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
 */
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <array>
#include <algorithm>
#include <string>
#include <cstdlib>
#include "runtime_base.h"

19
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
namespace runtime {

/*!
 * \brief The name of Device API factory.
 * \param type The device type.
 */
inline std::string DeviceName(int type) {
  switch (type) {
    case kDLCPU: return "cpu";
    case kDLGPU: return "gpu";
    case kDLOpenCL: return "opencl";
    case kDLSDAccel: return "sdaccel";
    case kDLAOCL: return "aocl";
    case kDLVulkan: return "vulkan";
    case kDLMetal: return "metal";
    case kDLVPI: return "vpi";
    case kDLROCM: return "rocm";
    case kOpenGL: return "opengl";
    case kExtDev: return "ext_dev";
    default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
  }
}

class DeviceAPIManager {
 public:
  static const int kMaxDeviceAPI = 32;
  // Get API
47
  static DeviceAPI* Get(const DGLContext& ctx) {
Minjie Wang's avatar
Minjie Wang committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    return Get(ctx.device_type);
  }
  static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
    return Global()->GetAPI(dev_type, allow_missing);
  }

 private:
  std::array<DeviceAPI*, kMaxDeviceAPI> api_;
  DeviceAPI* rpc_api_{nullptr};
  std::mutex mutex_;
  // constructor
  DeviceAPIManager() {
    std::fill(api_.begin(), api_.end(), nullptr);
  }
  // Global static variable.
  static DeviceAPIManager* Global() {
    static DeviceAPIManager inst;
    return &inst;
  }
  // Get or initialize API.
  DeviceAPI* GetAPI(int type, bool allow_missing) {
    if (type < kRPCSessMask) {
      if (api_[type] != nullptr) return api_[type];
      std::lock_guard<std::mutex> lock(mutex_);
      if (api_[type] != nullptr) return api_[type];
      api_[type] = GetAPI(DeviceName(type), allow_missing);
      return api_[type];
    } else {
      if (rpc_api_ != nullptr) return rpc_api_;
      std::lock_guard<std::mutex> lock(mutex_);
      if (rpc_api_ != nullptr) return rpc_api_;
      rpc_api_ = GetAPI("rpc", allow_missing);
      return rpc_api_;
    }
  }
  DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
    std::string factory = "device_api." + name;
    auto* f = Registry::Get(factory);
    if (f == nullptr) {
      CHECK(allow_missing)
88
          << "Device API " << name << " is not enabled. Please install the cuda version of dgl.";
Minjie Wang's avatar
Minjie Wang committed
89
90
91
92
93
94
95
      return nullptr;
    }
    void* ptr = (*f)();
    return static_cast<DeviceAPI*>(ptr);
  }
};

96
DeviceAPI* DeviceAPI::Get(DGLContext ctx, bool allow_missing) {
Minjie Wang's avatar
Minjie Wang committed
97
98
99
100
  return DeviceAPIManager::Get(
      static_cast<int>(ctx.device_type), allow_missing);
}

101
void* DeviceAPI::AllocWorkspace(DGLContext ctx,
Minjie Wang's avatar
Minjie Wang committed
102
                                size_t size,
103
                                DGLType type_hint) {
Minjie Wang's avatar
Minjie Wang committed
104
105
106
  return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}

107
void DeviceAPI::FreeWorkspace(DGLContext ctx, void* ptr) {
Minjie Wang's avatar
Minjie Wang committed
108
109
110
  FreeDataSpace(ctx, ptr);
}

111
DGLStreamHandle DeviceAPI::CreateStream(DGLContext ctx) {
Minjie Wang's avatar
Minjie Wang committed
112
113
114
115
  LOG(FATAL) << "Device does not support stream api.";
  return 0;
}

116
void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
117
118
119
  LOG(FATAL) << "Device does not support stream api.";
}

120
121
122
void DeviceAPI::SyncStreamFromTo(DGLContext ctx,
                                 DGLStreamHandle event_src,
                                 DGLStreamHandle event_dst) {
Minjie Wang's avatar
Minjie Wang committed
123
124
125
  LOG(FATAL) << "Device does not support stream api.";
}
}  // namespace runtime
126
}  // namespace dgl
Minjie Wang's avatar
Minjie Wang committed
127

128
using namespace dgl::runtime;
Minjie Wang's avatar
Minjie Wang committed
129

130
struct DGLRuntimeEntry {
Minjie Wang's avatar
Minjie Wang committed
131
132
  std::string ret_str;
  std::string last_error;
133
  DGLByteArray ret_bytes;
Minjie Wang's avatar
Minjie Wang committed
134
135
};

136
typedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore;
Minjie Wang's avatar
Minjie Wang committed
137

138
139
const char *DGLGetLastError() {
  return DGLAPIRuntimeStore::Get()->last_error.c_str();
Minjie Wang's avatar
Minjie Wang committed
140
141
}

142
void DGLAPISetLastError(const char* msg) {
Minjie Wang's avatar
Minjie Wang committed
143
#ifndef _LIBCPP_SGX_CONFIG
144
  DGLAPIRuntimeStore::Get()->last_error = msg;
Minjie Wang's avatar
Minjie Wang committed
145
146
147
148
149
#else
  sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif
}

150
int DGLModLoadFromFile(const char* file_name,
Minjie Wang's avatar
Minjie Wang committed
151
                       const char* format,
152
                       DGLModuleHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
153
154
155
156
157
158
  API_BEGIN();
  Module m = Module::LoadFromFile(file_name, format);
  *out = new Module(m);
  API_END();
}

159
160
int DGLModImport(DGLModuleHandle mod,
                 DGLModuleHandle dep) {
Minjie Wang's avatar
Minjie Wang committed
161
162
163
164
165
166
  API_BEGIN();
  static_cast<Module*>(mod)->Import(
      *static_cast<Module*>(dep));
  API_END();
}

167
int DGLModGetFunction(DGLModuleHandle mod,
Minjie Wang's avatar
Minjie Wang committed
168
169
                      const char* func_name,
                      int query_imports,
170
                      DGLFunctionHandle *func) {
Minjie Wang's avatar
Minjie Wang committed
171
172
173
174
175
176
177
178
179
180
181
  API_BEGIN();
  PackedFunc pf = static_cast<Module*>(mod)->GetFunction(
      func_name, query_imports != 0);
  if (pf != nullptr) {
    *func = new PackedFunc(pf);
  } else {
    *func = nullptr;
  }
  API_END();
}

182
int DGLModFree(DGLModuleHandle mod) {
Minjie Wang's avatar
Minjie Wang committed
183
184
185
186
187
  API_BEGIN();
  delete static_cast<Module*>(mod);
  API_END();
}

188
int DGLBackendGetFuncFromEnv(void* mod_node,
Minjie Wang's avatar
Minjie Wang committed
189
                             const char* func_name,
190
                             DGLFunctionHandle *func) {
Minjie Wang's avatar
Minjie Wang committed
191
  API_BEGIN();
192
  *func = (DGLFunctionHandle)(
Minjie Wang's avatar
Minjie Wang committed
193
194
195
196
      static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name));
  API_END();
}

197
void* DGLBackendAllocWorkspace(int device_type,
Minjie Wang's avatar
Minjie Wang committed
198
199
200
201
                               int device_id,
                               uint64_t size,
                               int dtype_code_hint,
                               int dtype_bits_hint) {
202
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
203
204
205
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;

206
  DGLType type_hint;
Minjie Wang's avatar
Minjie Wang committed
207
208
209
210
211
212
213
214
215
  type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
  type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
  type_hint.lanes = 1;

  return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx,
                                                    static_cast<size_t>(size),
                                                    type_hint);
}

216
int DGLBackendFreeWorkspace(int device_type,
Minjie Wang's avatar
Minjie Wang committed
217
218
                            int device_id,
                            void* ptr) {
219
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
220
221
222
223
224
225
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
  return 0;
}

226
int DGLBackendRunOnce(void** handle,
Minjie Wang's avatar
Minjie Wang committed
227
228
229
230
231
232
233
234
235
236
                      int (*f)(void*),
                      void* cdata,
                      int nbytes) {
  if (*handle == nullptr) {
    *handle = reinterpret_cast<void*>(1);
    return (*f)(cdata);
  }
  return 0;
}

237
int DGLFuncFree(DGLFunctionHandle func) {
Minjie Wang's avatar
Minjie Wang committed
238
239
240
241
242
  API_BEGIN();
  delete static_cast<PackedFunc*>(func);
  API_END();
}

243
244
int DGLFuncCall(DGLFunctionHandle func,
                DGLValue* args,
Minjie Wang's avatar
Minjie Wang committed
245
246
                int* arg_type_codes,
                int num_args,
247
                DGLValue* ret_val,
Minjie Wang's avatar
Minjie Wang committed
248
249
                int* ret_type_code) {
  API_BEGIN();
250
  DGLRetValue rv;
Minjie Wang's avatar
Minjie Wang committed
251
  (*static_cast<const PackedFunc*>(func)).CallPacked(
252
      DGLArgs(args, arg_type_codes, num_args), &rv);
Minjie Wang's avatar
Minjie Wang committed
253
254
  // handle return string.
  if (rv.type_code() == kStr ||
255
     rv.type_code() == kDGLType ||
Minjie Wang's avatar
Minjie Wang committed
256
      rv.type_code() == kBytes) {
257
258
    DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get();
    if (rv.type_code() != kDGLType) {
Minjie Wang's avatar
Minjie Wang committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
      e->ret_str = *rv.ptr<std::string>();
    } else {
      e->ret_str = rv.operator std::string();
    }
    if (rv.type_code() == kBytes) {
      e->ret_bytes.data = e->ret_str.c_str();
      e->ret_bytes.size = e->ret_str.length();
      *ret_type_code = kBytes;
      ret_val->v_handle = &(e->ret_bytes);
    } else {
      *ret_type_code = kStr;
      ret_val->v_str = e->ret_str.c_str();
    }
  } else {
    rv.MoveToCHost(ret_val, ret_type_code);
  }
  API_END();
}

278
279
int DGLCFuncSetReturn(DGLRetValueHandle ret,
                      DGLValue* value,
Minjie Wang's avatar
Minjie Wang committed
280
281
282
283
                      int* type_code,
                      int num_ret) {
  API_BEGIN();
  CHECK_EQ(num_ret, 1);
284
285
  DGLRetValue* rv = static_cast<DGLRetValue*>(ret);
  *rv = DGLArgValue(value[0], type_code[0]);
Minjie Wang's avatar
Minjie Wang committed
286
287
288
  API_END();
}

289
int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
Minjie Wang's avatar
Minjie Wang committed
290
                           void* resource_handle,
291
292
                           DGLPackedCFuncFinalizer fin,
                           DGLFunctionHandle *out) {
Minjie Wang's avatar
Minjie Wang committed
293
294
295
  API_BEGIN();
  if (fin == nullptr) {
    *out = new PackedFunc(
296
297
        [func, resource_handle](DGLArgs args, DGLRetValue* rv) {
          int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
Minjie Wang's avatar
Minjie Wang committed
298
299
                         args.num_args, rv, resource_handle);
          if (ret != 0) {
300
301
            std::string err = "DGLCall CFunc Error:\n";
            err += DGLGetLastError();
Minjie Wang's avatar
Minjie Wang committed
302
303
304
305
306
307
308
309
            throw dmlc::Error(err);
          }
        });
  } else {
    // wrap it in a shared_ptr, with fin as deleter.
    // so fin will be called when the lambda went out of scope.
    std::shared_ptr<void> rpack(resource_handle, fin);
    *out = new PackedFunc(
310
311
        [func, rpack](DGLArgs args, DGLRetValue* rv) {
          int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
Minjie Wang's avatar
Minjie Wang committed
312
313
                         args.num_args, rv, rpack.get());
          if (ret != 0) {
314
315
            std::string err = "DGLCall CFunc Error:\n";
            err += DGLGetLastError();
Minjie Wang's avatar
Minjie Wang committed
316
317
318
319
320
321
322
            throw dmlc::Error(err);
          }
      });
  }
  API_END();
}

323
int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
324
  API_BEGIN();
325
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
326
327
328
329
330
331
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  *out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
  API_END();
}

332
int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
333
  API_BEGIN();
334
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
335
336
337
338
339
340
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
  API_END();
}

341
int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
342
  API_BEGIN();
343
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
344
345
346
347
348
349
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
  API_END();
}

350
int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
351
  API_BEGIN();
352
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
353
354
355
356
357
358
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
  API_END();
}

359
int DGLStreamStreamSynchronize(int device_type,
Minjie Wang's avatar
Minjie Wang committed
360
                               int device_id,
361
362
                               DGLStreamHandle src,
                               DGLStreamHandle dst) {
Minjie Wang's avatar
Minjie Wang committed
363
  API_BEGIN();
364
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
365
366
367
368
369
370
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
  API_END();
}

371
int DGLCbArgToReturn(DGLValue* value, int code) {
Minjie Wang's avatar
Minjie Wang committed
372
  API_BEGIN();
373
374
  dgl::runtime::DGLRetValue rv;
  rv = dgl::runtime::DGLArgValue(*value, code);
Minjie Wang's avatar
Minjie Wang committed
375
376
377
378
379
380
381
  int tcode;
  rv.MoveToCHost(value, &tcode);
  CHECK_EQ(tcode, code);
  API_END();
}

// set device api
382
383
384
DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
.set_body([](DGLArgs args, DGLRetValue *ret) {
    DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
385
386
387
388
389
390
    ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
    ctx.device_id = args[1];
    DeviceAPIManager::Get(ctx)->SetDevice(ctx);
  });

// set device api
391
392
393
DGL_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](DGLArgs args, DGLRetValue *ret) {
    DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
    ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
    ctx.device_id = args[1];

    DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
    if (kind == kExist) {
      DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
      if (api != nullptr) {
        api->GetAttr(ctx, kind, ret);
      } else {
        *ret = 0;
      }
    } else {
      DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
    }
  });
409