registry.cc 4.52 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
/*!
 *  Copyright (c) 2017 by Contributors
 * \file registry.cc
 * \brief The global registry of packed function.
 */
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <unordered_map>
#include <mutex>
#include <memory>
#include <array>
#include "runtime_base.h"

15
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
namespace runtime {

struct Registry::Manager {
  // map storing the functions.
  // We delibrately used raw pointer
  // This is because PackedFunc can contain callbacks into the host languge(python)
  // and the resource can become invalid because of indeterminstic order of destruction.
  // The resources will only be recycled during program exit.
  std::unordered_map<std::string, Registry*> fmap;
  // vtable for extension type
  std::array<ExtTypeVTable, kExtEnd> ext_vtable;
  // mutex
  std::mutex mutex;

  Manager() {
    for (auto& x : ext_vtable) {
      x.destroy = nullptr;
    }
  }

  static Manager* Global() {
    static Manager inst;
    return &inst;
  }
};

Registry& Registry::set_body(PackedFunc f) {  // NOLINT(*)
  func_ = f;
  return *this;
}

Registry& Registry::Register(const std::string& name, bool override) {  // NOLINT(*)
  Manager* m = Manager::Global();
  std::lock_guard<std::mutex> lock(m->mutex);
  auto it = m->fmap.find(name);
  if (it == m->fmap.end()) {
    Registry* r = new Registry();
    r->name_ = name;
    m->fmap[name] = r;
    return *r;
  } else {
    CHECK(override)
      << "Global PackedFunc " << name << " is already registered";
    return *it->second;
  }
}

bool Registry::Remove(const std::string& name) {
  Manager* m = Manager::Global();
  std::lock_guard<std::mutex> lock(m->mutex);
  auto it = m->fmap.find(name);
  if (it == m->fmap.end()) return false;
  m->fmap.erase(it);
  return true;
}

const PackedFunc* Registry::Get(const std::string& name) {
  Manager* m = Manager::Global();
  std::lock_guard<std::mutex> lock(m->mutex);
  auto it = m->fmap.find(name);
  if (it == m->fmap.end()) return nullptr;
  return &(it->second->func_);
}

std::vector<std::string> Registry::ListNames() {
  Manager* m = Manager::Global();
  std::lock_guard<std::mutex> lock(m->mutex);
  std::vector<std::string> keys;
  keys.reserve(m->fmap.size());
  for (const auto &kv : m->fmap) {
    keys.push_back(kv.first);
  }
  return keys;
}

ExtTypeVTable* ExtTypeVTable::Get(int type_code) {
  CHECK(type_code > kExtBegin && type_code < kExtEnd);
  Registry::Manager* m = Registry::Manager::Global();
  ExtTypeVTable* vt = &(m->ext_vtable[type_code]);
  CHECK(vt->destroy != nullptr)
      << "Extension type not registered";
  return vt;
}

ExtTypeVTable* ExtTypeVTable::RegisterInternal(
    int type_code, const ExtTypeVTable& vt) {
  CHECK(type_code > kExtBegin && type_code < kExtEnd);
  Registry::Manager* m = Registry::Manager::Global();
  std::lock_guard<std::mutex> lock(m->mutex);
  ExtTypeVTable* pvt = &(m->ext_vtable[type_code]);
  pvt[0] = vt;
  return pvt;
}
}  // namespace runtime
110
}  // namespace dgl
Minjie Wang's avatar
Minjie Wang committed
111
112

/*! \brief entry to to easily hold returning information */
113
struct DGLFuncThreadLocalEntry {
Minjie Wang's avatar
Minjie Wang committed
114
115
116
117
118
119
120
  /*! \brief result holder for returning strings */
  std::vector<std::string> ret_vec_str;
  /*! \brief result holder for returning string pointers */
  std::vector<const char *> ret_vec_charp;
};

/*! \brief Thread local store that can be used to hold return values. */
121
typedef dmlc::ThreadLocalStore<DGLFuncThreadLocalEntry> DGLFuncThreadLocalStore;
Minjie Wang's avatar
Minjie Wang committed
122

123
int DGLExtTypeFree(void* handle, int type_code) {
Minjie Wang's avatar
Minjie Wang committed
124
  API_BEGIN();
125
  dgl::runtime::ExtTypeVTable::Get(type_code)->destroy(handle);
Minjie Wang's avatar
Minjie Wang committed
126
127
128
  API_END();
}

129
130
int DGLFuncRegisterGlobal(
    const char* name, DGLFunctionHandle f, int override) {
Minjie Wang's avatar
Minjie Wang committed
131
  API_BEGIN();
132
133
  dgl::runtime::Registry::Register(name, override != 0)
      .set_body(*static_cast<dgl::runtime::PackedFunc*>(f));
Minjie Wang's avatar
Minjie Wang committed
134
135
136
  API_END();
}

137
int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out) {
Minjie Wang's avatar
Minjie Wang committed
138
  API_BEGIN();
139
140
  const dgl::runtime::PackedFunc* fp =
      dgl::runtime::Registry::Get(name);
Minjie Wang's avatar
Minjie Wang committed
141
  if (fp != nullptr) {
142
    *out = new dgl::runtime::PackedFunc(*fp);  // NOLINT(*)
Minjie Wang's avatar
Minjie Wang committed
143
144
145
146
147
148
  } else {
    *out = nullptr;
  }
  API_END();
}

149
int DGLFuncListGlobalNames(int *out_size,
Minjie Wang's avatar
Minjie Wang committed
150
151
                           const char*** out_array) {
  API_BEGIN();
152
153
  DGLFuncThreadLocalEntry *ret = DGLFuncThreadLocalStore::Get();
  ret->ret_vec_str = dgl::runtime::Registry::ListNames();
Minjie Wang's avatar
Minjie Wang committed
154
155
156
157
158
159
160
161
  ret->ret_vec_charp.clear();
  for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
    ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
  }
  *out_array = dmlc::BeginPtr(ret->ret_vec_charp);
  *out_size = static_cast<int>(ret->ret_vec_str.size());
  API_END();
}