c_object_api.cc 4.73 KB
Newer Older
1
2
3
4
5
/*!
 *  Copyright (c) 2016 by Contributors
 * Implementation of C API (reference: tvm/src/api/c_api.cc)
 * \file c_api.cc
 */
6
7
8
9
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/c_object_api.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
10
11
12
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
13

14
#include <exception>
15
16
17
#include <string>
#include <vector>

18
19
20
21
22
23
24
#include "runtime_base.h"

/*! \brief entry to to easily hold returning information */
struct DGLAPIThreadLocalEntry {
  /*! \brief result holder for returning strings */
  std::vector<std::string> ret_vec_str;
  /*! \brief result holder for returning string pointers */
25
  std::vector<const char*> ret_vec_charp;
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
  /*! \brief result holder for retruning string */
  std::string ret_str;
};

using namespace dgl::runtime;

/*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<DGLAPIThreadLocalEntry> DGLAPIThreadLocalStore;

using DGLAPIObject = std::shared_ptr<Object>;

struct APIAttrGetter : public AttrVisitor {
  std::string skey;
  DGLRetValue* ret;
  bool found_object_ref{false};

  void Visit(const char* key, double* value) final {
    if (skey == key) *ret = value[0];
  }
  void Visit(const char* key, int64_t* value) final {
    if (skey == key) *ret = value[0];
  }
  void Visit(const char* key, uint64_t* value) final {
49
50
    CHECK_LE(
        value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        << "cannot return too big constant";
    if (skey == key) *ret = static_cast<int64_t>(value[0]);
  }
  void Visit(const char* key, int* value) final {
    if (skey == key) *ret = static_cast<int64_t>(value[0]);
  }
  void Visit(const char* key, bool* value) final {
    if (skey == key) *ret = static_cast<int64_t>(value[0]);
  }
  void Visit(const char* key, std::string* value) final {
    if (skey == key) *ret = value[0];
  }
  void Visit(const char* key, ObjectRef* value) final {
    if (skey == key) {
      *ret = value[0];
      found_object_ref = true;
    }
  }
Minjie Wang's avatar
Minjie Wang committed
69
70
71
  void Visit(const char* key, NDArray* value) final {
    if (skey == key) *ret = value[0];
  }
72
73
74
75
76
};

struct APIAttrDir : public AttrVisitor {
  std::vector<std::string>* names;

77
78
79
80
81
  void Visit(const char* key, double* value) final { names->push_back(key); }
  void Visit(const char* key, int64_t* value) final { names->push_back(key); }
  void Visit(const char* key, uint64_t* value) final { names->push_back(key); }
  void Visit(const char* key, bool* value) final { names->push_back(key); }
  void Visit(const char* key, int* value) final { names->push_back(key); }
82
83
84
  void Visit(const char* key, std::string* value) final {
    names->push_back(key);
  }
85
86
  void Visit(const char* key, ObjectRef* value) final { names->push_back(key); }
  void Visit(const char* key, NDArray* value) final { names->push_back(key); }
87
88
89
90
91
92
93
94
};

int DGLObjectFree(ObjectHandle handle) {
  API_BEGIN();
  delete static_cast<DGLAPIObject*>(handle);
  API_END();
}

95
int DGLObjectTypeKey2Index(const char* type_key, int* out_index) {
96
97
98
99
100
  API_BEGIN();
  *out_index = static_cast<int>(Object::TypeKey2Index(type_key));
  API_END();
}

101
int DGLObjectGetTypeIndex(ObjectHandle handle, int* out_index) {
102
  API_BEGIN();
103
104
  *out_index =
      static_cast<int>((*static_cast<DGLAPIObject*>(handle))->type_index());
105
106
107
  API_END();
}

108
109
110
int DGLObjectGetAttr(
    ObjectHandle handle, const char* key, DGLValue* ret_val, int* ret_type_code,
    int* ret_success) {
111
112
113
114
115
116
117
118
119
120
121
122
123
  API_BEGIN();
  DGLRetValue rv;
  APIAttrGetter getter;
  getter.skey = key;
  getter.ret = &rv;
  DGLAPIObject* tobject = static_cast<DGLAPIObject*>(handle);
  if (getter.skey == "type_key") {
    ret_val->v_str = (*tobject)->type_key();
    *ret_type_code = kStr;
    *ret_success = 1;
  } else {
    (*tobject)->VisitAttrs(&getter);
    *ret_success = getter.found_object_ref || rv.type_code() != kNull;
124
125
    if (rv.type_code() == kStr || rv.type_code() == kDGLDataType) {
      DGLAPIThreadLocalEntry* e = DGLAPIThreadLocalStore::Get();
126
127
128
129
130
131
132
133
134
135
      e->ret_str = rv.operator std::string();
      *ret_type_code = kStr;
      ret_val->v_str = e->ret_str.c_str();
    } else {
      rv.MoveToCHost(ret_val, ret_type_code);
    }
  }
  API_END();
}

136
137
138
int DGLObjectListAttrNames(
    ObjectHandle handle, int* out_size, const char*** out_array) {
  DGLAPIThreadLocalEntry* ret = DGLAPIThreadLocalStore::Get();
139
140
141
142
143
144
145
146
147
148
149
150
151
152
  API_BEGIN();
  ret->ret_vec_str.clear();
  DGLAPIObject* tobject = static_cast<DGLAPIObject*>(handle);
  APIAttrDir dir;
  dir.names = &(ret->ret_vec_str);
  (*tobject)->VisitAttrs(&dir);
  ret->ret_vec_charp.clear();
  for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
    ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
  }
  *out_array = dmlc::BeginPtr(ret->ret_vec_charp);
  *out_size = static_cast<int>(ret->ret_vec_str.size());
  API_END();
}