"docs/source/api/vscode:/vscode.git/clone" did not exist on "742d79a79289e5b558d39c2540f6a1821f71a3dc"
c_runtime_api.cc 11.9 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
/*!
 *  Copyright (c) 2016 by Contributors
 * \file c_runtime_api.cc
Minjie Wang's avatar
Minjie Wang committed
4
 * \brief Runtime API implementation
Minjie Wang's avatar
Minjie Wang committed
5
6
7
8
9
10
11
12
 */
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
13
#include <dgl/runtime/tensordispatch.h>
Minjie Wang's avatar
Minjie Wang committed
14
15
16
17
18
19
#include <array>
#include <algorithm>
#include <string>
#include <cstdlib>
#include "runtime_base.h"

20
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
namespace runtime {

/*!
 * \brief The name of Device API factory.
 * \param type The device type.
 */
inline std::string DeviceName(int type) {
  switch (type) {
    case kDLCPU: return "cpu";
    case kDLGPU: return "gpu";
    case kDLOpenCL: return "opencl";
    case kDLSDAccel: return "sdaccel";
    case kDLAOCL: return "aocl";
    case kDLVulkan: return "vulkan";
    case kDLMetal: return "metal";
    case kDLVPI: return "vpi";
    case kDLROCM: return "rocm";
    case kOpenGL: return "opengl";
    case kExtDev: return "ext_dev";
    default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
  }
}

class DeviceAPIManager {
 public:
  static const int kMaxDeviceAPI = 32;
  // Get API
48
  static DeviceAPI* Get(const DGLContext& ctx) {
Minjie Wang's avatar
Minjie Wang committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    return Get(ctx.device_type);
  }
  static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
    return Global()->GetAPI(dev_type, allow_missing);
  }

 private:
  std::array<DeviceAPI*, kMaxDeviceAPI> api_;
  DeviceAPI* rpc_api_{nullptr};
  std::mutex mutex_;
  // constructor
  DeviceAPIManager() {
    std::fill(api_.begin(), api_.end(), nullptr);
  }
  // Global static variable.
  static DeviceAPIManager* Global() {
    static DeviceAPIManager inst;
    return &inst;
  }
  // Get or initialize API.
  DeviceAPI* GetAPI(int type, bool allow_missing) {
    if (type < kRPCSessMask) {
      if (api_[type] != nullptr) return api_[type];
      std::lock_guard<std::mutex> lock(mutex_);
      if (api_[type] != nullptr) return api_[type];
      api_[type] = GetAPI(DeviceName(type), allow_missing);
      return api_[type];
    } else {
      if (rpc_api_ != nullptr) return rpc_api_;
      std::lock_guard<std::mutex> lock(mutex_);
      if (rpc_api_ != nullptr) return rpc_api_;
      rpc_api_ = GetAPI("rpc", allow_missing);
      return rpc_api_;
    }
  }
  DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
    std::string factory = "device_api." + name;
    auto* f = Registry::Get(factory);
    if (f == nullptr) {
      CHECK(allow_missing)
89
          << "Device API " << name << " is not enabled. Please install the cuda version of dgl.";
Minjie Wang's avatar
Minjie Wang committed
90
91
92
93
94
95
96
      return nullptr;
    }
    void* ptr = (*f)();
    return static_cast<DeviceAPI*>(ptr);
  }
};

97
DeviceAPI* DeviceAPI::Get(DGLContext ctx, bool allow_missing) {
Minjie Wang's avatar
Minjie Wang committed
98
99
100
101
  return DeviceAPIManager::Get(
      static_cast<int>(ctx.device_type), allow_missing);
}

102
void* DeviceAPI::AllocWorkspace(DGLContext ctx,
Minjie Wang's avatar
Minjie Wang committed
103
                                size_t size,
104
                                DGLType type_hint) {
Minjie Wang's avatar
Minjie Wang committed
105
106
107
  return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}

108
void DeviceAPI::FreeWorkspace(DGLContext ctx, void* ptr) {
Minjie Wang's avatar
Minjie Wang committed
109
110
111
  FreeDataSpace(ctx, ptr);
}

112
DGLStreamHandle DeviceAPI::CreateStream(DGLContext ctx) {
Minjie Wang's avatar
Minjie Wang committed
113
114
115
116
  LOG(FATAL) << "Device does not support stream api.";
  return 0;
}

117
void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
118
119
120
  LOG(FATAL) << "Device does not support stream api.";
}

121
122
123
void DeviceAPI::SyncStreamFromTo(DGLContext ctx,
                                 DGLStreamHandle event_src,
                                 DGLStreamHandle event_dst) {
Minjie Wang's avatar
Minjie Wang committed
124
125
126
  LOG(FATAL) << "Device does not support stream api.";
}
}  // namespace runtime
127
}  // namespace dgl
Minjie Wang's avatar
Minjie Wang committed
128

129
using namespace dgl::runtime;
Minjie Wang's avatar
Minjie Wang committed
130

131
struct DGLRuntimeEntry {
Minjie Wang's avatar
Minjie Wang committed
132
133
  std::string ret_str;
  std::string last_error;
134
  DGLByteArray ret_bytes;
Minjie Wang's avatar
Minjie Wang committed
135
136
};

137
typedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore;
Minjie Wang's avatar
Minjie Wang committed
138

139
140
const char *DGLGetLastError() {
  return DGLAPIRuntimeStore::Get()->last_error.c_str();
Minjie Wang's avatar
Minjie Wang committed
141
142
}

143
void DGLAPISetLastError(const char* msg) {
Minjie Wang's avatar
Minjie Wang committed
144
#ifndef _LIBCPP_SGX_CONFIG
145
  DGLAPIRuntimeStore::Get()->last_error = msg;
Minjie Wang's avatar
Minjie Wang committed
146
147
148
149
150
#else
  sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif
}

151
int DGLModLoadFromFile(const char* file_name,
Minjie Wang's avatar
Minjie Wang committed
152
                       const char* format,
153
                       DGLModuleHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
154
155
156
157
158
159
  API_BEGIN();
  Module m = Module::LoadFromFile(file_name, format);
  *out = new Module(m);
  API_END();
}

160
161
int DGLModImport(DGLModuleHandle mod,
                 DGLModuleHandle dep) {
Minjie Wang's avatar
Minjie Wang committed
162
163
164
165
166
167
  API_BEGIN();
  static_cast<Module*>(mod)->Import(
      *static_cast<Module*>(dep));
  API_END();
}

168
int DGLModGetFunction(DGLModuleHandle mod,
Minjie Wang's avatar
Minjie Wang committed
169
170
                      const char* func_name,
                      int query_imports,
171
                      DGLFunctionHandle *func) {
Minjie Wang's avatar
Minjie Wang committed
172
173
174
175
176
177
178
179
180
181
182
  API_BEGIN();
  PackedFunc pf = static_cast<Module*>(mod)->GetFunction(
      func_name, query_imports != 0);
  if (pf != nullptr) {
    *func = new PackedFunc(pf);
  } else {
    *func = nullptr;
  }
  API_END();
}

183
int DGLModFree(DGLModuleHandle mod) {
Minjie Wang's avatar
Minjie Wang committed
184
185
186
187
188
  API_BEGIN();
  delete static_cast<Module*>(mod);
  API_END();
}

189
int DGLBackendGetFuncFromEnv(void* mod_node,
Minjie Wang's avatar
Minjie Wang committed
190
                             const char* func_name,
191
                             DGLFunctionHandle *func) {
Minjie Wang's avatar
Minjie Wang committed
192
  API_BEGIN();
193
  *func = (DGLFunctionHandle)(
Minjie Wang's avatar
Minjie Wang committed
194
195
196
197
      static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name));
  API_END();
}

198
void* DGLBackendAllocWorkspace(int device_type,
Minjie Wang's avatar
Minjie Wang committed
199
200
201
202
                               int device_id,
                               uint64_t size,
                               int dtype_code_hint,
                               int dtype_bits_hint) {
203
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
204
205
206
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;

207
  DGLType type_hint;
Minjie Wang's avatar
Minjie Wang committed
208
209
210
211
212
213
214
215
216
  type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
  type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
  type_hint.lanes = 1;

  return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx,
                                                    static_cast<size_t>(size),
                                                    type_hint);
}

217
int DGLBackendFreeWorkspace(int device_type,
Minjie Wang's avatar
Minjie Wang committed
218
219
                            int device_id,
                            void* ptr) {
220
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
221
222
223
224
225
226
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
  return 0;
}

227
int DGLBackendRunOnce(void** handle,
Minjie Wang's avatar
Minjie Wang committed
228
229
230
231
232
233
234
235
236
237
                      int (*f)(void*),
                      void* cdata,
                      int nbytes) {
  if (*handle == nullptr) {
    *handle = reinterpret_cast<void*>(1);
    return (*f)(cdata);
  }
  return 0;
}

238
int DGLFuncFree(DGLFunctionHandle func) {
Minjie Wang's avatar
Minjie Wang committed
239
240
241
242
243
  API_BEGIN();
  delete static_cast<PackedFunc*>(func);
  API_END();
}

244
245
int DGLFuncCall(DGLFunctionHandle func,
                DGLValue* args,
Minjie Wang's avatar
Minjie Wang committed
246
247
                int* arg_type_codes,
                int num_args,
248
                DGLValue* ret_val,
Minjie Wang's avatar
Minjie Wang committed
249
250
                int* ret_type_code) {
  API_BEGIN();
251
  DGLRetValue rv;
Minjie Wang's avatar
Minjie Wang committed
252
  (*static_cast<const PackedFunc*>(func)).CallPacked(
253
      DGLArgs(args, arg_type_codes, num_args), &rv);
Minjie Wang's avatar
Minjie Wang committed
254
255
  // handle return string.
  if (rv.type_code() == kStr ||
256
     rv.type_code() == kDGLType ||
Minjie Wang's avatar
Minjie Wang committed
257
      rv.type_code() == kBytes) {
258
259
    DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get();
    if (rv.type_code() != kDGLType) {
Minjie Wang's avatar
Minjie Wang committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
      e->ret_str = *rv.ptr<std::string>();
    } else {
      e->ret_str = rv.operator std::string();
    }
    if (rv.type_code() == kBytes) {
      e->ret_bytes.data = e->ret_str.c_str();
      e->ret_bytes.size = e->ret_str.length();
      *ret_type_code = kBytes;
      ret_val->v_handle = &(e->ret_bytes);
    } else {
      *ret_type_code = kStr;
      ret_val->v_str = e->ret_str.c_str();
    }
  } else {
    rv.MoveToCHost(ret_val, ret_type_code);
  }
  API_END();
}

