c_api_common.cc 1.68 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2018 by Contributors
 * \file c_runtime_api.cc
 * \brief DGL C API common implementations
 */
Lingfan Yu's avatar
Lingfan Yu committed
6
7
#include "c_api_common.h"

8
9
10
11
12
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
Lingfan Yu's avatar
Lingfan Yu committed
13
14
15

namespace dgl {

16
DLManagedTensor* CreateTmpDLManagedTensor(const DGLArgValue& arg) {
Lingfan Yu's avatar
Lingfan Yu committed
17
18
19
20
21
22
23
24
25
  const DLTensor* dl_tensor = arg;
  DLManagedTensor* ret = new DLManagedTensor();
  ret->deleter = [] (DLManagedTensor* self) { delete self; };
  ret->manager_ctx = nullptr;
  ret->dl_tensor = *dl_tensor;
  return ret;
}

PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) {
26
    auto body = [vec](DGLArgs args, DGLRetValue* rv) {
Da Zheng's avatar
Da Zheng committed
27
        const uint64_t which = args[0];
Lingfan Yu's avatar
Lingfan Yu committed
28
29
30
31
32
33
34
35
36
        if (which >= vec.size()) {
            LOG(FATAL) << "invalid choice";
        } else {
            *rv = std::move(vec[which]);
        }
    };
    return PackedFunc(body);
}

37
38
39
40
41
42
43
44
45
46
47
48
49
DGL_REGISTER_GLOBAL("_GetVectorWrapperSize")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    void* ptr = args[0];
    const CAPIVectorWrapper* wrapper = static_cast<const CAPIVectorWrapper*>(ptr);
    *rv = static_cast<int64_t>(wrapper->pointers.size());
  });

DGL_REGISTER_GLOBAL("_GetVectorWrapperData")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    void* ptr = args[0];
    CAPIVectorWrapper* wrapper = static_cast<CAPIVectorWrapper*>(ptr);
    *rv = static_cast<void*>(wrapper->pointers.data());
  });
Lingfan Yu's avatar
Lingfan Yu committed
50

51
52
53
54
55
56
57
58
DGL_REGISTER_GLOBAL("_FreeVectorWrapper")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
    void* ptr = args[0];
    CAPIVectorWrapper* wrapper = static_cast<CAPIVectorWrapper*>(ptr);
    delete wrapper;
  });

}  // namespace dgl