Unverified Commit 90e78c58 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[FFI] FFI container support, custom structure extension via Object (#693)

* WIP: import tvm runtime node system

* WIP: object system

* containers

* tested basic container composition

* tested custom object

* fix setattr bug

* tested object container return

* fix lint

* some comments about get/set state

* fix lint

* fix lint

* update cython

* fix cython

* ffi doc

* fix doc
parent 684a61ad
...@@ -93,6 +93,7 @@ file(GLOB DGL_SRC ...@@ -93,6 +93,7 @@ file(GLOB DGL_SRC
) )
file(GLOB_RECURSE DGL_SRC_1 file(GLOB_RECURSE DGL_SRC_1
src/api/*.cc
src/graph/*.cc src/graph/*.cc
src/scheduler/*.cc src/scheduler/*.cc
) )
......
.. currentmodule:: dgl
DGL Foreign Function Interface (FFI)
====================================
We all like Python because it is easy to manipulate. We all like C because it
is fast, reliable and typed. To have the merits of both ends, DGL is mostly in
python, for quick prototyping, while lowers the performance-critical part to C.
Thus, DGL developers frequently face the scenario to write a C routine and has
it exposed to python, via a mechanism called *Foreign Function Interface (FFI)*.
There are many FFI solutions out there. In DGL, we want to keep it simple,
intuitive and efficient for critical use cases. That's why when we came across the
FFI solution in the TVM project, we immediately fell for it. It exploits the idea of
functional programming so that it exposes only a dozens of C APIs and new APIs
can be built upon it.
We decided to borrow the idea (shamelessly). For example, to define a C
API that is exposed to python is only a few lines of codes:
.. code:: c++
// file: calculator.cc (put it in dgl/src folder)
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
DGL_REGISTER_GLOBAL("calculator.MyAdd")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int a = args[0];
int b = args[1];
*rv = a * b;
});
Compile and build the library. On the python side, create a
``calculator.py`` file under ``dgl/python/dgl/``
.. code:: python
# file: calculator.py
from ._ffi.function import _init_api
def add(a, b):
# MyAdd has been registered via `_ini_api` call below
return MyAdd(a, b)
_init_api("dgl.calculator")
The trick is that the FFI system first masks the type information of the
function arguments, so all the C function calls can go through one C API
(``DGLFuncCall``). The type information is retrieved in the function body by
static conversion, and we will do runtime type check to make sure that the type
conversion is correct. The overhead of such back-and-forth is negligible as
long as the function call is not too light (the above example is actually a bad
one). TVM's `PackedFunc
document <https://docs.tvm.ai/dev/runtime.html#packedfunc>`_ has more details.
Defining new types
------------------
``DGLArgs`` and ``DGLRetValue`` only support a limited number of types:
* Numerical values: int, float, double, ...
* string
* Function (in the form of PackedFunc)
* NDArray
Though limited, the above type system is very powerful because it supports
function as a first-class citizen. For example, if you want to return multiple
values, you can return a PackedFunc which returns each value given an integer
index. However, in many cases, new types are still desired to ease the
development process:
* The argument/return value is a composition of collections (e.g. dictionary of
dictionary of list).
* Sometimes we just want to have a notion of "structure" (e.g. given an apple,
get its color by ``apple.color``).
To achieve this, we introduce the Object type system. For example, to define a
new type ``Calculator``:
.. code:: c++
// file: calculator.cc
#include <dgl/packed_func_ext.h>
using namespace runtime;
class CalculatorObject : public Object {
public:
std::string brand;
int price;
void VisitAttrs(AttrVisitor *v) final {
v->Visit("brand", &brand);
v->Visit("price", &price);
}
static constexpr const char* _type_key = "Calculator";
DGL_DECLARE_OBJECT_TYPE_INFO(CalculatorObject, Object);
};
// This is to define a reference class (the wrapper of an object shared pointer).
// A minimal implementation is as follows, but you could define extra methods.
class Calculator : public ObjectRef {
public:
const CalculatorObject* operator->() const {
return static_cast<const CalculatorObject*>(obj_.get());
}
using ContainerType = CalculatorObject;
};
DGL_REGISTER_GLOBAL("calculator.CreateCaculator")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string brand = args[0];
int price = args[1];
auto o = std::make_shared<CalculatorObject>();
o->brand = brand;
o->price = price;
*rv = o;
}
On the python side:
.. code:: python
# file: calculator.py
from dgl._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
@register_object
class Calculator(ObjectBase):
@staticmethod
def create(brand, price):
# invoke a C API, the return value is of `Calculator` type
return CreateCalculator(brand, price)
_init_api("dgl.calculator")
We can then simply create ``Calculator`` object by:
.. code:: python
calc = Calculator.create("casio", 100)
What is nice about this object is that, it defines a visitor pattern that is
essentially a reflection mechanism to get its internal attributes. For example,
you can print the calculator's brand and by simply accessing its attributes.
.. code:: python
print(calc.brand)
print(calc.price)
The reflection is indeed a little bit slow due to the string key lookup. To
speed it up, you could define an attribute access API:
.. code:: c++
// file: calculator.cc
DGL_REGISTER_GLOBAL("calculator.CaculatorGetBrand")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
Calculator calc = args[0];
*rv = calc->brand;
}
Containers
----------
Containers are also objects. For example, the C API below accepts a list of
integers and return their sum:
.. code:: c++
// in file: calculator.cc
#include <dgl/runtime/container.h>
using namespace runtime;
DGL_REGISTER_GLOBAL("calculator.Sum")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
// All the DGL supported values are represented as a ValueObject, which
// contains a data field.
List<Value> values = args[0];
int sum = 0;
for (int i = 0; i < values.size(); ++i) {
sum += static_cast<int>(values[i]->data);
}
}
Invoking this API is simple -- just pass a python list of integers. DGL FFI will
automatically convert python list/tuple/dictionary to the corresponding object
type.
.. code:: python
# in file: calculator.py
from ._ffi.function import _init_api
Sum([0, 1, 2, 3, 4, 5])
_init_api("dgl.calculator")
The elements in the containers can be any objects, which allows the containers
to be composed. Below is an API that accepts a list of calculators and print
out their price:
.. code:: c++
// in file: calculator.cc
#include <iostream>
#include <dgl/runtime/container.h>
using namespace runtime;
DGL_REGISTER_GLOBAL("calculator.PrintCalculators")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
List<Calculator> calcs = args[0];
for (int i = 0; i < calcs.size(); ++i) {
std::cout << calcs[i]->price << std::endl;
}
}
Please note that containers are NOT meant for passing a large collection of
items from/to C APIs. It will be quite slow in these cases. It is recommended
to benchmark first. As an alternative, use NDArray for a large collection of
numerical values and use BatchedDGLGraph for a lot of graphs.
...@@ -188,11 +188,19 @@ Or go through all of them :doc:`here <tutorials/models/index>`. ...@@ -188,11 +188,19 @@ Or go through all of them :doc:`here <tutorials/models/index>`.
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
:caption: Notes :caption: Developer Notes
:hidden: :hidden:
:glob: :glob:
contribute contribute
developer/ffi
.. toctree::
:maxdepth: 1
:caption: Misc
:hidden:
:glob:
faq faq
env_var env_var
resources resources
......
/*!
* Copyright (c) 2019 by Contributors
* \file packed_func_ext.h
* \brief Extension package to PackedFunc
* This enales pass ObjectRef types into/from PackedFunc.
*/
#ifndef DGL_PACKED_FUNC_EXT_H_
#define DGL_PACKED_FUNC_EXT_H_
#include <sstream>
#include <string>
#include <memory>
#include <type_traits>
#include "./runtime/packed_func.h"
#include "./runtime/object.h"
#include "./runtime/container.h"
namespace dgl {
namespace runtime {
/*!
* \brief Runtime type checker for node type.
* \tparam T the type to be checked.
*/
template<typename T>
struct ObjectTypeChecker {
static inline bool Check(Object* sptr) {
// This is the only place in the project where RTTI is used
// It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType;
return sptr->derived_from<ContainerType>();
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};
template<typename T>
struct ObjectTypeChecker<List<T> > {
static inline bool Check(Object* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<ListObject>()) return false;
ListObject* n = static_cast<ListObject*>(sptr);
for (const auto& p : n->data) {
if (!ObjectTypeChecker<T>::Check(p.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "list<";
ObjectTypeChecker<T>::PrintName(os);
os << ">";
}
};
template<typename V>
struct ObjectTypeChecker<Map<std::string, V> > {
static inline bool Check(Object* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<StrMapObject>()) return false;
StrMapObject* n = static_cast<StrMapObject*>(sptr);
for (const auto& kv : n->data) {
if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<string";
os << ',';
ObjectTypeChecker<V>::PrintName(os);
os << '>';
}
};
template<typename K, typename V>
struct ObjectTypeChecker<Map<K, V> > {
static inline bool Check(Object* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<MapObject>()) return false;
MapObject* n = static_cast<MapObject*>(sptr);
for (const auto& kv : n->data) {
if (!ObjectTypeChecker<K>::Check(kv.first.get())) return false;
if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<";
ObjectTypeChecker<K>::PrintName(os);
os << ',';
ObjectTypeChecker<V>::PrintName(os);
os << '>';
}
};
template<typename T>
inline std::string NodeTypeName() {
std::ostringstream os;
ObjectTypeChecker<T>::PrintName(os);
return os.str();
}
// extensions for DGLArgValue
template<typename TObjectRef>
inline TObjectRef DGLArgValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef derived class");
if (type_code_ == kNull) return TObjectRef();
DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle);
std::shared_ptr<Object>& sptr = *ptr<std::shared_ptr<Object> >();
CHECK(ObjectTypeChecker<TObjectRef>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<TObjectRef>()
<< " but get " << sptr->type_key();
return TObjectRef(sptr);
}
inline std::shared_ptr<Object>& DGLArgValue::obj_sptr() {
DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle);
return *ptr<std::shared_ptr<Object> >();
}
template<typename TObjectRef, typename>
inline bool DGLArgValue::IsObjectType() const {
DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle);
std::shared_ptr<Object>& sptr =
*ptr<std::shared_ptr<Object> >();
return ObjectTypeChecker<TObjectRef>::Check(sptr.get());
}
// extensions for DGLRetValue
inline DGLRetValue& DGLRetValue::operator=(
const std::shared_ptr<Object>& other) {
if (other.get() == nullptr) {
SwitchToPOD(kNull);
} else {
SwitchToClass<std::shared_ptr<Object> >(kObjectHandle, other);
}
return *this;
}
inline DGLRetValue& DGLRetValue::operator=(const ObjectRef& other) {
if (!other.defined()) {
SwitchToPOD(kNull);
} else {
SwitchToClass<std::shared_ptr<Object> >(kObjectHandle, other.obj_);
}
return *this;
}
template<typename TObjectRef>
inline TObjectRef DGLRetValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
if (type_code_ == kNull) return TObjectRef();
DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle);
return TObjectRef(*ptr<std::shared_ptr<Object> >());
}
inline void DGLArgsSetter::operator()(size_t i, const ObjectRef& other) const { // NOLINT(*)
if (other.defined()) {
values_[i].v_handle = const_cast<std::shared_ptr<Object>*>(&(other.obj_));
type_codes_[i] = kObjectHandle;
} else {
type_codes_[i] = kNull;
}
}
} // namespace runtime
} // namespace dgl
#endif // DGL_PACKED_FUNC_EXT_H_
/*!
* Copyright (c) 2019 by Contributors
* \file dgl/runtime/c_object_api.h
*
* \brief DGL Object C API, used to extend and prototype new CAPIs.
*
* \note Most API functions are registerd as PackedFunc and
* can be grabbed via DGLFuncGetGlobal
*/
#ifndef DGL_RUNTIME_C_OBJECT_API_H_
#define DGL_RUNTIME_C_OBJECT_API_H_
#include "./c_runtime_api.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief handle to object */
typedef void* ObjectHandle;
/*!
* \brief free the object handle
* \param handle The object handle to be freed.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLObjectFree(ObjectHandle handle);
/*!
* \brief Convert type key to type index.
* \param type_key The key of the type.
* \param out_index the corresponding type index.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLObjectTypeKey2Index(const char* type_key,
int* out_index);
/*!
* \brief Get runtime type index of the object.
* \param handle the object handle.
* \param out_index the corresponding type index.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLObjectGetTypeIndex(ObjectHandle handle,
int* out_index);
/*!
* \brief get attributes given key
* \param handle The object handle
* \param key The attribute name
* \param out_value The attribute value
* \param out_type_code The type code of the attribute.
* \param out_success Whether get is successful.
* \return 0 when success, -1 when failure happens
* \note API calls always exchanges with type bits=64, lanes=1
*/
DGL_DLL int DGLObjectGetAttr(ObjectHandle handle,
const char* key,
DGLValue* out_value,
int* out_type_code,
int* out_success);
/*!
* \brief get attributes names in the object.
* \param handle The object handle
* \param out_size The number of functions
* \param out_array The array of function names.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLObjectListAttrNames(ObjectHandle handle,
int *out_size,
const char*** out_array);
#ifdef __cplusplus
} // DGL_EXTERN_C
#endif
#endif // DGL_RUNTIME_C_OBJECT_API_H_
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
* \file dgl/runtime/c_runtime_api.h * \file dgl/runtime/c_runtime_api.h
* \brief DGL runtime library. * \brief DGL runtime library.
* *
* This runtime is adapted from TVM project * This runtime is adapted from TVM project (commit: 2ce5277)
*/ */
#ifndef DGL_RUNTIME_C_RUNTIME_API_H_ #ifndef DGL_RUNTIME_C_RUNTIME_API_H_
#define DGL_RUNTIME_C_RUNTIME_API_H_ #define DGL_RUNTIME_C_RUNTIME_API_H_
...@@ -72,7 +72,7 @@ typedef enum { ...@@ -72,7 +72,7 @@ typedef enum {
kDGLType = 5U, kDGLType = 5U,
kDGLContext = 6U, kDGLContext = 6U,
kArrayHandle = 7U, kArrayHandle = 7U,
kNodeHandle = 8U, kObjectHandle = 8U,
kModuleHandle = 9U, kModuleHandle = 9U,
kFuncHandle = 10U, kFuncHandle = 10U,
kStr = 11U, kStr = 11U,
......
/*!
* Copyright (c) 2019 by Contributors
* \file runtime/container.h
* \brief Defines the container object data structures.
*/
#ifndef DGL_RUNTIME_CONTAINER_H_
#define DGL_RUNTIME_CONTAINER_H_
#include <unordered_map>
#include <vector>
#include <memory>
#include <utility>
#include <string>
#include "object.h"
#include "packed_func.h"
namespace dgl {
namespace runtime {
/*!
* \brief value object.
*
* It is typically used to wrap a non-Object type to Object type.
* Any type that is supported by DGLRetValue is supported by this.
*/
class ValueObject : public Object {
public:
/*! \brief the value data */
DGLRetValue data;
static constexpr const char* _type_key = "Value";
DGL_DECLARE_OBJECT_TYPE_INFO(ValueObject, Object);
};
/*! \brief Construct a value object. */
template <typename T>
inline std::shared_ptr<ValueObject> MakeValue(T&& val) {
auto obj = std::make_shared<ValueObject>();
obj->data = val;
return obj;
}
/*! \brief Vallue reference type */
class Value : public ObjectRef {
public:
Value() {}
explicit Value(std::shared_ptr<Object> o): ObjectRef(o) {}
const ValueObject* operator->() const {
return static_cast<const ValueObject*>(obj_.get());
}
using ContainerType = ValueObject;
};
/*! \brief list obj content in list */
class ListObject : public Object {
public:
/*! \brief the data content */
std::vector<std::shared_ptr<Object> > data;
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to list have no effect.
}
static constexpr const char* _type_key = "List";
DGL_DECLARE_OBJECT_TYPE_INFO(ListObject, Object);
};
/*! \brief map obj content */
class MapObject : public Object {
public:
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to map have no effect.
}
// hash function
struct Hash {
size_t operator()(const std::shared_ptr<Object>& n) const {
return std::hash<Object*>()(n.get());
}
};
// comparator
struct Equal {
bool operator()(
const std::shared_ptr<Object>& a,
const std::shared_ptr<Object>& b) const {
return a.get() == b.get();
}
};
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<
std::shared_ptr<Object>,
std::shared_ptr<Object>,
Hash, Equal>;
/*! \brief the data content */
ContainerType data;
static constexpr const char* _type_key = "Map";
DGL_DECLARE_OBJECT_TYPE_INFO(MapObject, Object);
};
/*! \brief specialized map obj with string as key */
class StrMapObject : public Object {
public:
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to map have no effect.
}
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<
std::string,
std::shared_ptr<Object> >;
/*! \brief the data content */
ContainerType data;
static constexpr const char* _type_key = "StrMap";
DGL_DECLARE_OBJECT_TYPE_INFO(StrMapObject, Object);
};
/*!
* \brief iterator adapter that adapts TIter to return another type.
* \tparam Converter a struct that contains converting function
* \tparam TIter the content iterator type.
*/
template<typename Converter,
typename TIter>
class IterAdapter {
public:
explicit IterAdapter(TIter iter) : iter_(iter) {}
inline IterAdapter& operator++() { // NOLINT(*)
++iter_;
return *this;
}
inline IterAdapter& operator++(int) { // NOLINT(*)
++iter_;
return *this;
}
inline IterAdapter operator+(int offset) const { // NOLINT(*)
return IterAdapter(iter_ + offset);
}
inline bool operator==(IterAdapter other) const {
return iter_ == other.iter_;
}
inline bool operator!=(IterAdapter other) const {
return !(*this == other);
}
inline const typename Converter::ResultType operator*() const {
return Converter::convert(*iter_);
}
private:
TIter iter_;
};
/*!
* \brief List container of ObjectRef.
*
* List implements copy on write semantics, which means list is mutable
* but copy will happen when list is referenced in more than two places.
*
* That is said when using this container for runtime arguments or return
* values, try use the constructor to create the list at once (for example
* from an existing vector).
*
* operator[] only provide const acces, use Set to mutate the content.
*
* \tparam T The content ObjectRef type.
*/
template<typename T,
typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type >
class List : public ObjectRef {
public:
/*!
* \brief default constructor
*/
List() {
obj_ = std::make_shared<ListObject>();
}
/*!
* \brief move constructor
* \param other source
*/
List(List<T> && other) { // NOLINT(*)
obj_ = std::move(other.obj_);
}
/*!
* \brief copy constructor
* \param other source
*/
List(const List<T> &other) : ObjectRef(other.obj_) { // NOLINT(*)
}
/*!
* \brief constructor from pointer
* \param n the container pointer
*/
explicit List(std::shared_ptr<Object> n) : ObjectRef(n) {}
/*!
* \brief constructor from iterator
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
List(IterType begin, IterType end) {
assign(begin, end);
}
/*!
* \brief constructor from initializer list
* \param init The initalizer list
*/
List(std::initializer_list<T> init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
List(const std::vector<T>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief Constructs a container with n elements. Each element is a copy of val
* \param n The size of the container
* \param val The init value
*/
explicit List(size_t n, const T& val) {
auto tmp_obj = std::make_shared<ListObject>();
for (size_t i = 0; i < n; ++i) {
tmp_obj->data.push_back(val.obj_);
}
obj_ = std::move(tmp_obj);
}
/*!
* \brief move assign operator
* \param other The source of assignment
* \return reference to self.
*/
List<T>& operator=(List<T> && other) {
obj_ = std::move(other.obj_);
return *this;
}
/*!
* \brief copy assign operator
* \param other The source of assignment
* \return reference to self.
*/
List<T>& operator=(const List<T> & other) {
obj_ = other.obj_;
return *this;
}
/*!
* \brief reset the list to content from iterator.
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
void assign(IterType begin, IterType end) {
auto n = std::make_shared<ListObject>();
for (IterType it = begin; it != end; ++it) {
n->data.push_back((*it).obj_);
}
obj_ = std::move(n);
}
/*!
* \brief Read i-th element from list.
* \param i The index
* \return the i-th element.
*/
inline const T operator[](size_t i) const {
return T(static_cast<const ListObject*>(obj_.get())->data[i]);
}
/*! \return The size of the list */
inline size_t size() const {
if (obj_.get() == nullptr) return 0;
return static_cast<const ListObject*>(obj_.get())->data.size();
}
/*!
* \brief copy on write semantics
* Do nothing if current handle is the unique copy of the list.
* Otherwise make a new copy of the list to ensure the current handle
* hold a unique copy.
*
* \return Handle to the internal obj container(which ganrantees to be unique)
*/
inline ListObject* CopyOnWrite() {
if (obj_.get() == nullptr || !obj_.unique()) {
obj_ = std::make_shared<ListObject>(
*static_cast<const ListObject*>(obj_.get()));
}
return static_cast<ListObject*>(obj_.get());
}
/*!
* \brief push a new item to the back of the list
* \param item The item to be pushed.
*/
inline void push_back(const T& item) {
ListObject* n = this->CopyOnWrite();
n->data.push_back(item.obj_);
}
/*!
* \brief set i-th element of the list.
* \param i The index
* \param value The value to be setted.
*/
inline void Set(size_t i, const T& value) {
ListObject* n = this->CopyOnWrite();
n->data[i] = value.obj_;
}
/*! \return whether list is empty */
inline bool empty() const {
return size() == 0;
}
/*! \brief specify container obj */
using ContainerType = ListObject;
struct Ptr2ObjectRef {
using ResultType = T;
static inline T convert(const std::shared_ptr<Object>& n) {
return T(n);
}
};
using iterator = IterAdapter<Ptr2ObjectRef,
std::vector<std::shared_ptr<Object> >::const_iterator>;
using reverse_iterator = IterAdapter<
Ptr2ObjectRef,
std::vector<std::shared_ptr<Object> >::const_reverse_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
return iterator(static_cast<const ListObject*>(obj_.get())->data.begin());
}
/*! \return end iterator */
inline iterator end() const {
return iterator(static_cast<const ListObject*>(obj_.get())->data.end());
}
/*! \return rbegin iterator */
inline reverse_iterator rbegin() const {
return reverse_iterator(static_cast<const ListObject*>(obj_.get())->data.rbegin());
}
/*! \return rend iterator */
inline reverse_iterator rend() const {
return reverse_iterator(static_cast<const ListObject*>(obj_.get())->data.rend());
}
};
/*!
* \brief Map container of ObjectRef->ObjectRef.
*
* Map implements copy on write semantics, which means map is mutable
* but copy will happen when list is referenced in more than two places.
*
* That is said when using this container for runtime arguments or return
* values, try use the constructor to create it at once (for example
* from an existing std::map).
*
* operator[] only provide const acces, use Set to mutate the content.
*
* \tparam K The key ObjectRef type.
* \tparam V The value ObjectRef type.
*/
template<typename K,
typename V,
typename = typename std::enable_if<
std::is_base_of<ObjectRef, K>::value ||
std::is_base_of<std::string, K>::value >::type,
typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
class Map : public ObjectRef {
public:
/*!
* \brief default constructor
*/
Map() {
obj_ = std::make_shared<MapObject>();
}
/*!
* \brief move constructor
* \param other source
*/
Map(Map<K, V> && other) { // NOLINT(*)
obj_ = std::move(other.obj_);
}
/*!
* \brief copy constructor
* \param other source
*/
Map(const Map<K, V> &other) : ObjectRef(other.obj_) { // NOLINT(*)
}
/*!
* \brief constructor from pointer
* \param n the container pointer
*/
explicit Map(std::shared_ptr<Object> n) : ObjectRef(n) {}
/*!
* \brief constructor from iterator
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
Map(IterType begin, IterType end) {
assign(begin, end);
}
/*!
* \brief constructor from initializer list
* \param init The initalizer list
*/
Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
template<typename Hash, typename Equal>
Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief move assign operator
* \param other The source of assignment
* \return reference to self.
*/
Map<K, V>& operator=(Map<K, V> && other) {
obj_ = std::move(other.obj_);
return *this;
}
/*!
* \brief copy assign operator
* \param other The source of assignment
* \return reference to self.
*/
Map<K, V>& operator=(const Map<K, V> & other) {
obj_ = other.obj_;
return *this;
}
/*!
* \brief reset the list to content from iterator.
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
void assign(IterType begin, IterType end) {
auto n = std::shared_ptr<MapObject>();
for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first.obj_,
i->second.obj_));
}
obj_ = std::move(n);
}
/*!
* \brief Read element from map.
* \param key The key
* \return the corresonding element.
*/
inline const V operator[](const K& key) const {
return V(static_cast<const MapObject*>(obj_.get())->data.at(key.obj_));
}
/*!
* \brief Read element from map.
* \param key The key
* \return the corresonding element.
*/
inline const V at(const K& key) const {
return V(static_cast<const MapObject*>(obj_.get())->data.at(key.obj_));
}
/*! \return The size of the list */
inline size_t size() const {
if (obj_.get() == nullptr) return 0;
return static_cast<const MapObject*>(obj_.get())->data.size();
}
/*! \return The size of the list */
inline size_t count(const K& key) const {
if (obj_.get() == nullptr) return 0;
return static_cast<const MapObject*>(obj_.get())->data.count(key.obj_);
}
/*!
* \brief copy on write semantics
* Do nothing if current handle is the unique copy of the list.
* Otherwise make a new copy of the list to ensure the current handle
* hold a unique copy.
*
* \return Handle to the internal obj container(which ganrantees to be unique)
*/
inline MapObject* CopyOnWrite() {
if (obj_.get() == nullptr || !obj_.unique()) {
obj_ = std::make_shared<MapObject>(
*static_cast<const MapObject*>(obj_.get()));
}
return static_cast<MapObject*>(obj_.get());
}
/*!
* \brief set the Map.
* \param key The index key.
* \param value The value to be setted.
*/
inline void Set(const K& key, const V& value) {
MapObject* n = this->CopyOnWrite();
n->data[key.obj_] = value.obj_;
}
/*! \return whether list is empty */
inline bool empty() const {
return size() == 0;
}
/*! \brief specify container obj */
using ContainerType = MapObject;
struct Ptr2ObjectRef {
using ResultType = std::pair<K, V>;
static inline ResultType convert(const std::pair<
std::shared_ptr<Object>,
std::shared_ptr<Object> >& n) {
return std::make_pair(K(n.first), V(n.second));
}
};
using iterator = IterAdapter<
Ptr2ObjectRef, MapObject::ContainerType::const_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
return iterator(static_cast<const MapObject*>(obj_.get())->data.begin());
}
/*! \return end iterator */
inline iterator end() const {
return iterator(static_cast<const MapObject*>(obj_.get())->data.end());
}
/*! \return begin iterator */
inline iterator find(const K& key) const {
return iterator(static_cast<const MapObject*>(obj_.get())->data.find(key.obj_));
}
};
// specialize of string map
template<typename V, typename T1, typename T2>
class Map<std::string, V, T1, T2> : public ObjectRef {
public:
// for code reuse
Map() {
obj_ = std::make_shared<StrMapObject>();
}
Map(Map<std::string, V> && other) { // NOLINT(*)
obj_ = std::move(other.obj_);
}
Map(const Map<std::string, V> &other) : ObjectRef(other.obj_) { // NOLINT(*)
}
explicit Map(std::shared_ptr<Object> n) : ObjectRef(n) {}
template<typename IterType>
Map(IterType begin, IterType end) {
assign(begin, end);
}
Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*)
assign(init.begin(), init.end());
}
template<typename Hash, typename Equal>
Map(const std::unordered_map<std::string, V, Hash, Equal>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
Map<std::string, V>& operator=(Map<std::string, V> && other) {
obj_ = std::move(other.obj_);
return *this;
}
Map<std::string, V>& operator=(const Map<std::string, V> & other) {
obj_ = other.obj_;
return *this;
}
template<typename IterType>
void assign(IterType begin, IterType end) {
auto n = std::make_shared<StrMapObject>();
for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first,
i->second.obj_));
}
obj_ = std::move(n);
}
inline const V operator[](const std::string& key) const {
return V(static_cast<const StrMapObject*>(obj_.get())->data.at(key));
}
inline const V at(const std::string& key) const {
return V(static_cast<const StrMapObject*>(obj_.get())->data.at(key));
}
inline size_t size() const {
if (obj_.get() == nullptr) return 0;
return static_cast<const StrMapObject*>(obj_.get())->data.size();
}
inline size_t count(const std::string& key) const {
if (obj_.get() == nullptr) return 0;
return static_cast<const StrMapObject*>(obj_.get())->data.count(key);
}
inline StrMapObject* CopyOnWrite() {
if (obj_.get() == nullptr || !obj_.unique()) {
obj_ = std::make_shared<MapObject>(
*static_cast<const MapObject*>(obj_.get()));
}
return static_cast<StrMapObject*>(obj_.get());
}
inline void Set(const std::string& key, const V& value) {
StrMapObject* n = this->CopyOnWrite();
n->data[key] = value.obj_;
}
inline bool empty() const {
return size() == 0;
}
using ContainerType = StrMapObject;
struct Ptr2ObjectRef {
using ResultType = std::pair<std::string, V>;
static inline ResultType convert(const std::pair<
std::string,
std::shared_ptr<Object> >& n) {
return std::make_pair(n.first, V(n.second));
}
};
using iterator = IterAdapter<
Ptr2ObjectRef, StrMapObject::ContainerType::const_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
return iterator(static_cast<const StrMapObject*>(obj_.get())->data.begin());
}
/*! \return end iterator */
inline iterator end() const {
return iterator(static_cast<const StrMapObject*>(obj_.get())->data.end());
}
/*! \return begin iterator */
inline iterator find(const std::string& key) const {
return iterator(static_cast<const StrMapObject*>(obj_.get())->data.find(key));
}
};
} // namespace runtime
} // namespace dgl
#endif // DGL_RUNTIME_CONTAINER_H_
/*!
* Copyright (c) 2019 by Contributors
* \file runtime/object.h
* \brief Defines the Object data structures.
*/
#ifndef DGL_RUNTIME_OBJECT_H_
#define DGL_RUNTIME_OBJECT_H_
#include <dmlc/logging.h>
#include <string>
#include <vector>
#include <memory>
#include <type_traits>
namespace dgl {
namespace runtime {
// forward declaration
class Object;
class ObjectRef;
/*!
* \brief Visitor class to each object attribute.
* The content is going to be called for each field.
*/
class AttrVisitor {
public:
//! \cond Doxygen_Suppress
virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, uint64_t* value) = 0;
virtual void Visit(const char* key, int* value) = 0;
virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, ObjectRef* value) = 0;
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
this->Visit(key, reinterpret_cast<int*>(ptr));
}
//! \endcond
};
/*!
* \brief base class of object container.
* All object's internal is stored as std::shared_ptr<Object>
*/
class Object {
public:
/*! \brief virtual destructor */
virtual ~Object() {}
/*! \return The unique type key of the object */
virtual const char* type_key() const = 0;
/*!
* \brief Apply visitor to each field of the Object
* Visitor could mutate the content of the object.
* override if Object contains attribute fields.
* \param visitor The visitor
*/
virtual void VisitAttrs(AttrVisitor* visitor) {}
/*! \return the type index of the object */
virtual const uint32_t type_index() const = 0;
/*!
* \brief Whether this object derives from object with type_index=tid.
* Implemented by DGL_DECLARE_OBJECT_TYPE_INFO
*
* \param tid The type index.
* \return the check result.
*/
virtual const bool _DerivedFrom(uint32_t tid) const;
/*!
* \brief get a runtime unique type index given a type key
* \param type_key Type key of a type.
* \return the corresponding type index.
*/
static uint32_t TypeKey2Index(const char* type_key);
/*!
* \brief get type key from type index.
* \param index The type index
* \return the corresponding type key.
*/
static const char* TypeIndex2Key(uint32_t index);
/*!
* \return whether the type is derived from
*/
template<typename T>
inline bool derived_from() const;
/*!
* \return whether the object is of type T
* \tparam The type to be checked.
*/
template<typename T>
inline bool is_type() const;
// object ref can see this
friend class ObjectRef;
static constexpr const char* _type_key = "Object";
};
/*! \brief base class of all reference object */
class ObjectRef {
public:
/*! \brief type indicate the container type */
using ContainerType = Object;
/*!
* \brief Comparator
*
* Compare with the two are referencing to the same object (compare by address).
*
* \param other Another object ref.
* \return the compare result.
* \sa same_as
*/
inline bool operator==(const ObjectRef& other) const;
/*!
* \brief Comparator
*
* Compare with the two are referencing to the same object (compare by address).
*
* \param other Another object ref.
* \param other Another object ref.
* \return the compare result.
*/
inline bool same_as(const ObjectRef& other) const;
/*!
* \brief Comparator
*
* The operator overload allows ObjectRef be used in std::map.
*
* \param other Another object ref.
* \return the compare result.
*/
inline bool operator<(const ObjectRef& other) const;
/*!
* \brief Comparator
* \param other Another object ref.
* \return the compare result.
* \sa same_as
*/
inline bool operator!=(const ObjectRef& other) const;
/*! \return the hash function for ObjectRef */
inline size_t hash() const;
/*! \return whether the expression is null */
inline bool defined() const;
/*! \return the internal type index of Object */
inline uint32_t type_index() const;
/*! \return the internal object pointer */
inline const Object* get() const;
/*! \return the internal object pointer */
inline const Object* operator->() const;
/*!
* \brief Downcast this object to its actual type.
* This returns nullptr if the object is not of the requested type.
* Example usage:
*
* if (const Banana *banana = obj->as<Banana>()) {
* // This is a Banana!
* }
* \tparam T the target type, must be subtype of Object
*/
template<typename T>
inline const T *as() const;
/*! \brief default constructor */
ObjectRef() = default;
explicit ObjectRef(std::shared_ptr<Object> obj) : obj_(obj) {}
/*! \brief the internal object, do not touch */
std::shared_ptr<Object> obj_;
};
/*!
* \brief helper macro to declare type information in a base object.
*
* This is macro should be used in abstract base class definition
* because it does not define type_key and type_index.
*/
#define DGL_DECLARE_BASE_OBJECT_INFO(TypeName, Parent) \
const bool _DerivedFrom(uint32_t tid) const override { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
}
/*!
* \brief helper macro to declare type information in a terminal class
*
* This is macro should be used in terminal class definition.
*
* For example:
*
* // This class is an abstract class and cannot create instances
* class SomeBaseClass : public Node {
* public:
* static constexpr const char* _type_key = "some_base";
* DGL_DECLARE_BASE_OBJECT_INFO(SomeBaseClass, Node);
* };
*
* // Child class that allows instantiation
* class SomeChildClass : public SomeBaseClass {
* public:
* static constexpr const char* _type_key = "some_child";
* DGL_DECLARE_OBJECT_TYPE_INFO(SomeChildClass, SomeBaseClass);
* };
*/
#define DGL_DECLARE_OBJECT_TYPE_INFO(TypeName, Parent) \
const char* type_key() const final { \
return TypeName::_type_key; \
} \
const uint32_t type_index() const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
return tidx; \
} \
const bool _DerivedFrom(uint32_t tid) const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
}
// implementations of inline functions after this
template<typename T>
inline bool Object::is_type() const {
// use static field so query only happens once.
static uint32_t type_id = Object::TypeKey2Index(T::_type_key);
return type_id == this->type_index();
}
template<typename T>
inline bool Object::derived_from() const {
// use static field so query only happens once.
static uint32_t type_id = Object::TypeKey2Index(T::_type_key);
return this->_DerivedFrom(type_id);
}
inline const Object* ObjectRef::get() const {
return obj_.get();
}
inline const Object* ObjectRef::operator->() const {
return obj_.get();
}
inline bool ObjectRef::defined() const {
return obj_.get() != nullptr;
}
inline bool ObjectRef::operator==(const ObjectRef& other) const {
return obj_.get() == other.obj_.get();
}
inline bool ObjectRef::same_as(const ObjectRef& other) const {
return obj_.get() == other.obj_.get();
}
inline bool ObjectRef::operator<(const ObjectRef& other) const {
return obj_.get() < other.obj_.get();
}
inline bool ObjectRef::operator!=(const ObjectRef& other) const {
return obj_.get() != other.obj_.get();
}
inline size_t ObjectRef::hash() const {
return std::hash<Object*>()(obj_.get());
}
inline uint32_t ObjectRef::type_index() const {
CHECK(obj_.get() != nullptr) << "null type";
return get()->type_index();
}
template<typename T>
inline const T* ObjectRef::as() const {
const Object* ptr = get();
if (ptr && ptr->is_type<T>()) {
return static_cast<const T*>(ptr);
}
return nullptr;
}
/*! \brief The hash function for nodes */
struct ObjectHash {
size_t operator()(const ObjectRef& a) const {
return a.hash();
}
};
/*! \brief The equal comparator for nodes */
struct ObjectEqual {
bool operator()(const ObjectRef& a, const ObjectRef& b) const {
return a.get() == b.get();
}
};
} // namespace runtime
} // namespace dgl
namespace std {
template <>
struct hash<::dgl::runtime::ObjectRef> {
std::size_t operator()(const ::dgl::runtime::ObjectRef& k) const {
return k.hash();
}
};
} // namespace std
#endif // DGL_RUNTIME_OBJECT_H_
...@@ -24,13 +24,14 @@ ...@@ -24,13 +24,14 @@
#endif #endif
namespace dgl { namespace dgl {
// Forward declare NodeRef and Node for extensions. namespace runtime {
// This header works fine without depend on NodeRef
// Forward declare ObjectRef and Object for extensions.
// This header works fine without depend on ObjectRef
// as long as it is not used. // as long as it is not used.
class Node; class Object;
class NodeRef; class ObjectRef;
namespace runtime {
// forward declarations // forward declarations
class DGLArgs; class DGLArgs;
class DGLArgValue; class DGLArgValue;
...@@ -520,19 +521,25 @@ class DGLArgValue : public DGLPODValue_ { ...@@ -520,19 +521,25 @@ class DGLArgValue : public DGLPODValue_ {
const DGLValue& value() const { const DGLValue& value() const {
return value_; return value_;
} }
// Deferred extension handler. // Deferred extension handler.
template<typename TNodeRef> template<typename TObjectRef>
inline TNodeRef AsNodeRef() const; inline TObjectRef AsObjectRef() const;
// Convert this value to arbitrary class type
template<typename T, template<typename T,
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_class<T>::value>::type> std::is_class<T>::value>::type>
inline operator T() const; inline operator T() const;
template<typename TNodeRef,
// Return true if the value is of TObjectRef type
template<typename TObjectRef,
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_class<TNodeRef>::value>::type> std::is_class<TObjectRef>::value>::type>
inline bool IsNodeType() const; inline bool IsObjectType() const;
// get internal node ptr, if it is node // get internal node ptr, if it is node
inline std::shared_ptr<Node>& node_sptr(); inline std::shared_ptr<Object>& obj_sptr();
}; };
/*! /*!
...@@ -714,21 +721,21 @@ class DGLRetValue : public DGLPODValue_ { ...@@ -714,21 +721,21 @@ class DGLRetValue : public DGLPODValue_ {
} }
/*! \return The value field, if the data is POD */ /*! \return The value field, if the data is POD */
const DGLValue& value() const { const DGLValue& value() const {
CHECK(type_code_ != kNodeHandle && CHECK(type_code_ != kObjectHandle &&
type_code_ != kFuncHandle && type_code_ != kFuncHandle &&
type_code_ != kModuleHandle && type_code_ != kModuleHandle &&
type_code_ != kStr) << "DGLRetValue.value can only be used for POD data"; type_code_ != kStr) << "DGLRetValue.value can only be used for POD data";
return value_; return value_;
} }
// NodeRef related extenstions: in dgl/packed_func_ext.h // ObjectRef related extenstions: in dgl/packed_func_ext.h
template<typename T, template<typename T,
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_class<T>::value>::type> std::is_class<T>::value>::type>
inline operator T() const; inline operator T() const;
template<typename TNodeRef> template<typename TObjectRef>
inline TNodeRef AsNodeRef() const; inline TObjectRef AsObjectRef() const;
inline DGLRetValue& operator=(const NodeRef& other); inline DGLRetValue& operator=(const ObjectRef& other);
inline DGLRetValue& operator=(const std::shared_ptr<Node>& other); inline DGLRetValue& operator=(const std::shared_ptr<Object>& other);
private: private:
template<typename T> template<typename T>
...@@ -754,9 +761,9 @@ class DGLRetValue : public DGLPODValue_ { ...@@ -754,9 +761,9 @@ class DGLRetValue : public DGLPODValue_ {
*this = other.operator NDArray(); *this = other.operator NDArray();
break; break;
} }
case kNodeHandle: { case kObjectHandle: {
SwitchToClass<std::shared_ptr<Node> >( SwitchToClass<std::shared_ptr<Object> >(
kNodeHandle, *other.template ptr<std::shared_ptr<Node> >()); kObjectHandle, *other.template ptr<std::shared_ptr<Object> >());
break; break;
} }
default: { default: {
...@@ -801,7 +808,7 @@ class DGLRetValue : public DGLPODValue_ { ...@@ -801,7 +808,7 @@ class DGLRetValue : public DGLPODValue_ {
case kStr: delete ptr<std::string>(); break; case kStr: delete ptr<std::string>(); break;
case kFuncHandle: delete ptr<PackedFunc>(); break; case kFuncHandle: delete ptr<PackedFunc>(); break;
case kModuleHandle: delete ptr<Module>(); break; case kModuleHandle: delete ptr<Module>(); break;
case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break; case kObjectHandle: delete ptr<std::shared_ptr<Object> >(); break;
case kNDArrayContainer: { case kNDArrayContainer: {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef(); static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break; break;
...@@ -828,7 +835,7 @@ inline const char* TypeCode2Str(int type_code) { ...@@ -828,7 +835,7 @@ inline const char* TypeCode2Str(int type_code) {
case kBytes: return "bytes"; case kBytes: return "bytes";
case kHandle: return "handle"; case kHandle: return "handle";
case kNull: return "NULL"; case kNull: return "NULL";
case kNodeHandle: return "NodeHandle"; case kObjectHandle: return "ObjectHandle";
case kArrayHandle: return "ArrayHandle"; case kArrayHandle: return "ArrayHandle";
case kDGLType: return "DGLType"; case kDGLType: return "DGLType";
case kDGLContext: return "DGLContext"; case kDGLContext: return "DGLContext";
...@@ -1036,8 +1043,8 @@ class DGLArgsSetter { ...@@ -1036,8 +1043,8 @@ class DGLArgsSetter {
typename = typename std::enable_if< typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type> extension_class_info<T>::code != 0>::type>
inline void operator()(size_t i, const T& value) const; inline void operator()(size_t i, const T& value) const;
// NodeRef related extenstions: in dgl/packed_func_ext.h // ObjectRef related extenstions: in dgl/packed_func_ext.h
inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*) inline void operator()(size_t i, const ObjectRef& other) const; // NOLINT(*)
private: private:
/*! \brief The values fields */ /*! \brief The values fields */
...@@ -1146,7 +1153,7 @@ namespace detail { ...@@ -1146,7 +1153,7 @@ namespace detail {
template<typename T, typename TSrc, bool is_ext> template<typename T, typename TSrc, bool is_ext>
struct DGLValueCast { struct DGLValueCast {
static T Apply(const TSrc* self) { static T Apply(const TSrc* self) {
return self->template AsNodeRef<T>(); return self->template AsObjectRef<T>();
} }
}; };
......
...@@ -6,6 +6,7 @@ import socket ...@@ -6,6 +6,7 @@ import socket
from . import function from . import function
from . import nn from . import nn
from . import contrib from . import contrib
from . import container
from ._ffi.runtime_ctypes import TypeCode from ._ffi.runtime_ctypes import TypeCode
from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs from ._ffi.function import register_func, get_global_func, list_global_func_names, extract_ext_funcs
......
# C API and runtime # C API and runtime
Borrowed and adapted from TVM project. Borrowed and adapted from TVM project. (commit: 2ce5277)
...@@ -9,12 +9,15 @@ from numbers import Number, Integral ...@@ -9,12 +9,15 @@ from numbers import Number, Integral
from ..base import _LIB, check_call from ..base import _LIB, check_call
from ..base import c_str, string_types from ..base import c_str, string_types
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import DGLType, DGLByteArray, DGLContext from ..runtime_ctypes import DGLType, DGLByteArray, DGLContext
from . import ndarray as _nd from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array from .ndarray import NDArrayBase, _make_array
from .types import DGLValue, TypeCode from .types import DGLValue, TypeCode
from .types import DGLPackedCFunc, DGLCFuncFinalizer from .types import DGLPackedCFunc, DGLCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .object import ObjectBase
from . import object as _object
FunctionHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p ModuleHandle = ctypes.c_void_p
...@@ -79,7 +82,11 @@ def convert_to_dgl_func(pyfunc): ...@@ -79,7 +82,11 @@ def convert_to_dgl_func(pyfunc):
def _make_dgl_args(args, temp_args): def _make_dgl_args(args, temp_args):
"""Pack arguments into c args dgl call accept""" """Pack arguments into c args dgl call accept.
temp_args is used to temporarily save the arguments so they will not be
freed during C API function call.
"""
num_args = len(args) num_args = len(args)
values = (DGLValue * num_args)() values = (DGLValue * num_args)()
type_codes = (ctypes.c_int * num_args)() type_codes = (ctypes.c_int * num_args)()
...@@ -87,6 +94,14 @@ def _make_dgl_args(args, temp_args): ...@@ -87,6 +94,14 @@ def _make_dgl_args(args, temp_args):
if arg is None: if arg is None:
values[i].v_handle = None values[i].v_handle = None
type_codes[i] = TypeCode.NULL type_codes[i] = TypeCode.NULL
elif isinstance(arg, ObjectBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT_HANDLE
elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
arg = convert_to_object(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT_HANDLE
temp_args.append(arg)
elif isinstance(arg, NDArrayBase): elif isinstance(arg, NDArrayBase):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p) values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = (TypeCode.NDARRAY_CONTAINER type_codes[i] = (TypeCode.NDARRAY_CONTAINER
...@@ -189,7 +204,7 @@ def __init_handle_by_constructor__(fconstructor, args): ...@@ -189,7 +204,7 @@ def __init_handle_by_constructor__(fconstructor, args):
ctypes.byref(ret_val), ctypes.byref(ret_tcode))) ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
_ = temp_args _ = temp_args
_ = args _ = args
assert ret_tcode.value == TypeCode.NODE_HANDLE assert ret_tcode.value == TypeCode.OBJECT_HANDLE
handle = ret_val.v_handle handle = ret_val.v_handle
return handle return handle
...@@ -210,6 +225,7 @@ def _handle_return_func(x): ...@@ -210,6 +225,7 @@ def _handle_return_func(x):
# setup return handle for function type # setup return handle for function type
_object.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False) RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False)
......
"""ctypes object API."""
from __future__ import absolute_import
import ctypes
from ..base import _LIB, check_call, c_str
from ..object_generic import _set_class_object_base
from .types import DGLValue, TypeCode
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
ObjectHandle = ctypes.c_void_p
__init_by_constructor__ = None
"""Maps object type to its constructor"""
OBJECT_TYPE = {}
def _register_object(index, cls):
"""register object class in python"""
OBJECT_TYPE[index] = cls
def _return_object(x):
"""Construct a object object from the given DGLValue object"""
handle = x.v_handle
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
tindex = ctypes.c_int()
check_call(_LIB.DGLObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tindex.value, ObjectBase)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
obj = cls.__new__(cls)
obj.handle = handle
return obj
RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object
C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_HANDLE)
class ObjectBase(object):
"""Object base class"""
__slots__ = ["handle"]
# pylint: disable=no-member
def __del__(self):
if _LIB is not None:
check_call(_LIB.DGLObjectFree(self.handle))
def __getattr__(self, name):
if name == 'handle':
raise AttributeError("'handle' is a reserved attribute name that should not be used")
ret_val = DGLValue()
ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int()
check_call(_LIB.DGLObjectGetAttr(
self.handle, c_str(name),
ctypes.byref(ret_val),
ctypes.byref(ret_type_code),
ctypes.byref(ret_success)))
if not ret_success.value:
raise AttributeError(
"'%s' object has no attribute '%s'" % (str(type(self)), name))
return RETURN_SWITCH[ret_type_code.value](ret_val)
def __setattr__(self, name, value):
if name != 'handle':
raise AttributeError('Set attribute is not allowed for DGL object.')
object.__setattr__(self, name, value)
def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Object object
instead of creating a new Object.
"""
# assign handle first to avoid error raising
self.handle = None
handle = __init_by_constructor__(fconstructor, args)
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
self.handle = handle
_set_class_object_base(ObjectBase)
...@@ -15,7 +15,7 @@ cdef enum DGLTypeCode: ...@@ -15,7 +15,7 @@ cdef enum DGLTypeCode:
kDGLType = 5 kDGLType = 5
kDGLContext = 6 kDGLContext = 6
kArrayHandle = 7 kArrayHandle = 7
kNodeHandle = 8 kObjectHandle = 8
kModuleHandle = 9 kModuleHandle = 9
kFuncHandle = 10 kFuncHandle = 10
kStr = 11 kStr = 11
...@@ -62,7 +62,7 @@ ctypedef DGLArray* CDGLArrayHandle ...@@ -62,7 +62,7 @@ ctypedef DGLArray* CDGLArrayHandle
ctypedef void* DGLStreamHandle ctypedef void* DGLStreamHandle
ctypedef void* DGLRetValueHandle ctypedef void* DGLRetValueHandle
ctypedef void* DGLFunctionHandle ctypedef void* DGLFunctionHandle
ctypedef void* NodeHandle ctypedef void* ObjectHandle
ctypedef int (*DGLPackedCFunc)( ctypedef int (*DGLPackedCFunc)(
DGLValue* args, DGLValue* args,
...@@ -115,18 +115,17 @@ cdef extern from "dgl/runtime/c_runtime_api.h": ...@@ -115,18 +115,17 @@ cdef extern from "dgl/runtime/c_runtime_api.h":
DLManagedTensor** out) DLManagedTensor** out)
void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor) void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
# (minjie): Node and class module are not used in DGL. cdef extern from "dgl/runtime/c_object_api.h":
#cdef extern from "dgl/c_dsl_api.h": int DGLObjectFree(ObjectHandle handle)
# int DGLNodeFree(NodeHandle handle) int DGLObjectTypeKey2Index(const char* type_key,
# int DGLNodeTypeKey2Index(const char* type_key, int* out_index)
# int* out_index) int DGLObjectGetTypeIndex(ObjectHandle handle,
# int DGLNodeGetTypeIndex(NodeHandle handle, int* out_index)
# int* out_index) int DGLObjectGetAttr(ObjectHandle handle,
# int DGLNodeGetAttr(NodeHandle handle, const char* key,
# const char* key, DGLValue* out_value,
# DGLValue* out_value, int* out_type_code,
# int* out_type_code, int* out_success)
# int* out_success)
cdef inline py_str(const char* x): cdef inline py_str(const char* x):
if PY_MAJOR_VERSION < 3: if PY_MAJOR_VERSION < 3:
......
include "./base.pxi" include "./base.pxi"
# (minjie): Node and class module are not used in DGL. include "./object.pxi"
#include "./node.pxi"
include "./function.pxi" include "./function.pxi"
include "./ndarray.pxi" include "./ndarray.pxi"
...@@ -3,8 +3,7 @@ import traceback ...@@ -3,8 +3,7 @@ import traceback
from cpython cimport Py_INCREF, Py_DECREF from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral from numbers import Number, Integral
from ..base import string_types from ..base import string_types
# (minjie): Node and class module are not used in DGL. from ..object_generic import convert_to_object, ObjectGeneric
# from ..node_generic import convert_to_node, NodeGeneric
from ..runtime_ctypes import DGLType, DGLContext, DGLByteArray from ..runtime_ctypes import DGLType, DGLContext, DGLByteArray
...@@ -25,9 +24,8 @@ cdef int dgl_callback(DGLValue* args, ...@@ -25,9 +24,8 @@ cdef int dgl_callback(DGLValue* args,
for i in range(num_args): for i in range(num_args):
value = args[i] value = args[i]
tcode = type_codes[i] tcode = type_codes[i]
# (minjie): Node and class module are not used in DGL. if (tcode == kObjectHandle or
#if (tcode == kNodeHandle or tcode == kFuncHandle or
if (tcode == kFuncHandle or
tcode == kModuleHandle or tcode == kModuleHandle or
tcode > kExtBegin): tcode > kExtBegin):
CALL(DGLCbArgToReturn(&value, tcode)) CALL(DGLCbArgToReturn(&value, tcode))
...@@ -81,11 +79,10 @@ cdef inline int make_arg(object arg, ...@@ -81,11 +79,10 @@ cdef inline int make_arg(object arg,
list temp_args) except -1: list temp_args) except -1:
"""Pack arguments into c args dgl call accept""" """Pack arguments into c args dgl call accept"""
cdef unsigned long long ptr cdef unsigned long long ptr
# (minjie): Node and class module are not used in DGL. if isinstance(arg, ObjectBase):
#if isinstance(arg, NodeBase): value[0].v_handle = (<ObjectBase>arg).chandle
# value[0].v_handle = (<NodeBase>arg).chandle tcode[0] = kObjectHandle
# tcode[0] = kNodeHandle elif isinstance(arg, NDArrayBase):
if isinstance(arg, NDArrayBase):
value[0].v_handle = (<NDArrayBase>arg).chandle value[0].v_handle = (<NDArrayBase>arg).chandle
tcode[0] = (kNDArrayContainer if tcode[0] = (kNDArrayContainer if
not (<NDArrayBase>arg).c_is_view else kArrayHandle) not (<NDArrayBase>arg).c_is_view else kArrayHandle)
...@@ -134,12 +131,11 @@ cdef inline int make_arg(object arg, ...@@ -134,12 +131,11 @@ cdef inline int make_arg(object arg,
value[0].v_str = tstr value[0].v_str = tstr
tcode[0] = kStr tcode[0] = kStr
temp_args.append(tstr) temp_args.append(tstr)
# (minjie): Node and class module are not used in DGL. elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
#elif isinstance(arg, (list, tuple, dict, NodeGeneric)): arg = convert_to_object(arg)
# arg = convert_to_node(arg) value[0].v_handle = (<ObjectBase>arg).chandle
# value[0].v_handle = (<NodeBase>arg).chandle tcode[0] = kObjectHandle
# tcode[0] = kNodeHandle temp_args.append(arg)
# temp_args.append(arg)
#elif isinstance(arg, _CLASS_MODULE): #elif isinstance(arg, _CLASS_MODULE):
# value[0].v_handle = c_handle(arg.handle) # value[0].v_handle = c_handle(arg.handle)
# tcode[0] = kModuleHandle # tcode[0] = kModuleHandle
...@@ -170,10 +166,9 @@ cdef inline bytearray make_ret_bytes(void* chandle): ...@@ -170,10 +166,9 @@ cdef inline bytearray make_ret_bytes(void* chandle):
cdef inline object make_ret(DGLValue value, int tcode): cdef inline object make_ret(DGLValue value, int tcode):
"""convert result to return value.""" """convert result to return value."""
# (minjie): Node and class module are not used in DGL. if tcode == kObjectHandle:
#if tcode == kNodeHandle: return make_ret_object(value.v_handle)
# return make_ret_node(value.v_handle) elif tcode == kNull:
if tcode == kNull:
return None return None
elif tcode == kInt: elif tcode == kInt:
return value.v_int64 return value.v_int64
...@@ -189,7 +184,7 @@ cdef inline object make_ret(DGLValue value, int tcode): ...@@ -189,7 +184,7 @@ cdef inline object make_ret(DGLValue value, int tcode):
return ctypes_handle(value.v_handle) return ctypes_handle(value.v_handle)
elif tcode == kDGLContext: elif tcode == kDGLContext:
return DGLContext(value.v_ctx.device_type, value.v_ctx.device_id) return DGLContext(value.v_ctx.device_type, value.v_ctx.device_id)
# (minjie): Node and class module are not used in DGL. # (minjie): class module are not used in DGL.
#elif tcode == kModuleHandle: #elif tcode == kModuleHandle:
# return _CLASS_MODULE(ctypes_handle(value.v_handle)) # return _CLASS_MODULE(ctypes_handle(value.v_handle))
elif tcode == kFuncHandle: elif tcode == kFuncHandle:
......
from ... import _api_internal from ... import _api_internal
from ..base import string_types from ..base import string_types
from ..node_generic import _set_class_node_base from ..object_generic import _set_class_object_base
"""Maps node type to its constructor""" """Maps object type to its constructor"""
NODE_TYPE = [] OBJECT_TYPE = []
def _register_node(int index, object cls): def _register_object(int index, object cls):
"""register node class""" """register object class"""
while len(NODE_TYPE) <= index: while len(OBJECT_TYPE) <= index:
NODE_TYPE.append(None) OBJECT_TYPE.append(None)
NODE_TYPE[index] = cls OBJECT_TYPE[index] = cls
cdef inline object make_ret_node(void* chandle): cdef inline object make_ret_object(void* chandle):
global NODE_TYPE global OBJECT_TYPE
cdef int tindex cdef int tindex
cdef list node_type cdef list object_type
cdef object cls cdef object cls
node_type = NODE_TYPE object_type = OBJECT_TYPE
CALL(DGLNodeGetTypeIndex(chandle, &tindex)) CALL(DGLObjectGetTypeIndex(chandle, &tindex))
if tindex < len(node_type): if tindex < len(object_type):
cls = node_type[tindex] cls = object_type[tindex]
if cls is not None: if cls is not None:
obj = cls.__new__(cls) obj = cls.__new__(cls)
else: else:
obj = NodeBase.__new__(NodeBase) obj = ObjectBase.__new__(ObjectBase)
else: else:
obj = NodeBase.__new__(NodeBase) obj = ObjectBase.__new__(ObjectBase)
(<NodeBase>obj).chandle = chandle (<ObjectBase>obj).chandle = chandle
return obj return obj
cdef class NodeBase: cdef class ObjectBase:
cdef void* chandle cdef void* chandle
cdef _set_handle(self, handle): cdef _set_handle(self, handle):
...@@ -53,12 +53,12 @@ cdef class NodeBase: ...@@ -53,12 +53,12 @@ cdef class NodeBase:
self._set_handle(value) self._set_handle(value)
def __dealloc__(self): def __dealloc__(self):
CALL(DGLNodeFree(self.chandle)) CALL(DGLObjectFree(self.chandle))
def __getattr__(self, name): def __getattr__(self, name):
cdef DGLValue ret_val cdef DGLValue ret_val
cdef int ret_type_code, ret_succ cdef int ret_type_code, ret_succ
CALL(DGLNodeGetAttr(self.chandle, c_str(name), CALL(DGLObjectGetAttr(self.chandle, c_str(name),
&ret_val, &ret_type_code, &ret_succ)) &ret_val, &ret_type_code, &ret_succ))
if ret_succ == 0: if ret_succ == 0:
raise AttributeError( raise AttributeError(
...@@ -79,13 +79,13 @@ cdef class NodeBase: ...@@ -79,13 +79,13 @@ cdef class NodeBase:
Note Note
---- ----
We have a special calling convention to call constructor functions. We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object So the return handle is directly set into the Object object
instead of creating a new Node. instead of creating a new Object.
""" """
cdef void* chandle cdef void* chandle
ConstructorCall( ConstructorCall(
(<FunctionBase>fconstructor).chandle, (<FunctionBase>fconstructor).chandle,
kNodeHandle, args, &chandle) kObjectHandle, args, &chandle)
self.chandle = chandle self.chandle = chandle
_set_class_node_base(NodeBase) _set_class_object_base(ObjectBase)
"""Object namespace"""
# pylint: disable=unused-import
from __future__ import absolute_import
import ctypes
import sys
from .. import _api_internal
from .object_generic import ObjectGeneric, convert_to_object
from .base import _LIB, check_call, c_str, py_str, _FFI_MODE
IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError
try:
# pylint: disable=wrong-import-position
if _FFI_MODE == "ctypes":
raise ImportError()
if sys.version_info >= (3, 0):
from ._cy3.core import _register_object, ObjectBase as _ObjectBase
else:
from ._cy2.core import _register_object, ObjectBase as _ObjectBase
except IMPORT_EXCEPT:
# pylint: disable=wrong-import-position
from ._ctypes.object import _register_object, ObjectBase as _ObjectBase
def _new_object(cls):
"""Helper function for pickle"""
return cls.__new__(cls)
class ObjectBase(_ObjectBase):
"""ObjectBase is the base class of all DGL CAPI object."""
def __dir__(self):
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.DGLObjectListAttrNames(
self.handle, ctypes.byref(size), ctypes.byref(plist)))
names = []
for i in range(size.value):
names.append(py_str(plist[i]))
return names
def __hash__(self):
return _api_internal._raw_ptr(self)
def __eq__(self, other):
return self.same_as(other)
def __ne__(self, other):
return not self.__eq__(other)
def __reduce__(self):
cls = type(self)
return (_new_object, (cls, ), self.__getstate__())
def __getstate__(self):
# TODO(minjie): TVM assumes that a Node (Object in DGL) can be serialized
# to json. However, this is not true in DGL because DGL Object is meant
# for runtime API, so it could contain binary data such as NDArray.
# If this feature is required, please raise a RFC to DGL issue.
raise RuntimeError("__getstate__ is not supported for object type")
def __setstate__(self, state):
# pylint: disable=assigning-non-slot
# TODO(minjie): TVM assumes that a Node (Object in DGL) can be serialized
# to json. However, this is not true in DGL because DGL Object is meant
# for runtime API, so it could contain binary data such as NDArray.
# If this feature is required, please raise a RFC to DGL issue.
raise RuntimeError("__setstate__ is not supported for object type")
def same_as(self, other):
"""check object identity equality"""
if not isinstance(other, ObjectBase):
return False
return self.__hash__() == other.__hash__()
def register_object(type_key=None):
"""Decorator used to register object type
Examples
--------
>>> @register_object
>>> class MyObject:
>>> ... pass
Parameters
----------
type_key : str or cls
The type key of the object
"""
object_name = type_key if isinstance(type_key, str) else type_key.__name__
def register(cls):
"""internal register function"""
tindex = ctypes.c_int()
ret = _LIB.DGLObjectTypeKey2Index(c_str(object_name), ctypes.byref(tindex))
if ret == 0:
_register_object(tindex.value, cls)
return cls
if isinstance(type_key, str):
return register
return register(type_key)
"""Common implementation of Object generic related logic"""
# pylint: disable=unused-import
from __future__ import absolute_import
from numbers import Number, Integral
from .. import _api_internal
from .base import string_types
# Object base class
_CLASS_OBJECT_BASE = None
def _set_class_object_base(cls):
global _CLASS_OBJECT_BASE
_CLASS_OBJECT_BASE = cls
class ObjectGeneric(object):
"""Base class for all classes that can be converted to object."""
def asobject(self):
"""Convert value to object"""
raise NotImplementedError()
def convert_to_object(value):
"""Convert a python value to corresponding object type.
Parameters
----------
value : str
The value to be inspected.
Returns
-------
object : Object
The corresponding object value.
"""
if isinstance(value, _CLASS_OBJECT_BASE):
return value
if isinstance(value, (list, tuple)):
value = [convert_to_object(x) for x in value]
return _api_internal._List(*value)
if isinstance(value, dict):
vlist = []
for item in value.items():
if (not isinstance(item[0], _CLASS_OBJECT_BASE) and
not isinstance(item[0], string_types)):
raise ValueError("key of map must already been a container type")
vlist.append(item[0])
vlist.append(convert_to_object(item[1]))
return _api_internal._Map(*vlist)
if isinstance(value, ObjectGeneric):
return value.asobject()
return _api_internal._Value(value)
...@@ -20,7 +20,7 @@ class TypeCode(object): ...@@ -20,7 +20,7 @@ class TypeCode(object):
DGL_TYPE = 5 DGL_TYPE = 5
DGL_CONTEXT = 6 DGL_CONTEXT = 6
ARRAY_HANDLE = 7 ARRAY_HANDLE = 7
NODE_HANDLE = 8 OBJECT_HANDLE = 8
MODULE_HANDLE = 9 MODULE_HANDLE = 9
FUNC_HANDLE = 10 FUNC_HANDLE = 10
STR = 11 STR = 11
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment