c_runtime_api.cc 11.8 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2016-2022 by Contributors
4
5
 * @file c_runtime_api.cc
 * @brief Runtime API implementation
Minjie Wang's avatar
Minjie Wang committed
6
7
 */
#include <dgl/runtime/c_backend_api.h>
8
9
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
Minjie Wang's avatar
Minjie Wang committed
10
#include <dgl/runtime/module.h>
11
#include <dgl/runtime/packed_func.h>
Minjie Wang's avatar
Minjie Wang committed
12
#include <dgl/runtime/registry.h>
13
#include <dgl/runtime/tensordispatch.h>
14
15
#include <dmlc/thread_local.h>

Minjie Wang's avatar
Minjie Wang committed
16
#include <algorithm>
17
#include <array>
Minjie Wang's avatar
Minjie Wang committed
18
#include <cstdlib>
19
20
#include <string>

Minjie Wang's avatar
Minjie Wang committed
21
22
#include "runtime_base.h"

23
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
24
25
namespace runtime {

26
/**
27
28
 * @brief The name of Device API factory.
 * @param type The device type.
Minjie Wang's avatar
Minjie Wang committed
29
30
31
 */
inline std::string DeviceName(int type) {
  switch (type) {
32
33
34
35
    case kDGLCPU:
      return "cpu";
    case kDGLCUDA:
      return "cuda";
sangwzh's avatar
sangwzh committed
36
37
    case kDGLROCM:
      return "cuda";
38
    // add more device here once supported
39
40
41
    default:
      LOG(FATAL) << "unknown type =" << type;
      return "Unknown";
Minjie Wang's avatar
Minjie Wang committed
42
43
44
45
46
47
48
  }
}

class DeviceAPIManager {
 public:
  static const int kMaxDeviceAPI = 32;
  // Get API
49
  static DeviceAPI* Get(const DGLContext& ctx) { return Get(ctx.device_type); }
Minjie Wang's avatar
Minjie Wang committed
50
51
52
53
54
55
56
57
58
  static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
    return Global()->GetAPI(dev_type, allow_missing);
  }

 private:
  std::array<DeviceAPI*, kMaxDeviceAPI> api_;
  DeviceAPI* rpc_api_{nullptr};
  std::mutex mutex_;
  // constructor
59
  DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }
Minjie Wang's avatar
Minjie Wang committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
  // Global static variable.
  static DeviceAPIManager* Global() {
    static DeviceAPIManager inst;
    return &inst;
  }
  // Get or initialize API.
  DeviceAPI* GetAPI(int type, bool allow_missing) {
    if (type < kRPCSessMask) {
      if (api_[type] != nullptr) return api_[type];
      std::lock_guard<std::mutex> lock(mutex_);
      if (api_[type] != nullptr) return api_[type];
      api_[type] = GetAPI(DeviceName(type), allow_missing);
      return api_[type];
    } else {
      if (rpc_api_ != nullptr) return rpc_api_;
      std::lock_guard<std::mutex> lock(mutex_);
      if (rpc_api_ != nullptr) return rpc_api_;
      rpc_api_ = GetAPI("rpc", allow_missing);
      return rpc_api_;
    }
  }
  DeviceAPI* GetAPI(const std::string name, bool allow_missing) {
    std::string factory = "device_api." + name;
    auto* f = Registry::Get(factory);
    if (f == nullptr) {
      CHECK(allow_missing)
86
87
          << "Device API " << name
          << " is not enabled. Please install the cuda version of dgl.";
Minjie Wang's avatar
Minjie Wang committed
88
89
90
91
92
93
94
      return nullptr;
    }
    void* ptr = (*f)();
    return static_cast<DeviceAPI*>(ptr);
  }
};

95
DeviceAPI* DeviceAPI::Get(DGLContext ctx, bool allow_missing) {
Minjie Wang's avatar
Minjie Wang committed
96
97
98
99
  return DeviceAPIManager::Get(
      static_cast<int>(ctx.device_type), allow_missing);
}

100
DeviceAPI* DeviceAPI::Get(DGLDeviceType dev_type, bool allow_missing) {
101
102
103
  return DeviceAPIManager::Get(static_cast<int>(dev_type), allow_missing);
}

104
105
void* DeviceAPI::AllocWorkspace(
    DGLContext ctx, size_t size, DGLDataType type_hint) {
Minjie Wang's avatar
Minjie Wang committed
106
107
108
  return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
}

