Unverified Commit 90e78c58 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[FFI] FFI container support, custom structure extension via Object (#693)

* WIP: import tvm runtime node system

* WIP: object system

* containers

* tested basic container composition

* tested custom object

* fix setattr bug

* tested object container return

* fix lint

* some comments about get/set state

* fix lint

* fix lint

* update cython

* fix cython

* ffi doc

* fix doc
parent 684a61ad
"""Container data structures used in DGL runtime.
reference: tvm/python/tvm/collections.py
"""
from __future__ import absolute_import as _abs
from ._ffi.object import ObjectBase, register_object
from . import _api_internal
@register_object
class List(ObjectBase):
"""List container of DGL.
You do not need to create List explicitly.
Normally python list and tuple will be converted automatically
to List during dgl function call.
You may get List in return values of DGL function call.
"""
def __getitem__(self, i):
if isinstance(i, slice):
start = i.start if i.start is not None else 0
stop = i.stop if i.stop is not None else len(self)
step = i.step if i.step is not None else 1
if start < 0:
start += len(self)
if stop < 0:
stop += len(self)
return [self[idx] for idx in range(start, stop, step)]
if i < -len(self) or i >= len(self):
raise IndexError("List index out of range. List size: {}, got index {}"
.format(len(self), i))
if i < 0:
i += len(self)
return _api_internal._ListGetItem(self, i)
def __len__(self):
return _api_internal._ListSize(self)
@register_object
class Map(ObjectBase):
"""Map container of DGL.
You do not need to create Map explicitly.
Normally python dict will be converted automaticall to Map during dgl function call.
You can use convert to create a dict[ObjectBase-> ObjectBase] into a Map
"""
def __getitem__(self, k):
return _api_internal._MapGetItem(self, k)
def __contains__(self, k):
return _api_internal._MapCount(self, k) != 0
def items(self):
"""Get the items from the map"""
akvs = _api_internal._MapItems(self)
return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)]
def __len__(self):
return _api_internal._MapSize(self)
@register_object
class StrMap(Map):
"""A special map container that has str as key.
You can use convert to create a dict[str->ObjectBase] into a Map.
"""
def items(self):
"""Get the items from the map"""
akvs = _api_internal._MapItems(self)
return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]
/*!
* Copyright (c) 2019 by Contributors
* \file api/api_container.cc
* \brief Runtime container APIs. (reference: tvm/src/api/api_lang.cc)
*/
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/registry.h>
#include <dgl/packed_func_ext.h>
namespace dgl {
namespace runtime {
DGL_REGISTER_GLOBAL("_List")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto ret_obj = std::make_shared<runtime::ListObject>();
for (int i = 0; i < args.size(); ++i) {
ret_obj->data.push_back(args[i].obj_sptr());
}
*rv = ret_obj;
});
DGL_REGISTER_GLOBAL("_ListGetItem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto& sptr = args[0].obj_sptr();
CHECK(sptr->is_type<ListObject>());
auto* o = static_cast<const ListObject*>(sptr.get());
int64_t i = args[1];
CHECK_LT(i, o->data.size()) << "list out of bound";
*rv = o->data[i];
});
DGL_REGISTER_GLOBAL("_ListSize")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto& sptr = args[0].obj_sptr();
CHECK(sptr->is_type<ListObject>());
auto* o = static_cast<const ListObject*>(sptr.get());
*rv = static_cast<int64_t>(o->data.size());
});
DGL_REGISTER_GLOBAL("_Map")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CHECK_EQ(args.size() % 2, 0);
if (args.size() != 0 && args[0].type_code() == kStr) {
// StrMap
StrMapObject::ContainerType data;
for (int i = 0; i < args.size(); i += 2) {
CHECK(args[i].type_code() == kStr)
<< "The key of the map must be string";
CHECK(args[i + 1].type_code() == kObjectHandle)
<< "The value of the map must be an object type";
data.emplace(std::make_pair(args[i].operator std::string(),
args[i + 1].obj_sptr()));
}
auto obj = std::make_shared<StrMapObject>();
obj->data = std::move(data);
*rv = obj;
} else {
// object container
MapObject::ContainerType data;
for (int i = 0; i < args.size(); i += 2) {
CHECK(args[i].type_code() == kObjectHandle)
<< "The key of the map must be an object type";
CHECK(args[i + 1].type_code() == kObjectHandle)
<< "The value of the map must be an object type";
data.emplace(std::make_pair(args[i].obj_sptr(), args[i + 1].obj_sptr()));
}
auto obj = std::make_shared<MapObject>();
obj->data = std::move(data);
*rv = obj;
}
});
DGL_REGISTER_GLOBAL("_MapSize")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto& sptr = args[0].obj_sptr();
if (sptr->is_type<MapObject>()) {
auto* o = static_cast<const MapObject*>(sptr.get());
*rv = static_cast<int64_t>(o->data.size());
} else {
CHECK(sptr->is_type<StrMapObject>());
auto* o = static_cast<const StrMapObject*>(sptr.get());
*rv = static_cast<int64_t>(o->data.size());
}
});
DGL_REGISTER_GLOBAL("_MapGetItem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto& sptr = args[0].obj_sptr();
if (sptr->is_type<MapObject>()) {
auto* o = static_cast<const MapObject*>(sptr.get());
auto it = o->data.find(args[1].obj_sptr());
CHECK(it != o->data.end()) << "cannot find the key in the map";
*rv = (*it).second;
} else {
CHECK(sptr->is_type<StrMapObject>());
auto* o = static_cast<const StrMapObject*>(sptr.get());
auto it = o->data.find(args[1].operator std::string());
CHECK(it != o->data.end()) << "cannot find the key in the map";
*rv = (*it).second;
}
});
DGL_REGISTER_GLOBAL("_MapItems")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto& sptr = args[0].obj_sptr();
if (sptr->is_type<MapObject>()) {
auto* o = static_cast<const MapObject*>(sptr.get());
auto rkvs = std::make_shared<ListObject>();
for (const auto& kv : o->data) {
rkvs->data.push_back(kv.first);
rkvs->data.push_back(kv.second);
}
*rv = rkvs;
} else {
CHECK(sptr->is_type<StrMapObject>());
auto* o = static_cast<const StrMapObject*>(sptr.get());
auto rkvs = std::make_shared<ListObject>();
for (const auto& kv : o->data) {
rkvs->data.push_back(MakeValue(kv.first));
rkvs->data.push_back(kv.second);
}
*rv = rkvs;
}
});
DGL_REGISTER_GLOBAL("_MapCount")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto& sptr = args[0].obj_sptr();
if (sptr->is_type<MapObject>()) {
auto* o = static_cast<const MapObject*>(sptr.get());
*rv = static_cast<int64_t>(o->data.count(args[1].obj_sptr()));
} else {
CHECK(sptr->is_type<StrMapObject>());
auto* o = static_cast<const StrMapObject*>(sptr.get());
*rv = static_cast<int64_t>(o->data.count(args[1].operator std::string()));
}
});
DGL_REGISTER_GLOBAL("_Value")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = MakeValue(args[0]);
});
DGL_REGISTER_GLOBAL("_ValueGet")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto& sptr = args[0].obj_sptr();
CHECK(sptr->is_type<ValueObject>());
auto* o = static_cast<const ValueObject*>(sptr.get());
*rv = o->data;
});
} // namespace runtime
} // namespace dgl
/*!
* Copyright (c) 2016 by Contributors
* Implementation of C API (reference: tvm/src/api/c_api.cc)
* \file c_api.cc
*/
#include <dmlc/base.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_object_api.h>
#include <dgl/runtime/object.h>
#include <dgl/packed_func_ext.h>
#include <vector>
#include <string>
#include <exception>
#include "runtime_base.h"
/*! \brief entry to to easily hold returning information */
struct DGLAPIThreadLocalEntry {
/*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp;
/*! \brief result holder for retruning string */
std::string ret_str;
};
using namespace dgl::runtime;
/*! \brief Thread local store that can be used to hold return values. */
typedef dmlc::ThreadLocalStore<DGLAPIThreadLocalEntry> DGLAPIThreadLocalStore;
using DGLAPIObject = std::shared_ptr<Object>;
struct APIAttrGetter : public AttrVisitor {
std::string skey;
DGLRetValue* ret;
bool found_object_ref{false};
void Visit(const char* key, double* value) final {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, int64_t* value) final {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, uint64_t* value) final {
CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
<< "cannot return too big constant";
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, int* value) final {
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, bool* value) final {
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, std::string* value) final {
if (skey == key) *ret = value[0];
}
void Visit(const char* key, ObjectRef* value) final {
if (skey == key) {
*ret = value[0];
found_object_ref = true;
}
}
};
struct APIAttrDir : public AttrVisitor {
std::vector<std::string>* names;
void Visit(const char* key, double* value) final {
names->push_back(key);
}
void Visit(const char* key, int64_t* value) final {
names->push_back(key);
}
void Visit(const char* key, uint64_t* value) final {
names->push_back(key);
}
void Visit(const char* key, bool* value) final {
names->push_back(key);
}
void Visit(const char* key, int* value) final {
names->push_back(key);
}
void Visit(const char* key, std::string* value) final {
names->push_back(key);
}
void Visit(const char* key, ObjectRef* value) final {
names->push_back(key);
}
};
int DGLObjectFree(ObjectHandle handle) {
API_BEGIN();
delete static_cast<DGLAPIObject*>(handle);
API_END();
}
int DGLObjectTypeKey2Index(const char* type_key,
int* out_index) {
API_BEGIN();
*out_index = static_cast<int>(Object::TypeKey2Index(type_key));
API_END();
}
int DGLObjectGetTypeIndex(ObjectHandle handle,
int* out_index) {
API_BEGIN();
*out_index = static_cast<int>(
(*static_cast<DGLAPIObject*>(handle))->type_index());
API_END();
}
int DGLObjectGetAttr(ObjectHandle handle,
const char* key,
DGLValue* ret_val,
int* ret_type_code,
int* ret_success) {
API_BEGIN();
DGLRetValue rv;
APIAttrGetter getter;
getter.skey = key;
getter.ret = &rv;
DGLAPIObject* tobject = static_cast<DGLAPIObject*>(handle);
if (getter.skey == "type_key") {
ret_val->v_str = (*tobject)->type_key();
*ret_type_code = kStr;
*ret_success = 1;
} else {
(*tobject)->VisitAttrs(&getter);
*ret_success = getter.found_object_ref || rv.type_code() != kNull;
if (rv.type_code() == kStr ||
rv.type_code() == kDGLType) {
DGLAPIThreadLocalEntry *e = DGLAPIThreadLocalStore::Get();
e->ret_str = rv.operator std::string();
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
} else {
rv.MoveToCHost(ret_val, ret_type_code);
}
}
API_END();
}
int DGLObjectListAttrNames(ObjectHandle handle,
int *out_size,
const char*** out_array) {
DGLAPIThreadLocalEntry *ret = DGLAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_vec_str.clear();
DGLAPIObject* tobject = static_cast<DGLAPIObject*>(handle);
APIAttrDir dir;
dir.names = &(ret->ret_vec_str);
(*tobject)->VisitAttrs(&dir);
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
}
*out_array = dmlc::BeginPtr(ret->ret_vec_charp);
*out_size = static_cast<int>(ret->ret_vec_str.size());
API_END();
}
/*!
* Copyright (c) 2019 by Contributors
* \file runtime/object.cc
* \brief Implementation of runtime object APIs.
*/
#include <dgl/runtime/object.h>
#include <memory>
#include <atomic>
#include <mutex>
#include <unordered_map>
namespace dgl {
namespace runtime {
namespace {
// single manager of operator information.
struct TypeManager {
// mutex to avoid registration from multiple threads.
// recursive is needed for trigger(which calls UpdateAttrMap)
std::mutex mutex;
std::atomic<uint32_t> type_counter{0};
std::unordered_map<std::string, uint32_t> key2index;
std::vector<std::string> index2key;
// get singleton of the
static TypeManager* Global() {
static TypeManager inst;
return &inst;
}
};
} // namespace
const bool Object::_DerivedFrom(uint32_t tid) const {
static uint32_t tindex = TypeKey2Index(Object::_type_key);
return tid == tindex;
}
// this is slow, usually caller always hold the result in a static variable.
uint32_t Object::TypeKey2Index(const char* key) {
TypeManager *t = TypeManager::Global();
std::lock_guard<std::mutex>(t->mutex);
std::string skey = key;
auto it = t->key2index.find(skey);
if (it != t->key2index.end()) {
return it->second;
}
uint32_t tid = ++(t->type_counter);
t->key2index[skey] = tid;
t->index2key.push_back(skey);
return tid;
}
const char* Object::TypeIndex2Key(uint32_t index) {
TypeManager *t = TypeManager::Global();
std::lock_guard<std::mutex>(t->mutex);
CHECK_NE(index, 0);
return t->index2key.at(index - 1).c_str();
}
} // namespace runtime
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment