registry.cc 4.48 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
5
/*!
 *  Copyright (c) 2017 by Contributors
 * \file registry.cc
 * \brief The global registry of packed function.
 */
6
#include <dgl/runtime/registry.h>
Minjie Wang's avatar
Minjie Wang committed
7
8
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
9

Minjie Wang's avatar
Minjie Wang committed
10
#include <array>
11
12
13
14
#include <memory>
#include <mutex>
#include <unordered_map>

Minjie Wang's avatar
Minjie Wang committed
15
16
#include "runtime_base.h"

17
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
18
19
20
21
22
namespace runtime {

struct Registry::Manager {
  // map storing the functions.
  // We delibrately used raw pointer
23
24
25
26
  // This is because PackedFunc can contain callbacks into the host
  // languge(python) and the resource can become invalid because of
  // indeterminstic order of destruction. The resources will only be recycled
  // during program exit.
Minjie Wang's avatar
Minjie Wang committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
  std::unordered_map<std::string, Registry*> fmap;
  // vtable for extension type
  std::array<ExtTypeVTable, kExtEnd> ext_vtable;
  // mutex
  std::mutex mutex;

  Manager() {
    for (auto& x : ext_vtable) {
      x.destroy = nullptr;
    }
  }

  static Manager* Global() {
    static Manager inst;
    return &inst;
  }
};

Registry& Registry::set_body(PackedFunc f) {  // NOLINT(*)
  func_ = f;
  return *this;
}

50
51
Registry& Registry::Register(
    const std::string& name, bool override) {  // NOLINT(*)
Minjie Wang's avatar
Minjie Wang committed
52
53
54
55
56
57
58
59
60
  Manager* m = Manager::Global();
  std::lock_guard<std::mutex> lock(m->mutex);
  auto it = m->fmap.find(name);
  if (it == m->fmap.end()) {
    Registry* r = new Registry();
    r->name_ = name;
    m->fmap[name] = r;
    return *r;
  } else {
61
    CHECK(override) << "Global PackedFunc " << name << " is already registered";
Minjie Wang's avatar
Minjie Wang committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    return *it->second;
  }
}

bool Registry::Remove(const std::string& name) {
  Manager* m = Manager::Global();
  std::lock_guard<std::mutex> lock(m->mutex);
  auto it = m->fmap.find(name);
  if (it == m->fmap.end()) return false;
  m->fmap.erase(it);
  return true;
}

const PackedFunc* Registry::Get(const std::string& name) {
  Manager* m = Manager::Global();
  std::lock_guard<std::mutex> lock(m->mutex);
  auto it = m->fmap.find(name);
  if (it == m->fmap.end()) return nullptr;
  return &(it->second->func_);
}

std::vector<std::string> Registry::ListNames() {
  Manager* m = Manager::Global();
  std::lock_guard<std::mutex> lock(m->mutex);
  std::vector<std::string> keys;
  keys.reserve(m->fmap.size());
88
  for (const auto& kv : m->fmap) {
Minjie Wang's avatar
Minjie Wang committed
89
90
91
92
93
94
95
96
97
    keys.push_back(kv.first);
  }
  return keys;
}

ExtTypeVTable* ExtTypeVTable::Get(int type_code) {
  CHECK(type_code > kExtBegin && type_code < kExtEnd);
  Registry::Manager* m = Registry::Manager::Global();
  ExtTypeVTable* vt = &(m->ext_vtable[type_code]);
98
  CHECK(vt->destroy != nullptr) << "Extension type not registered";
Minjie Wang's avatar
Minjie Wang committed
99
100
101
102
103
104
105
106
107
108
109
110
111
  return vt;
}

ExtTypeVTable* ExtTypeVTable::RegisterInternal(
    int type_code, const ExtTypeVTable& vt) {
  CHECK(type_code > kExtBegin && type_code < kExtEnd);
  Registry::Manager* m = Registry::Manager::Global();
  std::lock_guard<std::mutex> lock(m->mutex);
  ExtTypeVTable* pvt = &(m->ext_vtable[type_code]);
  pvt[0] = vt;
  return pvt;
}
}  // namespace runtime
112
}  // namespace dgl
Minjie Wang's avatar
Minjie Wang committed
113
114

/*! \brief entry to to easily hold returning information */
115
struct DGLFuncThreadLocalEntry {
Minjie Wang's avatar
Minjie Wang committed
116
117
118
  /*! \brief result holder for returning strings */
  std::vector<std::string> ret_vec_str;
  /*! \brief result holder for returning string pointers */
119
  std::vector<const char*> ret_vec_charp;
Minjie Wang's avatar
Minjie Wang committed
120
121
122
};

/*! \brief Thread local store that can be used to hold return values. */
123
typedef dmlc::ThreadLocalStore<DGLFuncThreadLocalEntry> DGLFuncThreadLocalStore;
Minjie Wang's avatar
Minjie Wang committed
124

125
int DGLExtTypeFree(void* handle, int type_code) {
Minjie Wang's avatar
Minjie Wang committed
126
  API_BEGIN();
127
  dgl::runtime::ExtTypeVTable::Get(type_code)->destroy(handle);
Minjie Wang's avatar
Minjie Wang committed
128
129
130
  API_END();
}

131
int DGLFuncRegisterGlobal(const char* name, DGLFunctionHandle f, int override) {
Minjie Wang's avatar
Minjie Wang committed
132
  API_BEGIN();
133
134
  dgl::runtime::Registry::Register(name, override != 0)
      .set_body(*static_cast<dgl::runtime::PackedFunc*>(f));
Minjie Wang's avatar
Minjie Wang committed
135
136
137
  API_END();
}

138
int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
139
  API_BEGIN();
140
  const dgl::runtime::PackedFunc* fp = dgl::runtime::Registry::Get(name);
Minjie Wang's avatar
Minjie Wang committed
141
  if (fp != nullptr) {
142
    *out = new dgl::runtime::PackedFunc(*fp);  // NOLINT(*)
Minjie Wang's avatar
Minjie Wang committed
143
144
145
146
147
148
  } else {
    *out = nullptr;
  }
  API_END();
}

149
int DGLFuncListGlobalNames(int* out_size, const char*** out_array) {
Minjie Wang's avatar
Minjie Wang committed
150
  API_BEGIN();
151
  DGLFuncThreadLocalEntry* ret = DGLFuncThreadLocalStore::Get();
152
  ret->ret_vec_str = dgl::runtime::Registry::ListNames();
Minjie Wang's avatar
Minjie Wang committed
153
154
155
156
157
158
159
160
  ret->ret_vec_charp.clear();
  for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
    ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
  }
  *out_array = dmlc::BeginPtr(ret->ret_vec_charp);
  *out_size = static_cast<int>(ret->ret_vec_str.size());
  API_END();
}