109
void DeviceAPI::FreeWorkspace(DGLContext ctx, void* ptr) {
Minjie Wang's avatar
Minjie Wang committed
110
111
112
  FreeDataSpace(ctx, ptr);
}

113
DGLStreamHandle DeviceAPI::CreateStream(DGLContext ctx) {
Minjie Wang's avatar
Minjie Wang committed
114
115
116
117
  LOG(FATAL) << "Device does not support stream api.";
  return 0;
}

118
void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
119
120
121
  LOG(FATAL) << "Device does not support stream api.";
}

122
123
void DeviceAPI::SyncStreamFromTo(
    DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {
Minjie Wang's avatar
Minjie Wang committed
124
125
  LOG(FATAL) << "Device does not support stream api.";
}
126

127
bool DeviceAPI::PinData(void* ptr, size_t nbytes) {
sangwzh's avatar
sangwzh committed
128
  LOG(FATAL) << "Device does not support hipHostRegister api.";
129
  return false;
130
131
}

132
133
void* DeviceAPI::AllocPinnedDataSpace(
    size_t nbytes, void** ctx, void** deleter) {
sangwzh's avatar
sangwzh committed
134
  LOG(FATAL) << "Device does not support hipHostMalloc api.";
135
136
137
138
139
140
141
  return nullptr;
}

void DeviceAPI::FreePinnedDataSpace(void** deleter) {
  LOG(FATAL) << "Device does not support cudaHostFree api.";
}

142
void DeviceAPI::UnpinData(void* ptr) {
sangwzh's avatar
sangwzh committed
143
  LOG(FATAL) << "Device does not support hipHostUnregister api.";
144
}
Minjie Wang's avatar
Minjie Wang committed
145
}  // namespace runtime
146
}  // namespace dgl
Minjie Wang's avatar
Minjie Wang committed
147

148
using namespace dgl::runtime;
Minjie Wang's avatar
Minjie Wang committed
149

150
struct DGLRuntimeEntry {
Minjie Wang's avatar
Minjie Wang committed
151
152
  std::string ret_str;
  std::string last_error;
153
  DGLByteArray ret_bytes;
Minjie Wang's avatar
Minjie Wang committed
154
155
};

156
typedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore;
Minjie Wang's avatar
Minjie Wang committed
157

158
const char* DGLGetLastError() {
159
  return DGLAPIRuntimeStore::Get()->last_error.c_str();
Minjie Wang's avatar
Minjie Wang committed
160
161
}

162
void DGLAPISetLastError(const char* msg) {
Minjie Wang's avatar
Minjie Wang committed
163
#ifndef _LIBCPP_SGX_CONFIG
164
  DGLAPIRuntimeStore::Get()->last_error = msg;
Minjie Wang's avatar
Minjie Wang committed
165
166
167
168
169
#else
  sgx::OCallPackedFunc("__sgx_set_last_error__", msg);
#endif
}

170
171
int DGLModLoadFromFile(
    const char* file_name, const char* format, DGLModuleHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
172
173
174
175
176
177
  API_BEGIN();
  Module m = Module::LoadFromFile(file_name, format);
  *out = new Module(m);
  API_END();
}

178
int DGLModImport(DGLModuleHandle mod, DGLModuleHandle dep) {
Minjie Wang's avatar
Minjie Wang committed
179
  API_BEGIN();
180
  static_cast<Module*>(mod)->Import(*static_cast<Module*>(dep));
Minjie Wang's avatar
Minjie Wang committed
181
182
183
  API_END();
}

184
185
186
int DGLModGetFunction(
    DGLModuleHandle mod, const char* func_name, int query_imports,
    DGLFunctionHandle* func) {
Minjie Wang's avatar
Minjie Wang committed
187
  API_BEGIN();
188
189
  PackedFunc pf =
      static_cast<Module*>(mod)->GetFunction(func_name, query_imports != 0);
Minjie Wang's avatar
Minjie Wang committed
190
191
192
193
194
195
196
197
  if (pf != nullptr) {
    *func = new PackedFunc(pf);
  } else {
    *func = nullptr;
  }
  API_END();
}

198
int DGLModFree(DGLModuleHandle mod) {
Minjie Wang's avatar
Minjie Wang committed
199
200
201
202
203
  API_BEGIN();
  delete static_cast<Module*>(mod);
  API_END();
}

204
205
int DGLBackendGetFuncFromEnv(
    void* mod_node, const char* func_name, DGLFunctionHandle* func) {
Minjie Wang's avatar
Minjie Wang committed
206
  API_BEGIN();
207
208
209
  *func =
      (DGLFunctionHandle)(static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(
          func_name));
Minjie Wang's avatar
Minjie Wang committed
210
211
212
  API_END();
}

213
214
215
void* DGLBackendAllocWorkspace(
    int device_type, int device_id, uint64_t size, int dtype_code_hint,
    int dtype_bits_hint) {
216
  DGLContext ctx;
217
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
218
219
  ctx.device_id = device_id;

220
  DGLDataType type_hint;
Minjie Wang's avatar
Minjie Wang committed
221
222
223
224
  type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
  type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
  type_hint.lanes = 1;

225
226
  return DeviceAPIManager::Get(ctx)->AllocWorkspace(
      ctx, static_cast<size_t>(size), type_hint);
Minjie Wang's avatar
Minjie Wang committed
227
228
}

229
int DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
230
  DGLContext ctx;
231
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
232
233
234
235
236
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeWorkspace(ctx, ptr);
  return 0;
}

237
int DGLBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) {
Minjie Wang's avatar
Minjie Wang committed
238
239
240
241
242
243
244
  if (*handle == nullptr) {
    *handle = reinterpret_cast<void*>(1);
    return (*f)(cdata);
  }
  return 0;
}

245
int DGLFuncFree(DGLFunctionHandle func) {
Minjie Wang's avatar
Minjie Wang committed
246
247
248
249
250
  API_BEGIN();
  delete static_cast<PackedFunc*>(func);
  API_END();
}

251
252
253
int DGLFuncCall(
    DGLFunctionHandle func, DGLValue* args, int* arg_type_codes, int num_args,
    DGLValue* ret_val, int* ret_type_code) {
Minjie Wang's avatar
Minjie Wang committed
254
  API_BEGIN();
255
  DGLRetValue rv;
256
257
  (*static_cast<const PackedFunc*>(func))
      .CallPacked(DGLArgs(args, arg_type_codes, num_args), &rv);
Minjie Wang's avatar
Minjie Wang committed
258
  // handle return string.
259
  if (rv.type_code() == kStr || rv.type_code() == kDGLDataType ||
Minjie Wang's avatar
Minjie Wang committed
260
      rv.type_code() == kBytes) {
261
    DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get();
262
    if (rv.type_code() != kDGLDataType) {
Minjie Wang's avatar
Minjie Wang committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
      e->ret_str = *rv.ptr<std::string>();
    } else {
      e->ret_str = rv.operator std::string();
    }
    if (rv.type_code() == kBytes) {
      e->ret_bytes.data = e->ret_str.c_str();
      e->ret_bytes.size = e->ret_str.length();
      *ret_type_code = kBytes;
      ret_val->v_handle = &(e->ret_bytes);
    } else {
      *ret_type_code = kStr;
      ret_val->v_str = e->ret_str.c_str();
    }
  } else {
    rv.MoveToCHost(ret_val, ret_type_code);
  }
  API_END();
}

282
283
int DGLCFuncSetReturn(
    DGLRetValueHandle ret, DGLValue* value, int* type_code, int num_ret) {
Minjie Wang's avatar
Minjie Wang committed
284
285
  API_BEGIN();
  CHECK_EQ(num_ret, 1);
286
287
  DGLRetValue* rv = static_cast<DGLRetValue*>(ret);
  *rv = DGLArgValue(value[0], type_code[0]);
Minjie Wang's avatar
Minjie Wang committed
288
289
290
  API_END();
}