279
280
int DGLCFuncSetReturn(DGLRetValueHandle ret,
                      DGLValue* value,
Minjie Wang's avatar
Minjie Wang committed
281
282
283
284
                      int* type_code,
                      int num_ret) {
  API_BEGIN();
  CHECK_EQ(num_ret, 1);
285
286
  DGLRetValue* rv = static_cast<DGLRetValue*>(ret);
  *rv = DGLArgValue(value[0], type_code[0]);
Minjie Wang's avatar
Minjie Wang committed
287
288
289
  API_END();
}

290
int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
Minjie Wang's avatar
Minjie Wang committed
291
                           void* resource_handle,
292
293
                           DGLPackedCFuncFinalizer fin,
                           DGLFunctionHandle *out) {
Minjie Wang's avatar
Minjie Wang committed
294
295
296
  API_BEGIN();
  if (fin == nullptr) {
    *out = new PackedFunc(
297
298
        [func, resource_handle](DGLArgs args, DGLRetValue* rv) {
          int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
Minjie Wang's avatar
Minjie Wang committed
299
300
                         args.num_args, rv, resource_handle);
          if (ret != 0) {
301
302
            std::string err = "DGLCall CFunc Error:\n";
            err += DGLGetLastError();
Minjie Wang's avatar
Minjie Wang committed
303
304
305
306
307
308
309
310
            throw dmlc::Error(err);
          }
        });
  } else {
    // wrap it in a shared_ptr, with fin as deleter.
    // so fin will be called when the lambda went out of scope.
    std::shared_ptr<void> rpack(resource_handle, fin);
    *out = new PackedFunc(
311
312
        [func, rpack](DGLArgs args, DGLRetValue* rv) {
          int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
Minjie Wang's avatar
Minjie Wang committed
313
314
                         args.num_args, rv, rpack.get());
          if (ret != 0) {
315
316
            std::string err = "DGLCall CFunc Error:\n";
            err += DGLGetLastError();
Minjie Wang's avatar
Minjie Wang committed
317
318
319
320
321
322
323
            throw dmlc::Error(err);
          }
      });
  }
  API_END();
}

324
int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
325
  API_BEGIN();
326
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
327
328
329
330
331
332
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  *out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
  API_END();
}

333
int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
334
  API_BEGIN();
335
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
336
337
338
339
340
341
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
  API_END();
}

342
int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
343
  API_BEGIN();
344
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
345
346
347
348
349
350
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
  API_END();
}

351
int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
352
  API_BEGIN();
353
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
354
355
356
357
358
359
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
  API_END();
}

360
int DGLStreamStreamSynchronize(int device_type,
Minjie Wang's avatar
Minjie Wang committed
361
                               int device_id,
362
363
                               DGLStreamHandle src,
                               DGLStreamHandle dst) {
Minjie Wang's avatar
Minjie Wang committed
364
  API_BEGIN();
365
  DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
366
367
368
369
370
371
  ctx.device_type = static_cast<DLDeviceType>(device_type);
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
  API_END();
}

372
int DGLCbArgToReturn(DGLValue* value, int code) {
Minjie Wang's avatar
Minjie Wang committed
373
  API_BEGIN();
374
375
  dgl::runtime::DGLRetValue rv;
  rv = dgl::runtime::DGLArgValue(*value, code);
Minjie Wang's avatar
Minjie Wang committed
376
377
378
379
380
381
  int tcode;
  rv.MoveToCHost(value, &tcode);
  CHECK_EQ(tcode, code);
  API_END();
}

382
383
void DGLLoadTensorAdapter(const char *path) {
  TensorDispatcher::Global()->Load(path);
384
385
}

Minjie Wang's avatar
Minjie Wang committed
386
// set device api
387
388
389
DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
.set_body([](DGLArgs args, DGLRetValue *ret) {
    DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
390
391
392
393
394
395
    ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
    ctx.device_id = args[1];
    DeviceAPIManager::Get(ctx)->SetDevice(ctx);
  });

// set device api
396
397
398
DGL_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](DGLArgs args, DGLRetValue *ret) {
    DGLContext ctx;
Minjie Wang's avatar
Minjie Wang committed
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    ctx.device_type = static_cast<DLDeviceType>(args[0].operator int());
    ctx.device_id = args[1];

    DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
    if (kind == kExist) {
      DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
      if (api != nullptr) {
        api->GetAttr(ctx, kind, ret);
      } else {
        *ret = 0;
      }
    } else {
      DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
    }
  });
414