291
292
293
int DGLFuncCreateFromCFunc(
    DGLPackedCFunc func, void* resource_handle, DGLPackedCFuncFinalizer fin,
    DGLFunctionHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
294
295
  API_BEGIN();
  if (fin == nullptr) {
296
297
298
299
300
    *out =
        new PackedFunc([func, resource_handle](DGLArgs args, DGLRetValue* rv) {
          int ret = func(
              (DGLValue*)args.values, (int*)args.type_codes,  // NOLINT(*)
              args.num_args, rv, resource_handle);
Minjie Wang's avatar
Minjie Wang committed
301
          if (ret != 0) {
302
303
            std::string err = "DGLCall CFunc Error:\n";
            err += DGLGetLastError();
Minjie Wang's avatar
Minjie Wang committed
304
305
306
307
308
309
310
            throw dmlc::Error(err);
          }
        });
  } else {
    // wrap it in a shared_ptr, with fin as deleter.
    // so fin will be called when the lambda went out of scope.
    std::shared_ptr<void> rpack(resource_handle, fin);
311
312
313
314
315
316
317
318
319
320
    *out = new PackedFunc([func, rpack](DGLArgs args, DGLRetValue* rv) {
      int ret = func(
          (DGLValue*)args.values, (int*)args.type_codes,  // NOLINT(*)
          args.num_args, rv, rpack.get());
      if (ret != 0) {
        std::string err = "DGLCall CFunc Error:\n";
        err += DGLGetLastError();
        throw dmlc::Error(err);
      }
    });
Minjie Wang's avatar
Minjie Wang committed
321
322
323
324
  }
  API_END();
}

325
int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
326
  API_BEGIN();
327
  DGLContext ctx;
328
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
329
330
331
332
333
  ctx.device_id = device_id;
  *out = DeviceAPIManager::Get(ctx)->CreateStream(ctx);
  API_END();
}

334
int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
335
  API_BEGIN();
336
  DGLContext ctx;
337
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
338
339
340
341
342
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->FreeStream(ctx, stream);
  API_END();
}

343
int DGLSetStream(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
344
  API_BEGIN();
345
  DGLContext ctx;
346
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
347
348
349
350
351
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SetStream(ctx, stream);
  API_END();
}

352
353
354
int DGLGetStream(int device_type, int device_id, DGLStreamHandle* stream) {
  API_BEGIN();
  DGLContext ctx;
355
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
356
357
358
359
360
  ctx.device_id = device_id;
  *stream = DeviceAPIManager::Get(ctx)->GetStream();
  API_END();
}

361
int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {
Minjie Wang's avatar
Minjie Wang committed
362
  API_BEGIN();
363
  DGLContext ctx;
364
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
365
366
367
368
369
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
  API_END();
}

370
371
int DGLStreamStreamSynchronize(
    int device_type, int device_id, DGLStreamHandle src, DGLStreamHandle dst) {
Minjie Wang's avatar
Minjie Wang committed
372
  API_BEGIN();
373
  DGLContext ctx;
374
  ctx.device_type = static_cast<DGLDeviceType>(device_type);
Minjie Wang's avatar
Minjie Wang committed
375
376
377
378
379
  ctx.device_id = device_id;
  DeviceAPIManager::Get(ctx)->SyncStreamFromTo(ctx, src, dst);
  API_END();
}

380
int DGLCbArgToReturn(DGLValue* value, int code) {
Minjie Wang's avatar
Minjie Wang committed
381
  API_BEGIN();
382
383
  dgl::runtime::DGLRetValue rv;
  rv = dgl::runtime::DGLArgValue(*value, code);
Minjie Wang's avatar
Minjie Wang committed
384
385
386
387
388
389
  int tcode;
  rv.MoveToCHost(value, &tcode);
  CHECK_EQ(tcode, code);
  API_END();
}

390
int DGLLoadTensorAdapter(const char* path) {
391
  return TensorDispatcher::Global()->Load(path) ? 0 : -1;
392
393
}

Minjie Wang's avatar
Minjie Wang committed
394
// set device api
395
DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
396
397
398
399
400
401
    .set_body([](DGLArgs args, DGLRetValue* ret) {
      DGLContext ctx;
      ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
      ctx.device_id = args[1];
      DeviceAPIManager::Get(ctx)->SetDevice(ctx);
    });
Minjie Wang's avatar
Minjie Wang committed
402
403

// set device api
404
DGL_REGISTER_GLOBAL("_GetDeviceAttr")
405
406
407
408
409
410
411
412
413
414
415
416
417
    .set_body([](DGLArgs args, DGLRetValue* ret) {
      DGLContext ctx;
      ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
      ctx.device_id = args[1];

      DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
      if (kind == kExist) {
        DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
        if (api != nullptr) {
          api->GetAttr(ctx, kind, ret);
        } else {
          *ret = 0;
        }
Minjie Wang's avatar
Minjie Wang committed
418
      } else {
419
        DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
Minjie Wang's avatar
Minjie Wang committed
420
      }
421
    });