Unverified Commit 07dc8fb6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4800)



* clang-format

* manul

* manul

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 9d9280cb
......@@ -32,8 +32,7 @@ DGL_DLL int DGLObjectFree(ObjectHandle handle);
* \param out_index the corresponding type index.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLObjectTypeKey2Index(const char* type_key,
int* out_index);
DGL_DLL int DGLObjectTypeKey2Index(const char* type_key, int* out_index);
/*!
* \brief Get runtime type index of the object.
......@@ -41,8 +40,7 @@ DGL_DLL int DGLObjectTypeKey2Index(const char* type_key,
* \param out_index the corresponding type index.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLObjectGetTypeIndex(ObjectHandle handle,
int* out_index);
DGL_DLL int DGLObjectGetTypeIndex(ObjectHandle handle, int* out_index);
/*!
* \brief get attributes given key
......@@ -54,11 +52,9 @@ DGL_DLL int DGLObjectGetTypeIndex(ObjectHandle handle,
* \return 0 when success, -1 when failure happens
* \note API calls always exchanges with type bits=64, lanes=1
*/
DGL_DLL int DGLObjectGetAttr(ObjectHandle handle,
const char* key,
DGLValue* out_value,
int* out_type_code,
int* out_success);
DGL_DLL int DGLObjectGetAttr(
ObjectHandle handle, const char* key, DGLValue* out_value,
int* out_type_code, int* out_success);
/*!
* \brief get attributes names in the object.
......@@ -67,9 +63,8 @@ DGL_DLL int DGLObjectGetAttr(ObjectHandle handle,
* \param out_array The array of function names.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLObjectListAttrNames(ObjectHandle handle,
int *out_size,
const char*** out_array);
DGL_DLL int DGLObjectListAttrNames(
ObjectHandle handle, int* out_size, const char*** out_array);
#ifdef __cplusplus
} // DGL_EXTERN_C
#endif
......
......@@ -38,8 +38,8 @@
#ifdef __cplusplus
extern "C" {
#endif
#include <stdint.h>
#include <stddef.h>
#include <stdint.h>
/*! \brief type of array index. */
typedef int64_t dgl_index_t;
......@@ -60,7 +60,8 @@ typedef enum {
} DGLDeviceType;
/*!
* \brief The object type code is used in DGL FFI to indicate the types of objects passed between C and Python.
* \brief The object type code is used in DGL FFI to indicate the types of
* objects passed between C and Python.
*/
typedef enum {
kInt = 0U,
......@@ -105,9 +106,9 @@ typedef enum {
} DGLDataTypeCode;
/*!
* \brief The data type the tensor can hold. The data type is assumed to follow the
* native endian-ness. An explicit error message should be raised when attempting to
* export an array with non-native endianness
* \brief The data type the tensor can hold. The data type is assumed to follow
* the native endian-ness. An explicit error message should be raised when
* attempting to export an array with non-native endianness
*
* Examples
* - float: type_code = 2, bits = 32, lanes=1
......@@ -149,12 +150,12 @@ typedef struct {
typedef struct {
/*!
* \brief The data pointer points to the allocated data.
*
*
* Depending on the device context, it can be a CPU pointer, or a CUDA
* device pointer or acl_mem handle in OpenCL.
* This pointer is always aligned to 256 bytes as in CUDA. Use the
* `byte_offset` field to mark the beginning of the actual data (if the address
* is not 256 byte aligned).
* `byte_offset` field to mark the beginning of the actual data (if the
* address is not 256 byte aligned).
*
* Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow,
* TVM, perhaps others) do not adhere to this 256 byte alignment requirement
......@@ -247,7 +248,7 @@ DGL_DLL void DGLAPISetLastError(const char* msg);
* this function is threadsafe and can be called by different thread
* \return error info
*/
DGL_DLL const char *DGLGetLastError(void);
DGL_DLL const char* DGLGetLastError(void);
/*!
* \brief Load module from file.
* \param file_name The file name to load the module from.
......@@ -258,9 +259,8 @@ DGL_DLL const char *DGLGetLastError(void);
* \note The resulting module do not contain import relation.
* It can be reconstructed by DGLModImport.
*/
DGL_DLL int DGLModLoadFromFile(const char* file_name,
const char* format,
DGLModuleHandle* out);
DGL_DLL int DGLModLoadFromFile(
const char* file_name, const char* format, DGLModuleHandle* out);
/*!
* \brief Add dep to mod's dependency.
......@@ -270,8 +270,7 @@ DGL_DLL int DGLModLoadFromFile(const char* file_name,
* \param dep The dependent module to be imported.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLModImport(DGLModuleHandle mod,
DGLModuleHandle dep);
DGL_DLL int DGLModImport(DGLModuleHandle mod, DGLModuleHandle dep);
/*!
* \brief Get function from the module.
......@@ -281,10 +280,9 @@ DGL_DLL int DGLModImport(DGLModuleHandle mod,
* \param out The result function, can be NULL if it is not available.
* \return 0 when no error is thrown, -1 when failure happens
*/
DGL_DLL int DGLModGetFunction(DGLModuleHandle mod,
const char* func_name,
int query_imports,
DGLFunctionHandle *out);
DGL_DLL int DGLModGetFunction(
DGLModuleHandle mod, const char* func_name, int query_imports,
DGLFunctionHandle* out);
/*!
* \brief Free front-end extension type resource.
......@@ -334,12 +332,9 @@ DGL_DLL int DGLFuncFree(DGLFunctionHandle func);
* The front-end need to call free function (e.g. DGLFuncFree)
* to free these handles.
*/
DGL_DLL int DGLFuncCall(DGLFunctionHandle func,
DGLValue* arg_values,
int* type_codes,
int num_args,
DGLValue* ret_val,
int* ret_type_code);
DGL_DLL int DGLFuncCall(
DGLFunctionHandle func, DGLValue* arg_values, int* type_codes, int num_args,
DGLValue* ret_val, int* ret_type_code);
/*!
* \brief Set the return value of DGLPackedCFunc.
......@@ -352,10 +347,8 @@ DGL_DLL int DGLFuncCall(DGLFunctionHandle func,
* \param type_code The type of the value to be returned.
* \param num_ret Number of return values, for now only 1 is supported.
*/
DGL_DLL int DGLCFuncSetReturn(DGLRetValueHandle ret,
DGLValue* value,
int* type_code,
int num_ret);
DGL_DLL int DGLCFuncSetReturn(
DGLRetValueHandle ret, DGLValue* value, int* type_code, int num_ret);
/*!
* \brief Inplace translate callback argument value to return value.
......@@ -377,14 +370,12 @@ DGL_DLL int DGLCbArgToReturn(DGLValue* value, int code);
* \param num_args Number of arguments.
* \param ret The return value handle.
* \param resource_handle The handle additional resouce handle from fron-end.
* \return 0 if success, -1 if failure happens, set error via DGLAPISetLastError.
* \return 0 if success, -1 if failure happens, set error via
* DGLAPISetLastError.
* \sa DGLCFuncSetReturn
*/
typedef int (*DGLPackedCFunc)(
DGLValue* args,
int* type_codes,
int num_args,
DGLRetValueHandle ret,
DGLValue* args, int* type_codes, int num_args, DGLRetValueHandle ret,
void* resource_handle);
/*!
......@@ -407,18 +398,19 @@ typedef int (*DGLExtensionFuncDeclarer)(DGLFunctionHandle register_func_handle);
/*!
* \brief Wrap a DGLPackedCFunc to become a FunctionHandle.
*
* The resource_handle will be managed by DGL API, until the function is no longer used.
* The resource_handle will be managed by DGL API, until the function is no
* longer used.
*
* \param func The packed C function.
* \param resource_handle The resource handle from front-end, can be NULL.
* \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
* \param fin The finalizer on resource handle when the FunctionHandle get
* freed, can be NULL.
* \param out the result function handle.
* \return 0 when success, -1 when failure happens
* \return 0 when success, -1 when failure happens.
*/
DGL_DLL int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
void* resource_handle,
DGLPackedCFuncFinalizer fin,
DGLFunctionHandle *out);
DGL_DLL int DGLFuncCreateFromCFunc(
DGLPackedCFunc func, void* resource_handle, DGLPackedCFuncFinalizer fin,
DGLFunctionHandle* out);
/*!
* \brief Register the function to runtime's global table.
......@@ -449,8 +441,7 @@ DGL_DLL int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out);
* \param out_array The array of function names.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLFuncListGlobalNames(int* out_size,
const char*** out_array);
DGL_DLL int DGLFuncListGlobalNames(int* out_size, const char*** out_array);
// Array related apis for quick proptyping
/*!
......@@ -467,14 +458,9 @@ DGL_DLL int DGLFuncListGlobalNames(int* out_size,
* \param out The output handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayAlloc(const dgl_index_t* shape,
int ndim,
int dtype_code,
int dtype_bits,
int dtype_lanes,
int device_type,
int device_id,
DGLArrayHandle* out);
DGL_DLL int DGLArrayAlloc(
const dgl_index_t* shape, int ndim, int dtype_code, int dtype_bits,
int dtype_lanes, int device_type, int device_id, DGLArrayHandle* out);
/*!
* \brief Allocate a nd-array's with shared memory,
......@@ -490,14 +476,9 @@ DGL_DLL int DGLArrayAlloc(const dgl_index_t* shape,
* \param out The output handle.
* \return 0 when success, -1 when failure happens
*/
int DGLArrayAllocSharedMem(const char *mem_name,
const dgl_index_t *shape,
int ndim,
int dtype_code,
int dtype_bits,
int dtype_lanes,
bool is_create,
DGLArrayHandle* out);
int DGLArrayAllocSharedMem(
const char* mem_name, const dgl_index_t* shape, int ndim, int dtype_code,
int dtype_bits, int dtype_lanes, bool is_create, DGLArrayHandle* out);
/*!
* \brief Free the DGL Array.
......@@ -513,9 +494,8 @@ DGL_DLL int DGLArrayFree(DGLArrayHandle handle);
* \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayCopyFromBytes(DGLArrayHandle handle,
void* data,
size_t nbytes);
DGL_DLL int DGLArrayCopyFromBytes(
DGLArrayHandle handle, void* data, size_t nbytes);
/*!
* \brief Copy array data to CPU byte array.
......@@ -524,9 +504,8 @@ DGL_DLL int DGLArrayCopyFromBytes(DGLArrayHandle handle,
* \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayCopyToBytes(DGLArrayHandle handle,
void* data,
size_t nbytes);
DGL_DLL int DGLArrayCopyToBytes(
DGLArrayHandle handle, void* data, size_t nbytes);
/*!
* \brief Copy the array, both from and to must be valid during the copy.
......@@ -534,8 +513,7 @@ DGL_DLL int DGLArrayCopyToBytes(DGLArrayHandle handle,
* \param to The target space.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from,
DGLArrayHandle to);
DGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from, DGLArrayHandle to);
/*!
* \brief Create a new runtime stream.
......@@ -545,7 +523,8 @@ DGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from,
* \param out The new stream handle
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out);
DGL_DLL int DGLStreamCreate(
int device_type, int device_id, DGLStreamHandle* out);
/*!
* \brief Free a created stream handle.
......@@ -555,7 +534,8 @@ DGL_DLL int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out
* \param stream The stream to be freed
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream);
DGL_DLL int DGLStreamFree(
int device_type, int device_id, DGLStreamHandle stream);
/*!
* \brief Set the runtime stream of current thread to be stream.
......@@ -568,7 +548,8 @@ DGL_DLL int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream
* \param handle The stream handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLSetStream(int device_type, int device_id, DGLStreamHandle handle);
DGL_DLL int DGLSetStream(
int device_type, int device_id, DGLStreamHandle handle);
/*!
* \brief Get the runtime stream of current thread.
......@@ -578,7 +559,8 @@ DGL_DLL int DGLSetStream(int device_type, int device_id, DGLStreamHandle handle)
* \param handle The stream handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLGetStream(int device_type, int device_id, DGLStreamHandle* handle);
DGL_DLL int DGLGetStream(
int device_type, int device_id, DGLStreamHandle* handle);
/*!
* \brief Wait until all computations on stream completes.
......@@ -588,7 +570,8 @@ DGL_DLL int DGLGetStream(int device_type, int device_id, DGLStreamHandle* handle
* \param stream The stream to be synchronized.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream);
DGL_DLL int DGLSynchronize(
int device_type, int device_id, DGLStreamHandle stream);
/*!
* \brief Synchronize two streams of execution.
......@@ -599,16 +582,14 @@ DGL_DLL int DGLSynchronize(int device_type, int device_id, DGLStreamHandle strea
* \param dst The destination stream to synchronize.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLStreamStreamSynchronize(int device_type,
int device_id,
DGLStreamHandle src,
DGLStreamHandle dst);
DGL_DLL int DGLStreamStreamSynchronize(
int device_type, int device_id, DGLStreamHandle src, DGLStreamHandle dst);
/*!
* \brief Load tensor adapter.
* \return 0 when success, -1 when failure happens.
*/
DGL_DLL int DGLLoadTensorAdapter(const char *path);
DGL_DLL int DGLLoadTensorAdapter(const char* path);
/*!
* \brief Pin host memory.
......@@ -628,17 +609,18 @@ int DGLArrayRecordStream(DGLArrayHandle handle, DGLStreamHandle stream);
/*!
* \brief Bug report macro.
*
* This serves as a sanity check on system side to make sure the code is correct by
* checking whether a condition always holds for complex reasons. Failing the
* condition signifies a system bug instead of users giving invalid inputs or using
* the functionality incorrectly.
* This serves as a sanity check on system side to make sure the code is correct
* by checking whether a condition always holds for complex reasons. Failing
* the condition signifies a system bug instead of users giving invalid inputs
* or using the functionality incorrectly.
*
* Hints the user to file a bug report if the condition fails.
*/
#define BUG_IF_FAIL(cond) \
CHECK(cond) << "A bug has been occurred. " \
"Please file a bug report at https://github.com/dmlc/dgl/issues. " \
"Message: "
#define BUG_IF_FAIL(cond) \
CHECK(cond) \
<< "A bug has been occurred. " \
"Please file a bug report at https://github.com/dmlc/dgl/issues. " \
"Message: "
#ifdef __cplusplus
} // DGL_EXTERN_C
......
......@@ -4,7 +4,6 @@
* \brief DGL runtime config
*/
#ifndef DGL_RUNTIME_CONFIG_H_
#define DGL_RUNTIME_CONFIG_H_
......@@ -23,7 +22,7 @@ class Config {
bool IsLibxsmmAvailable() const;
private:
Config() = default;
Config() = default;
bool libxsmm_ = true;
};
......
......@@ -6,11 +6,12 @@
#ifndef DGL_RUNTIME_CONTAINER_H_
#define DGL_RUNTIME_CONTAINER_H_
#include <unordered_map>
#include <vector>
#include <memory>
#include <utility>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "object.h"
#include "packed_func.h"
......@@ -44,7 +45,7 @@ inline std::shared_ptr<ValueObject> MakeValue(T&& val) {
class Value : public ObjectRef {
public:
Value() {}
explicit Value(std::shared_ptr<Object> o): ObjectRef(o) {}
explicit Value(std::shared_ptr<Object> o) : ObjectRef(o) {}
const ValueObject* operator->() const {
return static_cast<const ValueObject*>(obj_.get());
......@@ -60,7 +61,7 @@ class ListObject : public Object {
std::vector<std::shared_ptr<Object> > data;
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to list have no effect.
// Visitor to list have no effect.
}
static constexpr const char* _type_key = "List";
......@@ -71,7 +72,7 @@ class ListObject : public Object {
class MapObject : public Object {
public:
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to map have no effect.
// Visitor to map have no effect.
}
// hash function
struct Hash {
......@@ -90,9 +91,7 @@ class MapObject : public Object {
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<
std::shared_ptr<Object>,
std::shared_ptr<Object>,
Hash, Equal>;
std::shared_ptr<Object>, std::shared_ptr<Object>, Hash, Equal>;
/*! \brief the data content */
ContainerType data;
......@@ -101,17 +100,15 @@ class MapObject : public Object {
DGL_DECLARE_OBJECT_TYPE_INFO(MapObject, Object);
};
/*! \brief specialized map obj with string as key */
class StrMapObject : public Object {
public:
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to map have no effect.
// Visitor to map have no effect.
}
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<
std::string,
std::shared_ptr<Object> >;
using ContainerType =
std::unordered_map<std::string, std::shared_ptr<Object> >;
/*! \brief the data content */
ContainerType data;
......@@ -125,8 +122,7 @@ class StrMapObject : public Object {
* \tparam Converter a struct that contains converting function
* \tparam TIter the content iterator type.
*/
template<typename Converter,
typename TIter>
template <typename Converter, typename TIter>
class IterAdapter {
public:
explicit IterAdapter(TIter iter) : iter_(iter) {}
......@@ -144,9 +140,7 @@ class IterAdapter {
inline bool operator==(IterAdapter other) const {
return iter_ == other.iter_;
}
inline bool operator!=(IterAdapter other) const {
return !(*this == other);
}
inline bool operator!=(IterAdapter other) const { return !(*this == other); }
inline const typename Converter::ResultType operator*() const {
return Converter::convert(*iter_);
}
......@@ -175,7 +169,7 @@ class IterAdapter {
* <code>
* error: no type named 'type' in 'struct std::enable_if<false, void>'
* </code>
*
*
* Example:
*
* <code>
......@@ -183,31 +177,32 @@ class IterAdapter {
* // List<NDArray> list2; // fails
* List<Value> list; // works
* list.push_back(Value(MakeValue(1))); // works
* list.push_back(Value(MakeValue(NDArray::Empty(shape, dtype, ctx)))); // works
* list.push_back(Value(MakeValue(NDArray::Empty(shape, dtype, ctx)))); //
* works
* </code>
*/
template<typename T,
typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type >
template <
typename T,
typename =
typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
class List : public ObjectRef {
public:
/*!
* \brief default constructor
*/
List() {
obj_ = std::make_shared<ListObject>();
}
List() { obj_ = std::make_shared<ListObject>(); }
/*!
* \brief move constructor
* \param other source
*/
List(List<T> && other) { // NOLINT(*)
List(List<T>&& other) { // NOLINT(*)
obj_ = std::move(other.obj_);
}
/*!
* \brief copy constructor
* \param other source
*/
List(const List<T> &other) : ObjectRef(other.obj_) { // NOLINT(*)
List(const List<T>& other) : ObjectRef(other.obj_) { // NOLINT(*)
}
/*!
* \brief constructor from pointer
......@@ -220,7 +215,7 @@ class List : public ObjectRef {
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
template <typename IterType>
List(IterType begin, IterType end) {
assign(begin, end);
}
......@@ -228,20 +223,19 @@ class List : public ObjectRef {
* \brief constructor from initializer list
* \param init The initalizer list
*/
List(std::initializer_list<T> init) { // NOLINT(*)
List(std::initializer_list<T> init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
List(const std::vector<T>& init) { // NOLINT(*)
List(const std::vector<T>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief Constructs a container with n elements. Each element is a copy of val
* \param n The size of the container
* \param val The init value
* \brief Constructs a container with n elements. Each element is a copy of
* val \param n The size of the container \param val The init value
*/
explicit List(size_t n, const T& val) {
auto tmp_obj = std::make_shared<ListObject>();
......@@ -255,7 +249,7 @@ class List : public ObjectRef {
* \param other The source of assignment
* \return reference to self.
*/
List<T>& operator=(List<T> && other) {
List<T>& operator=(List<T>&& other) {
obj_ = std::move(other.obj_);
return *this;
}
......@@ -264,7 +258,7 @@ class List : public ObjectRef {
* \param other The source of assignment
* \return reference to self.
*/
List<T>& operator=(const List<T> & other) {
List<T>& operator=(const List<T>& other) {
obj_ = other.obj_;
return *this;
}
......@@ -274,7 +268,7 @@ class List : public ObjectRef {
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
template <typename IterType>
void assign(IterType begin, IterType end) {
auto n = std::make_shared<ListObject>();
for (IterType it = begin; it != end; ++it) {
......@@ -304,7 +298,7 @@ class List : public ObjectRef {
* \return Handle to the internal obj container(which ganrantees to be unique)
*/
inline ListObject* CopyOnWrite() {
if (obj_.get() == nullptr || !obj_.unique()) {
if (obj_.get() == nullptr || !obj_.unique()) {
obj_ = std::make_shared<ListObject>(
*static_cast<const ListObject*>(obj_.get()));
}
......@@ -328,9 +322,7 @@ class List : public ObjectRef {
n->data[i] = value.obj_;
}
/*! \return whether list is empty */
inline bool empty() const {
return size() == 0;
}
inline bool empty() const { return size() == 0; }
/*! \brief Copy the content to a vector */
inline std::vector<T> ToVector() const {
return std::vector<T>(begin(), end());
......@@ -340,16 +332,14 @@ class List : public ObjectRef {
struct Ptr2ObjectRef {
using ResultType = T;
static inline T convert(const std::shared_ptr<Object>& n) {
return T(n);
}
static inline T convert(const std::shared_ptr<Object>& n) { return T(n); }
};
using iterator = IterAdapter<Ptr2ObjectRef,
std::vector<std::shared_ptr<Object> >::const_iterator>;
using iterator = IterAdapter<
Ptr2ObjectRef, std::vector<std::shared_ptr<Object> >::const_iterator>;
using reverse_iterator = IterAdapter<
Ptr2ObjectRef,
std::vector<std::shared_ptr<Object> >::const_reverse_iterator>;
Ptr2ObjectRef,
std::vector<std::shared_ptr<Object> >::const_reverse_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
......@@ -361,11 +351,13 @@ class List : public ObjectRef {
}
/*! \return rbegin iterator */
inline reverse_iterator rbegin() const {
return reverse_iterator(static_cast<const ListObject*>(obj_.get())->data.rbegin());
return reverse_iterator(
static_cast<const ListObject*>(obj_.get())->data.rbegin());
}
/*! \return rend iterator */
inline reverse_iterator rend() const {
return reverse_iterator(static_cast<const ListObject*>(obj_.get())->data.rend());
return reverse_iterator(
static_cast<const ListObject*>(obj_.get())->data.rend());
}
};
......@@ -390,7 +382,7 @@ class List : public ObjectRef {
* <code>
* error: no type named 'type' in 'struct std::enable_if<false, void>'
* </code>
*
*
* Example:
*
* <code>
......@@ -398,35 +390,35 @@ class List : public ObjectRef {
* // Map<std::string, NDArray> map2; // fails
* Map<std::string, Value> map; // works
* map.Set("key1", Value(MakeValue(1))); // works
* map.Set("key2", Value(MakeValue(NDArray::Empty(shape, dtype, ctx)))); // works
* map.Set("key2", Value(MakeValue(NDArray::Empty(shape, dtype, ctx)))); //
* works
* </code>
*/
template<typename K,
typename V,
typename = typename std::enable_if<
std::is_base_of<ObjectRef, K>::value ||
std::is_base_of<std::string, K>::value >::type,
typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
template <
typename K, typename V,
typename = typename std::enable_if<
std::is_base_of<ObjectRef, K>::value ||
std::is_base_of<std::string, K>::value>::type,
typename =
typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
class Map : public ObjectRef {
public:
/*!
* \brief default constructor
*/
Map() {
obj_ = std::make_shared<MapObject>();
}
Map() { obj_ = std::make_shared<MapObject>(); }
/*!
* \brief move constructor
* \param other source
*/
Map(Map<K, V> && other) { // NOLINT(*)
Map(Map<K, V>&& other) { // NOLINT(*)
obj_ = std::move(other.obj_);
}
/*!
* \brief copy constructor
* \param other source
*/
Map(const Map<K, V> &other) : ObjectRef(other.obj_) { // NOLINT(*)
Map(const Map<K, V>& other) : ObjectRef(other.obj_) { // NOLINT(*)
}
/*!
* \brief constructor from pointer
......@@ -439,7 +431,7 @@ class Map : public ObjectRef {
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
template <typename IterType>
Map(IterType begin, IterType end) {
assign(begin, end);
}
......@@ -447,15 +439,15 @@ class Map : public ObjectRef {
* \brief constructor from initializer list
* \param init The initalizer list
*/
Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*)
Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
template<typename Hash, typename Equal>
Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
template <typename Hash, typename Equal>
Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
......@@ -463,7 +455,7 @@ class Map : public ObjectRef {
* \param other The source of assignment
* \return reference to self.
*/
Map<K, V>& operator=(Map<K, V> && other) {
Map<K, V>& operator=(Map<K, V>&& other) {
obj_ = std::move(other.obj_);
return *this;
}
......@@ -472,7 +464,7 @@ class Map : public ObjectRef {
* \param other The source of assignment
* \return reference to self.
*/
Map<K, V>& operator=(const Map<K, V> & other) {
Map<K, V>& operator=(const Map<K, V>& other) {
obj_ = other.obj_;
return *this;
}
......@@ -482,12 +474,11 @@ class Map : public ObjectRef {
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
template <typename IterType>
void assign(IterType begin, IterType end) {
auto n = std::shared_ptr<MapObject>();
for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first.obj_,
i->second.obj_));
n->data.emplace(std::make_pair(i->first.obj_, i->second.obj_));
}
obj_ = std::move(n);
}
......@@ -526,7 +517,7 @@ class Map : public ObjectRef {
* \return Handle to the internal obj container(which ganrantees to be unique)
*/
inline MapObject* CopyOnWrite() {
if (obj_.get() == nullptr || !obj_.unique()) {
if (obj_.get() == nullptr || !obj_.unique()) {
obj_ = std::make_shared<MapObject>(
*static_cast<const MapObject*>(obj_.get()));
}
......@@ -543,23 +534,20 @@ class Map : public ObjectRef {
}
/*! \return whether list is empty */
inline bool empty() const {
return size() == 0;
}
inline bool empty() const { return size() == 0; }
/*! \brief specify container obj */
using ContainerType = MapObject;
struct Ptr2ObjectRef {
using ResultType = std::pair<K, V>;
static inline ResultType convert(const std::pair<
std::shared_ptr<Object>,
std::shared_ptr<Object> >& n) {
static inline ResultType convert(
const std::pair<std::shared_ptr<Object>, std::shared_ptr<Object> >& n) {
return std::make_pair(K(n.first), V(n.second));
}
};
using iterator = IterAdapter<
Ptr2ObjectRef, MapObject::ContainerType::const_iterator>;
using iterator =
IterAdapter<Ptr2ObjectRef, MapObject::ContainerType::const_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
......@@ -571,50 +559,49 @@ class Map : public ObjectRef {
}
/*! \return begin iterator */
inline iterator find(const K& key) const {
return iterator(static_cast<const MapObject*>(obj_.get())->data.find(key.obj_));
return iterator(
static_cast<const MapObject*>(obj_.get())->data.find(key.obj_));
}
};
// specialize of string map
template<typename V, typename T1, typename T2>
template <typename V, typename T1, typename T2>
class Map<std::string, V, T1, T2> : public ObjectRef {
public:
// for code reuse
Map() {
obj_ = std::make_shared<StrMapObject>();
}
Map(Map<std::string, V> && other) { // NOLINT(*)
Map() { obj_ = std::make_shared<StrMapObject>(); }
Map(Map<std::string, V>&& other) { // NOLINT(*)
obj_ = std::move(other.obj_);
}
Map(const Map<std::string, V> &other) : ObjectRef(other.obj_) { // NOLINT(*)
Map(const Map<std::string, V>& other) : ObjectRef(other.obj_) { // NOLINT(*)
}
explicit Map(std::shared_ptr<Object> n) : ObjectRef(n) {}
template<typename IterType>
template <typename IterType>
Map(IterType begin, IterType end) {
assign(begin, end);
}
Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*)
Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*)
assign(init.begin(), init.end());
}
template<typename Hash, typename Equal>
Map(const std::unordered_map<std::string, V, Hash, Equal>& init) { // NOLINT(*)
template <typename Hash, typename Equal>
Map(const std::unordered_map<std::string, V, Hash, Equal>&
init) { // NOLINT(*)
assign(init.begin(), init.end());
}
Map<std::string, V>& operator=(Map<std::string, V> && other) {
Map<std::string, V>& operator=(Map<std::string, V>&& other) {
obj_ = std::move(other.obj_);
return *this;
}
Map<std::string, V>& operator=(const Map<std::string, V> & other) {
Map<std::string, V>& operator=(const Map<std::string, V>& other) {
obj_ = other.obj_;
return *this;
}
template<typename IterType>
template <typename IterType>
void assign(IterType begin, IterType end) {
auto n = std::make_shared<StrMapObject>();
for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first,
i->second.obj_));
n->data.emplace(std::make_pair(i->first, i->second.obj_));
}
obj_ = std::move(n);
}
......@@ -633,7 +620,7 @@ class Map<std::string, V, T1, T2> : public ObjectRef {
return static_cast<const StrMapObject*>(obj_.get())->data.count(key);
}
inline StrMapObject* CopyOnWrite() {
if (obj_.get() == nullptr || !obj_.unique()) {
if (obj_.get() == nullptr || !obj_.unique()) {
obj_ = std::make_shared<MapObject>(
*static_cast<const MapObject*>(obj_.get()));
}
......@@ -643,22 +630,19 @@ class Map<std::string, V, T1, T2> : public ObjectRef {
StrMapObject* n = this->CopyOnWrite();
n->data[key] = value.obj_;
}
inline bool empty() const {
return size() == 0;
}
inline bool empty() const { return size() == 0; }
using ContainerType = StrMapObject;
struct Ptr2ObjectRef {
using ResultType = std::pair<std::string, V>;
static inline ResultType convert(const std::pair<
std::string,
std::shared_ptr<Object> >& n) {
static inline ResultType convert(
const std::pair<std::string, std::shared_ptr<Object> >& n) {
return std::make_pair(n.first, V(n.second));
}
};
using iterator = IterAdapter<
Ptr2ObjectRef, StrMapObject::ContainerType::const_iterator>;
using iterator =
IterAdapter<Ptr2ObjectRef, StrMapObject::ContainerType::const_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
......@@ -670,7 +654,8 @@ class Map<std::string, V, T1, T2> : public ObjectRef {
}
/*! \return begin iterator */
inline iterator find(const std::string& key) const {
return iterator(static_cast<const StrMapObject*>(obj_.get())->data.find(key));
return iterator(
static_cast<const StrMapObject*>(obj_.get())->data.find(key));
}
};
......
......@@ -7,8 +7,9 @@
#define DGL_RUNTIME_DEVICE_API_H_
#include <string>
#include "packed_func.h"
#include "c_runtime_api.h"
#include "packed_func.h"
namespace dgl {
namespace runtime {
......@@ -30,7 +31,8 @@ enum DeviceAttrKind : int {
/*! \brief Number of bytes each allocation must align to */
constexpr int kAllocAlignment = 64;
/*! \brief Number of bytes each allocation must align to in temporary allocation */
/*! \brief Number of bytes each allocation must align to in temporary allocation
*/
constexpr int kTempAllocaAlignment = 64;
/*! \brief Maximum size that can be allocated on stack */
......@@ -47,9 +49,7 @@ class DeviceAPI {
/*!
* \brief Check whether the device is available.
*/
virtual bool IsAvailable() {
return true;
}
virtual bool IsAvailable() { return true; }
/*!
* \brief Set the environment device id to ctx
* \param ctx The context to be set.
......@@ -62,7 +62,8 @@ class DeviceAPI {
* \param rv The return value.
* \sa DeviceAttrKind
*/
virtual void GetAttr(DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) = 0;
virtual void GetAttr(
DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) = 0;
/*!
* \brief Allocate a data space on device.
* \param ctx The device context to perform operation.
......@@ -72,10 +73,9 @@ class DeviceAPI {
* as OpenGL, as nbytes & alignment are sufficient for most backends.
* \return The allocated device pointer.
*/
virtual void* AllocDataSpace(DGLContext ctx,
size_t nbytes,
size_t alignment,
DGLDataType type_hint) = 0;
virtual void* AllocDataSpace(
DGLContext ctx, size_t nbytes, size_t alignment,
DGLDataType type_hint) = 0;
/*!
* \brief Free a data space on device.
* \param ctx The device context to perform operation.
......@@ -94,15 +94,11 @@ class DeviceAPI {
* \param type_hint The type of elements, only neded by certain backends.
* can be useful for cross device endian converison.
*/
virtual void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t num_bytes,
DGLContext ctx_from,
DGLContext ctx_to,
DGLDataType type_hint) = 0;
/*!
virtual void CopyDataFromTo(
const void* from, size_t from_offset, void* to, size_t to_offset,
size_t num_bytes, DGLContext ctx_from, DGLContext ctx_to,
DGLDataType type_hint) = 0;
/*!
* \brief Create a new stream of execution.
*
* \param ctx The context of allocation.
......@@ -145,31 +141,28 @@ class DeviceAPI {
* \param event_src The source stream to synchronize.
* \param event_dst The destination stream to synchronize.
*/
DGL_DLL virtual void SyncStreamFromTo(DGLContext ctx,
DGLStreamHandle event_src,
DGLStreamHandle event_dst);
DGL_DLL virtual void SyncStreamFromTo(
DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst);
/*!
* \brief Pin host memory using cudaHostRegister().
*
* \param ptr The host memory pointer to be pinned.
* \param nbytes The size to be pinned.
*/
* \param nbytes The size to be pinned.
*/
DGL_DLL virtual void PinData(void* ptr, size_t nbytes);
/*!
* \brief Unpin host memory using cudaHostUnregister().
*
* \param ptr The host memory pointer to be unpinned.
*/
* \param ptr The host memory pointer to be unpinned.
*/
DGL_DLL virtual void UnpinData(void* ptr);
/*!
* \brief Check whether the memory is in pinned memory.
*/
DGL_DLL virtual bool IsPinned(const void* ptr) {
return false;
}
DGL_DLL virtual bool IsPinned(const void* ptr) { return false; }
/*!
* \brief Allocate temporal workspace for backend execution.
......@@ -180,16 +173,16 @@ class DeviceAPI {
* - Only a few allocation will happen, and space will be released after use.
* - The release order is usually in reverse order of allocate (stack style).
* - Repeative pattern of same allocations over different runs.
* - Workspace should not overlap between different threads(i.e. be threadlocal)
* - Workspace should not overlap between different threads(i.e. be
* threadlocal)
*
* \param ctx The context of allocation.
* \param nbytes The size to be allocated.
* \param type_hint The type of elements. Only needed by certain backends such
* as OpenGL, as nbytes is sufficient for most backends.
*/
DGL_DLL virtual void* AllocWorkspace(DGLContext ctx,
size_t nbytes,
DGLDataType type_hint = {});
DGL_DLL virtual void* AllocWorkspace(
DGLContext ctx, size_t nbytes, DGLDataType type_hint = {});
/*!
* \brief Free temporal workspace in backend execution.
*
......@@ -206,14 +199,14 @@ class DeviceAPI {
*/
DGL_DLL static DeviceAPI* Get(DGLContext ctx, bool allow_missing = false);
/*!
* \brief Get device API based on context.
* \param dev_type The device type
* \param allow_missing Whether allow missing
* \return The corresponding device API.
*/
DGL_DLL static DeviceAPI* Get(DGLDeviceType dev_type, bool allow_missing = false);
DGL_DLL static DeviceAPI* Get(
DGLDeviceType dev_type, bool allow_missing = false);
};
/*! \brief The device type bigger than this is RPC device */
......
......@@ -31,10 +31,10 @@ struct DLPackConvert {
/*!
* \brief Deleter for NDArray converted from DLPack.
*
* This is used from data which is passed from external DLPack(DLManagedTensor)
* that are not allocated inside of DGL.
* This enables us to create NDArray from memory allocated by other
* frameworks that are DLPack compatible
* This is used from data which is passed from external
* DLPack(DLManagedTensor) that are not allocated inside of DGL. This enables
* us to create NDArray from memory allocated by other frameworks that are
* DLPack compatible
*/
static void DLPackDeleter(NDArray::Container* ptr);
......@@ -43,7 +43,7 @@ struct DLPackConvert {
* \param from The DGL NDArray.
* \return A DLPack tensor.
*/
static DLManagedTensor* ToDLPack(const NDArray &from);
static DLManagedTensor* ToDLPack(const NDArray& from);
};
} // namespace runtime
......@@ -66,8 +66,7 @@ DGL_DLL void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor);
* \param out The output array handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from,
DGLArrayHandle* out);
DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from, DGLArrayHandle* out);
/*!
* \brief Produce a DLMangedTensor from the array that shares data memory with
......@@ -76,8 +75,8 @@ DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from,
* \param out The DLManagedTensor handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out,
int alignment = 0);
DGL_DLL int DGLArrayToDLPack(
DGLArrayHandle from, DLManagedTensor** out, int alignment = 0);
#ifdef __cplusplus
} // DGL_EXTERN_C
......
......@@ -9,10 +9,12 @@
#define DGL_RUNTIME_MODULE_H_
#include <dmlc/io.h>
#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include <vector>
#include "c_runtime_api.h"
namespace dgl {
......@@ -29,8 +31,7 @@ class Module {
public:
Module() {}
// constructor from container.
explicit Module(std::shared_ptr<ModuleNode> n)
: node_(n) {}
explicit Module(std::shared_ptr<ModuleNode> n) : node_(n) {}
/*!
* \brief Get packed function from current module by name.
*
......@@ -40,7 +41,8 @@ class Module {
* This function will return PackedFunc(nullptr) if function do not exist.
* \note Implemented in packed_func.cc
*/
inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
inline PackedFunc GetFunction(
const std::string& name, bool query_imports = false);
/*! \return internal container */
inline ModuleNode* operator->();
/*! \return internal container */
......@@ -61,8 +63,8 @@ class Module {
* \note This function won't load the import relationship.
* Re-create import relationship by calling Import.
*/
DGL_DLL static Module LoadFromFile(const std::string& file_name,
const std::string& format = "");
DGL_DLL static Module LoadFromFile(
const std::string& file_name, const std::string& format = "");
private:
std::shared_ptr<ModuleNode> node_;
......@@ -103,8 +105,8 @@ class ModuleNode {
* \param file_name The file to be saved to.
* \param format The format of the file.
*/
virtual void SaveToFile(const std::string& file_name,
const std::string& format);
virtual void SaveToFile(
const std::string& file_name, const std::string& format);
/*!
* \brief Save the module to binary stream.
* \param stream The binary stream to save to.
......@@ -128,9 +130,7 @@ class ModuleNode {
*/
DGL_DLL const PackedFunc* GetFuncFromEnv(const std::string& name);
/*! \return The module it imports from */
const std::vector<Module>& imports() const {
return imports_;
}
const std::vector<Module>& imports() const { return imports_; }
protected:
friend class Module;
......@@ -139,8 +139,7 @@ class ModuleNode {
private:
/*! \brief Cache used by GetImport */
std::unordered_map<std::string,
std::unique_ptr<PackedFunc> > import_cache_;
std::unordered_map<std::string, std::unique_ptr<PackedFunc> > import_cache_;
};
/*! \brief namespace for constant symbols */
......@@ -155,20 +154,19 @@ constexpr const char* dgl_dev_mblob_nbytes = "__dgl_dev_mblob_nbytes";
constexpr const char* dgl_set_device = "__dgl_set_device";
/*! \brief Auxiliary counter to global barrier. */
constexpr const char* dgl_global_barrier_state = "__dgl_global_barrier_state";
/*! \brief Prepare the global barrier before kernels that uses global barrier. */
constexpr const char* dgl_prepare_global_barrier = "__dgl_prepare_global_barrier";
/*!
* \brief Prepare the global barrier before kernels that uses global barrier.
*/
constexpr const char* dgl_prepare_global_barrier =
"__dgl_prepare_global_barrier";
/*! \brief Placeholder for the module's entry function. */
constexpr const char* dgl_module_main = "__dgl_main__";
} // namespace symbol
// implementations of inline functions.
inline ModuleNode* Module::operator->() {
return node_.get();
}
inline ModuleNode* Module::operator->() { return node_.get(); }
inline const ModuleNode* Module::operator->() const {
return node_.get();
}
inline const ModuleNode* Module::operator->() const { return node_.get(); }
} // namespace runtime
} // namespace dgl
......
......@@ -7,10 +7,11 @@
#define DGL_RUNTIME_OBJECT_H_
#include <dmlc/logging.h>
#include <string>
#include <vector>
#include <memory>
#include <string>
#include <type_traits>
#include <vector>
namespace dgl {
namespace runtime {
......@@ -26,7 +27,7 @@ class NDArray;
*/
class AttrVisitor {
public:
//! \cond Doxygen_Suppress
//! \cond Doxygen_Suppress
virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, uint64_t* value) = 0;
......@@ -35,14 +36,16 @@ class AttrVisitor {
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, ObjectRef* value) = 0;
virtual void Visit(const char* key, NDArray* value) = 0;
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
template <
typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
static_assert(
std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
this->Visit(key, reinterpret_cast<int*>(ptr));
}
//! \endcond
//! \endcond
};
/*!
......@@ -87,20 +90,19 @@ class Object {
/*!
* \return whether the type is derived from
*/
template<typename T>
template <typename T>
inline bool derived_from() const;
/*!
* \return whether the object is of type T
* \tparam The type to be checked.
*/
template<typename T>
template <typename T>
inline bool is_type() const;
// object ref can see this
friend class ObjectRef;
static constexpr const char* _type_key = "Object";
};
/*! \brief base class of all reference object */
class ObjectRef {
public:
......@@ -109,7 +111,8 @@ class ObjectRef {
/*!
* \brief Comparator
*
* Compare with the two are referencing to the same object (compare by address).
* Compare with the two are referencing to the same object (compare by
* address).
*
* \param other Another object ref.
* \return the compare result.
......@@ -119,7 +122,8 @@ class ObjectRef {
/*!
* \brief Comparator
*
* Compare with the two are referencing to the same object (compare by address).
* Compare with the two are referencing to the same object (compare by
* address).
*
* \param other Another object ref.
* \return the compare result.
......@@ -161,8 +165,8 @@ class ObjectRef {
* }
* \tparam T the target type, must be subtype of Object
*/
template<typename T>
inline const T *as() const;
template <typename T>
inline const T* as() const;
/*! \brief default constructor */
ObjectRef() = default;
......@@ -178,11 +182,11 @@ class ObjectRef {
* This is macro should be used in abstract base class definition
* because it does not define type_key and type_index.
*/
#define DGL_DECLARE_BASE_OBJECT_INFO(TypeName, Parent) \
const bool _DerivedFrom(uint32_t tid) const override { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
#define DGL_DECLARE_BASE_OBJECT_INFO(TypeName, Parent) \
const bool _DerivedFrom(uint32_t tid) const override { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
}
/*!
......@@ -206,69 +210,61 @@ class ObjectRef {
* DGL_DECLARE_OBJECT_TYPE_INFO(SomeChildClass, SomeBaseClass);
* };
*/
#define DGL_DECLARE_OBJECT_TYPE_INFO(TypeName, Parent) \
const char* type_key() const final { \
return TypeName::_type_key; \
} \
uint32_t type_index() const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
return tidx; \
} \
bool _DerivedFrom(uint32_t tid) const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
#define DGL_DECLARE_OBJECT_TYPE_INFO(TypeName, Parent) \
const char* type_key() const final { return TypeName::_type_key; } \
uint32_t type_index() const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
return tidx; \
} \
bool _DerivedFrom(uint32_t tid) const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
}
/*! \brief Macro to generate common object reference class method definition */
#define DGL_DEFINE_OBJECT_REF_METHODS(TypeName, BaseTypeName, ObjectName) \
TypeName() {} \
explicit TypeName(std::shared_ptr<runtime::Object> obj): BaseTypeName(obj) {} \
const ObjectName* operator->() const { \
return static_cast<const ObjectName*>(obj_.get()); \
} \
ObjectName* operator->() { \
return static_cast<ObjectName*>(obj_.get()); \
} \
std::shared_ptr<ObjectName> sptr() const { \
return CHECK_NOTNULL(std::dynamic_pointer_cast<ObjectName>(obj_)); \
} \
operator bool() const { return this->defined(); } \
#define DGL_DEFINE_OBJECT_REF_METHODS(TypeName, BaseTypeName, ObjectName) \
TypeName() {} \
explicit TypeName(std::shared_ptr<runtime::Object> obj) \
: BaseTypeName(obj) {} \
const ObjectName* operator->() const { \
return static_cast<const ObjectName*>(obj_.get()); \
} \
ObjectName* operator->() { return static_cast<ObjectName*>(obj_.get()); } \
std::shared_ptr<ObjectName> sptr() const { \
return CHECK_NOTNULL(std::dynamic_pointer_cast<ObjectName>(obj_)); \
} \
operator bool() const { return this->defined(); } \
using ContainerType = ObjectName
/*! \brief Macro to generate object reference class definition */
#define DGL_DEFINE_OBJECT_REF(TypeName, ObjectName) \
class TypeName : public ::dgl::runtime::ObjectRef { \
public: \
DGL_DEFINE_OBJECT_REF_METHODS(TypeName, ::dgl::runtime::ObjectRef, ObjectName); \
#define DGL_DEFINE_OBJECT_REF(TypeName, ObjectName) \
class TypeName : public ::dgl::runtime::ObjectRef { \
public: \
DGL_DEFINE_OBJECT_REF_METHODS( \
TypeName, ::dgl::runtime::ObjectRef, ObjectName); \
}
// implementations of inline functions after this
template<typename T>
template <typename T>
inline bool Object::is_type() const {
// use static field so query only happens once.
static uint32_t type_id = Object::TypeKey2Index(T::_type_key);
return type_id == this->type_index();
}
template<typename T>
template <typename T>
inline bool Object::derived_from() const {
// use static field so query only happens once.
static uint32_t type_id = Object::TypeKey2Index(T::_type_key);
return this->_DerivedFrom(type_id);
}
inline const Object* ObjectRef::get() const {
return obj_.get();
}
inline const Object* ObjectRef::get() const { return obj_.get(); }
inline const Object* ObjectRef::operator->() const {
return obj_.get();
}
inline const Object* ObjectRef::operator->() const { return obj_.get(); }
inline bool ObjectRef::defined() const {
return obj_.get() != nullptr;
}
inline bool ObjectRef::defined() const { return obj_.get() != nullptr; }
inline bool ObjectRef::operator==(const ObjectRef& other) const {
return obj_.get() == other.obj_.get();
......@@ -295,7 +291,7 @@ inline uint32_t ObjectRef::type_index() const {
return get()->type_index();
}
template<typename T>
template <typename T>
inline const T* ObjectRef::as() const {
const Object* ptr = get();
if (ptr && ptr->is_type<T>()) {
......@@ -306,9 +302,7 @@ inline const T* ObjectRef::as() const {
/*! \brief The hash function for nodes */
struct ObjectHash {
size_t operator()(const ObjectRef& a) const {
return a.hash();
}
size_t operator()(const ObjectRef& a) const { return a.hash(); }
};
/*! \brief The equal comparator for nodes */
......
......@@ -7,14 +7,16 @@
#define DGL_RUNTIME_PACKED_FUNC_H_
#include <dmlc/logging.h>
#include <functional>
#include <tuple>
#include <vector>
#include <string>
#include <limits>
#include <memory>
#include <utility>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
#include "c_runtime_api.h"
#include "module.h"
#include "ndarray.h"
......@@ -67,7 +69,7 @@ class PackedFunc {
* }
* \endcode
*/
using FType = std::function<void (DGLArgs args, DGLRetValue* rv)>;
using FType = std::function<void(DGLArgs args, DGLRetValue* rv)>;
/*! \brief default constructor */
PackedFunc() {}
/*!
......@@ -89,8 +91,8 @@ class PackedFunc {
* }
* \endcode
*/
template<typename... Args>
inline DGLRetValue operator()(Args&& ...args) const;
template <typename... Args>
inline DGLRetValue operator()(Args&&... args) const;
/*!
* \brief Call the function in packed format.
* \param args The arguments
......@@ -100,13 +102,9 @@ class PackedFunc {
/*! \return the internal body function */
inline FType body() const;
/*! \return Whether the packed function is nullptr */
bool operator==(std::nullptr_t null) const {
return body_ == nullptr;
}
bool operator==(std::nullptr_t null) const { return body_ == nullptr; }
/*! \return Whether the packed function is not nullptr */
bool operator!=(std::nullptr_t null) const {
return body_ != nullptr;
}
bool operator!=(std::nullptr_t null) const { return body_ != nullptr; }
private:
/*! \brief internal container of packed function */
......@@ -114,9 +112,10 @@ class PackedFunc {
};
/*!
* \brief Please refer to \ref TypedPackedFuncAnchor "TypedPackedFunc<R(Args..)>"
* \brief Please refer to \ref TypedPackedFuncAnchor
* "TypedPackedFunc<R(Args..)>"
*/
template<typename FType>
template <typename FType>
class TypedPackedFunc;
/*!
......@@ -151,7 +150,7 @@ class TypedPackedFunc;
* \tparam R The return value of the function.
* \tparam Args The argument signature of the function.
*/
template<typename R, typename ...Args>
template <typename R, typename... Args>
class TypedPackedFunc<R(Args...)> {
public:
/*! \brief short hand for this function type */
......@@ -191,11 +190,9 @@ class TypedPackedFunc<R(Args...)> {
* \param typed_lambda typed lambda function.
* \tparam FLambda the type of the lambda function.
*/
template<typename FLambda,
typename = typename std::enable_if<
std::is_convertible<FLambda,
std::function<R(Args...)>
>::value>::type>
template <
typename FLambda, typename = typename std::enable_if<std::is_convertible<
FLambda, std::function<R(Args...)> >::value>::type>
explicit TypedPackedFunc(const FLambda& typed_lambda) {
this->AssignTypedLambda(typed_lambda);
}
......@@ -215,11 +212,10 @@ class TypedPackedFunc<R(Args...)> {
* \tparam FLambda the type of the lambda function.
* \returns reference to self.
*/
template<typename FLambda,
typename = typename std::enable_if<
std::is_convertible<FLambda,
std::function<R(Args...)>
>::value>::type>
template <
typename FLambda, typename = typename std::enable_if<std::is_convertible<
FLambda,
std::function<R(Args...)> >::value>::type>
TSelf& operator=(FLambda typed_lambda) { // NOLINT(*)
this->AssignTypedLambda(typed_lambda);
return *this;
......@@ -238,20 +234,16 @@ class TypedPackedFunc<R(Args...)> {
* \param args The arguments
* \returns The return value.
*/
inline R operator()(Args ...args) const;
inline R operator()(Args... args) const;
/*!
* \brief convert to PackedFunc
* \return the internal PackedFunc
*/
operator PackedFunc() const {
return packed();
}
operator PackedFunc() const { return packed(); }
/*!
* \return reference the internal PackedFunc
*/
const PackedFunc& packed() const {
return packed_;
}
const PackedFunc& packed() const { return packed_; }
private:
friend class DGLRetValue;
......@@ -264,7 +256,7 @@ class TypedPackedFunc<R(Args...)> {
* \tparam FLambda The lambda function type.
* \note We capture the lambda when possible for maximum efficiency.
*/
template<typename FLambda>
template <typename FLambda>
inline void AssignTypedLambda(FLambda flambda);
};
......@@ -280,12 +272,8 @@ class DGLArgs {
* \param type_codes The argument type codes
* \param num_args number of arguments.
*/
DGLArgs(const DGLValue* values,
const int* type_codes,
int num_args)
: values(values),
type_codes(type_codes),
num_args(num_args) { }
DGLArgs(const DGLValue* values, const int* type_codes, int num_args)
: values(values), type_codes(type_codes), num_args(num_args) {}
/*! \return size of the arguments */
inline int size() const;
/*!
......@@ -307,7 +295,7 @@ class DGLArgs {
*
* \tparam T the typename
*/
template<typename T>
template <typename T>
struct extension_class_info {
static const int code = 0;
};
......@@ -337,7 +325,8 @@ class ExtTypeVTable {
private:
// Internal registration function.
DGL_DLL static ExtTypeVTable* RegisterInternal(int type_code, const ExtTypeVTable& vt);
DGL_DLL static ExtTypeVTable* RegisterInternal(
int type_code, const ExtTypeVTable& vt);
};
/*!
......@@ -366,8 +355,7 @@ class DGLPODValue_ {
}
operator int() const {
DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);
CHECK_LE(value_.v_int64,
std::numeric_limits<int>::max());
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
return static_cast<int>(value_.v_int64);
}
operator bool() const {
......@@ -381,14 +369,12 @@ class DGLPODValue_ {
return value_.v_handle;
}
operator DGLArray*() const {
if (type_code_ == kArrayHandle ||
type_code_ == kNDArrayContainer) {
if (type_code_ == kArrayHandle || type_code_ == kNDArrayContainer) {
return static_cast<DGLArray*>(value_.v_handle);
} else {
if (type_code_ == kNull) return nullptr;
LOG(FATAL) << "Expected "
<< "DGLArray* or NDArray but get "
<< TypeCode2Str(type_code_);
<< "DGLArray* or NDArray but get " << TypeCode2Str(type_code_);
return nullptr;
}
}
......@@ -401,20 +387,18 @@ class DGLPODValue_ {
DGL_CHECK_TYPE_CODE(type_code_, kDGLContext);
return value_.v_ctx;
}
template<typename TExtension>
template <typename TExtension>
const TExtension& AsExtension() const {
CHECK_LT(type_code_, kExtEnd);
return static_cast<TExtension*>(value_.v_handle)[0];
}
int type_code() const {
return type_code_;
}
int type_code() const { return type_code_; }
/*!
* \brief return handle as specific pointer type.
* \tparam T the data type.
* \return The pointer type.
*/
template<typename T>
template <typename T>
T* ptr() const {
return static_cast<T*>(value_.v_handle);
}
......@@ -447,9 +431,7 @@ class DGLArgValue : public DGLPODValue_ {
* \param value of the function
* \param type_code The type code.
*/
DGLArgValue(DGLValue value, int type_code)
: DGLPODValue_(value, type_code) {
}
DGLArgValue(DGLValue value, int type_code) : DGLPODValue_(value, type_code) {}
// reuse converter from parent
using DGLPODValue_::operator double;
using DGLPODValue_::operator int64_t;
......@@ -485,7 +467,7 @@ class DGLArgValue : public DGLPODValue_ {
DGL_CHECK_TYPE_CODE(type_code_, kFuncHandle);
return *ptr<PackedFunc>();
}
template<typename FType>
template <typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
}
......@@ -493,24 +475,22 @@ class DGLArgValue : public DGLPODValue_ {
DGL_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return *ptr<Module>();
}
const DGLValue& value() const {
return value_;
}
const DGLValue& value() const { return value_; }
// Deferred extension handler.
template<typename TObjectRef>
template <typename TObjectRef>
inline TObjectRef AsObjectRef() const;
// Convert this value to arbitrary class type
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
template <
typename T,
typename = typename std::enable_if<std::is_class<T>::value>::type>
inline operator T() const;
// Return true if the value is of TObjectRef type
template<typename TObjectRef,
typename = typename std::enable_if<
std::is_class<TObjectRef>::value>::type>
template <
typename TObjectRef, typename = typename std::enable_if<
std::is_class<TObjectRef>::value>::type>
inline bool IsObjectType() const;
// get internal node ptr, if it is node
......@@ -539,9 +519,7 @@ class DGLRetValue : public DGLPODValue_ {
other.type_code_ = kNull;
}
/*! \brief destructor */
~DGLRetValue() {
this->Clear();
}
~DGLRetValue() { this->Clear(); }
// reuse converter from parent
using DGLPODValue_::operator double;
using DGLPODValue_::operator int64_t;
......@@ -553,9 +531,7 @@ class DGLRetValue : public DGLPODValue_ {
using DGLPODValue_::operator DGLContext;
using DGLPODValue_::operator NDArray;
// Disable copy and assign from another value, but allow move.
DGLRetValue(const DGLRetValue& other) {
this->Assign(other);
}
DGLRetValue(const DGLRetValue& other) { this->Assign(other); }
// conversion operators
operator std::string() const {
if (type_code_ == kDGLDataType) {
......@@ -578,7 +554,7 @@ class DGLRetValue : public DGLPODValue_ {
DGL_CHECK_TYPE_CODE(type_code_, kFuncHandle);
return *ptr<PackedFunc>();
}
template<typename FType>
template <typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
}
......@@ -653,7 +629,7 @@ class DGLRetValue : public DGLPODValue_ {
this->SwitchToClass(kFuncHandle, f);
return *this;
}
template<typename FType>
template <typename FType>
DGLRetValue& operator=(const TypedPackedFunc<FType>& f) {
return operator=(f.packed());
}
......@@ -669,12 +645,11 @@ class DGLRetValue : public DGLPODValue_ {
this->Assign(other);
return *this;
}
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
template <
typename T, typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
DGLRetValue& operator=(const T& other) {
this->SwitchToClass<T>(
extension_class_info<T>::code, other);
this->SwitchToClass<T>(extension_class_info<T>::code, other);
return *this;
}
/*!
......@@ -686,8 +661,7 @@ class DGLRetValue : public DGLPODValue_ {
* \param ret_value The return value.
* \param ret_type_code The return type code.
*/
void MoveToCHost(DGLValue* ret_value,
int* ret_type_code) {
void MoveToCHost(DGLValue* ret_value, int* ret_type_code) {
// cannot move str; need specially handle.
CHECK(type_code_ != kStr && type_code_ != kBytes);
*ret_value = value_;
......@@ -696,24 +670,24 @@ class DGLRetValue : public DGLPODValue_ {
}
/*! \return The value field, if the data is POD */
const DGLValue& value() const {
CHECK(type_code_ != kObjectHandle &&
type_code_ != kFuncHandle &&
type_code_ != kModuleHandle &&
type_code_ != kStr) << "DGLRetValue.value can only be used for POD data";
CHECK(
type_code_ != kObjectHandle && type_code_ != kFuncHandle &&
type_code_ != kModuleHandle && type_code_ != kStr)
<< "DGLRetValue.value can only be used for POD data";
return value_;
}
// ObjectRef related extenstions: in dgl/packed_func_ext.h
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
template <
typename T,
typename = typename std::enable_if<std::is_class<T>::value>::type>
inline operator T() const;
template<typename TObjectRef>
template <typename TObjectRef>
inline TObjectRef AsObjectRef() const;
inline DGLRetValue& operator=(const ObjectRef& other);
inline DGLRetValue& operator=(const std::shared_ptr<Object>& other);
private:
template<typename T>
template <typename T>
void Assign(const T& other) {
switch (other.type_code()) {
case kStr: {
......@@ -751,9 +725,8 @@ class DGLRetValue : public DGLPODValue_ {
#else
this->Clear();
type_code_ = other.type_code();
value_.v_handle =
(*(ExtTypeVTable::Get(other.type_code())->clone))(
other.value().v_handle);
value_.v_handle = (*(ExtTypeVTable::Get(other.type_code())->clone))(
other.value().v_handle);
#endif
}
break;
......@@ -767,7 +740,7 @@ class DGLRetValue : public DGLPODValue_ {
type_code_ = type_code;
}
}
template<typename T>
template <typename T>
void SwitchToClass(int type_code, T v) {
if (type_code_ != type_code) {
this->Clear();
......@@ -780,10 +753,19 @@ class DGLRetValue : public DGLPODValue_ {
void Clear() {
if (type_code_ == kNull) return;
switch (type_code_) {
case kStr: case kBytes: delete ptr<std::string>(); break;
case kFuncHandle: delete ptr<PackedFunc>(); break;
case kModuleHandle: delete ptr<Module>(); break;
case kObjectHandle: delete ptr<std::shared_ptr<Object> >(); break;
case kStr:
case kBytes:
delete ptr<std::string>();
break;
case kFuncHandle:
delete ptr<PackedFunc>();
break;
case kModuleHandle:
delete ptr<Module>();
break;
case kObjectHandle:
delete ptr<std::shared_ptr<Object> >();
break;
case kNDArrayContainer: {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break;
......@@ -791,7 +773,7 @@ class DGLRetValue : public DGLPODValue_ {
}
if (type_code_ > kExtBegin) {
#if DGL_RUNTIME_HEADER_ONLY
LOG(FATAL) << "Header only mode do not support ext type";
LOG(FATAL) << "Header only mode do not support ext type";
#else
(*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle);
#endif
......@@ -802,49 +784,42 @@ class DGLRetValue : public DGLPODValue_ {
// implementation details
inline DGLArgValue DGLArgs::operator[](int i) const {
CHECK_LT(i, num_args)
<< "not enough argument passed, "
<< num_args << " passed"
<< " but request arg[" << i << "].";
CHECK_LT(i, num_args) << "not enough argument passed, " << num_args
<< " passed"
<< " but request arg[" << i << "].";
return DGLArgValue(values[i], type_codes[i]);
}
inline int DGLArgs::size() const {
return num_args;
}
inline int DGLArgs::size() const { return num_args; }
inline void PackedFunc::CallPacked(DGLArgs args, DGLRetValue* rv) const {
body_(args, rv);
}
inline PackedFunc::FType PackedFunc::body() const {
return body_;
}
inline PackedFunc::FType PackedFunc::body() const { return body_; }
// internal namespace
namespace detail {
template<bool stop, std::size_t I, typename F>
template <bool stop, std::size_t I, typename F>
struct for_each_dispatcher {
template<typename T, typename ...Args>
template <typename T, typename... Args>
static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*)
f(I, std::forward<T>(value));
for_each_dispatcher<sizeof...(Args) == 0, (I+1), F>
::run(f, std::forward<Args>(args)...);
for_each_dispatcher<sizeof...(Args) == 0, (I + 1), F>::run(
f, std::forward<Args>(args)...);
}
};
template<std::size_t I, typename F>
struct for_each_dispatcher<true, I, F> {
template <std::size_t I, typename F>
struct for_each_dispatcher<true, I, F> {
static void run(const F& f) {} // NOLINT(*)
};
template<typename F, typename ...Args>
template <typename F, typename... Args>
inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
for_each_dispatcher<sizeof...(Args) == 0, 0, F>
::run(f, std::forward<Args>(args)...);
for_each_dispatcher<sizeof...(Args) == 0, 0, F>::run(
f, std::forward<Args>(args)...);
}
} // namespace detail
......@@ -854,17 +829,16 @@ class DGLArgsSetter {
DGLArgsSetter(DGLValue* values, int* type_codes)
: values_(values), type_codes_(type_codes) {}
// setters for POD types
template<typename T,
typename = typename std::enable_if<
std::is_integral<T>::value>::type>
template <
typename T,
typename = typename std::enable_if<std::is_integral<T>::value>::type>
void operator()(size_t i, T value) const {
values_[i].v_int64 = static_cast<int64_t>(value);
type_codes_[i] = kDGLInt;
}
void operator()(size_t i, uint64_t value) const {
values_[i].v_int64 = static_cast<int64_t>(value);
CHECK_LE(value,
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
CHECK_LE(value, static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
type_codes_[i] = kDGLInt;
}
void operator()(size_t i, double value) const {
......@@ -914,8 +888,9 @@ class DGLArgsSetter {
values_[i].v_handle = const_cast<PackedFunc*>(&value);
type_codes_[i] = kFuncHandle;
}
template<typename FType>
void operator()(size_t i, const TypedPackedFunc<FType>& value) const { // NOLINT(*)
template <typename FType>
void operator()(
size_t i, const TypedPackedFunc<FType>& value) const { // NOLINT(*)
operator()(i, value.packed());
}
void operator()(size_t i, const Module& value) const { // NOLINT(*)
......@@ -937,9 +912,9 @@ class DGLArgsSetter {
}
}
// extension
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
template <
typename T, typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
inline void operator()(size_t i, const T& value) const;
// ObjectRef related extenstions: in dgl/packed_func_ext.h
inline void operator()(size_t i, const ObjectRef& other) const; // NOLINT(*)
......@@ -951,156 +926,143 @@ class DGLArgsSetter {
int* type_codes_;
};
template<typename... Args>
inline DGLRetValue PackedFunc::operator()(Args&& ...args) const {
template <typename... Args>
inline DGLRetValue PackedFunc::operator()(Args&&... args) const {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
DGLValue values[kArraySize];
int type_codes[kArraySize];
detail::for_each(DGLArgsSetter(values, type_codes),
std::forward<Args>(args)...);
detail::for_each(
DGLArgsSetter(values, type_codes), std::forward<Args>(args)...);
DGLRetValue rv;
body_(DGLArgs(values, type_codes, kNumArgs), &rv);
return rv;
}
namespace detail {
template<typename R, int nleft, int index, typename F>
template <typename R, int nleft, int index, typename F>
struct unpack_call_dispatcher {
template<typename ...Args>
static void run(const F& f,
const DGLArgs& args_pack,
DGLRetValue* rv,
Args&&... unpacked_args) {
unpack_call_dispatcher<R, nleft - 1, index + 1, F>
::run(f, args_pack, rv,
std::forward<Args>(unpacked_args)...,
args_pack[index]);
template <typename... Args>
static void run(
const F& f, const DGLArgs& args_pack, DGLRetValue* rv,
Args&&... unpacked_args) {
unpack_call_dispatcher<R, nleft - 1, index + 1, F>::run(
f, args_pack, rv, std::forward<Args>(unpacked_args)...,
args_pack[index]);
}
};
template<typename R, int index, typename F>
template <typename R, int index, typename F>
struct unpack_call_dispatcher<R, 0, index, F> {
template<typename ...Args>
static void run(const F& f,
const DGLArgs& args_pack,
DGLRetValue* rv,
Args&&... unpacked_args) {
template <typename... Args>
static void run(
const F& f, const DGLArgs& args_pack, DGLRetValue* rv,
Args&&... unpacked_args) {
*rv = R(f(std::forward<Args>(unpacked_args)...));
}
};
template<int index, typename F>
template <int index, typename F>
struct unpack_call_dispatcher<void, 0, index, F> {
template<typename ...Args>
static void run(const F& f,
const DGLArgs& args_pack,
DGLRetValue* rv,
Args&&... unpacked_args) {
template <typename... Args>
static void run(
const F& f, const DGLArgs& args_pack, DGLRetValue* rv,
Args&&... unpacked_args) {
f(std::forward<Args>(unpacked_args)...);
}
};
template<typename R, int nargs, typename F>
template <typename R, int nargs, typename F>
inline void unpack_call(const F& f, const DGLArgs& args, DGLRetValue* rv) {
unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);
}
template<typename R, typename ...Args>
inline R call_packed(const PackedFunc& pf, Args&& ...args) {
template <typename R, typename... Args>
inline R call_packed(const PackedFunc& pf, Args&&... args) {
return R(pf(std::forward<Args>(args)...));
}
template<typename R>
template <typename R>
struct typed_packed_call_dispatcher {
template<typename ...Args>
static inline R run(const PackedFunc& pf, Args&& ...args) {
template <typename... Args>
static inline R run(const PackedFunc& pf, Args&&... args) {
return pf(std::forward<Args>(args)...);
}
};
template<>
template <>
struct typed_packed_call_dispatcher<void> {
template<typename ...Args>
static inline void run(const PackedFunc& pf, Args&& ...args) {
template <typename... Args>
static inline void run(const PackedFunc& pf, Args&&... args) {
pf(std::forward<Args>(args)...);
}
};
} // namespace detail
template<typename R, typename ...Args>
template <typename R, typename... Args>
TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed)
: packed_(packed) {}
: packed_(packed) {}
template<typename R, typename ...Args>
template<typename FType>
template <typename R, typename... Args>
template <typename FType>
inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
packed_ = PackedFunc([flambda](const DGLArgs& args, DGLRetValue* rv) {
detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv);
});
detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv);
});
}
template<typename R, typename ...Args>
template <typename R, typename... Args>
inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
return detail::typed_packed_call_dispatcher<R>
::run(packed_, std::forward<Args>(args)...);
return detail::typed_packed_call_dispatcher<R>::run(
packed_, std::forward<Args>(args)...);
}
// extension and node type handling
namespace detail {
template<typename T, typename TSrc, bool is_ext>
template <typename T, typename TSrc, bool is_ext>
struct DGLValueCast {
static T Apply(const TSrc* self) {
return self->template AsObjectRef<T>();
}
static T Apply(const TSrc* self) { return self->template AsObjectRef<T>(); }
};
template<typename T, typename TSrc>
template <typename T, typename TSrc>
struct DGLValueCast<T, TSrc, true> {
static T Apply(const TSrc* self) {
return self->template AsExtension<T>();
}
static T Apply(const TSrc* self) { return self->template AsExtension<T>(); }
};
} // namespace detail
template<typename T, typename>
template <typename T, typename>
inline DGLArgValue::operator T() const {
return detail::
DGLValueCast<T, DGLArgValue, extension_class_info<T>::code != 0>
::Apply(this);
return detail::DGLValueCast<
T, DGLArgValue, extension_class_info<T>::code != 0>::Apply(this);
}
template<typename T, typename>
template <typename T, typename>
inline DGLRetValue::operator T() const {
return detail::
DGLValueCast<T, DGLRetValue, extension_class_info<T>::code != 0>
::Apply(this);
return detail::DGLValueCast<
T, DGLRetValue, extension_class_info<T>::code != 0>::Apply(this);
}
template<typename T, typename>
template <typename T, typename>
inline void DGLArgsSetter::operator()(size_t i, const T& value) const {
static_assert(extension_class_info<T>::code != 0,
"Need to have extesion code");
static_assert(
extension_class_info<T>::code != 0, "Need to have extesion code");
type_codes_[i] = extension_class_info<T>::code;
values_[i].v_handle = const_cast<T*>(&value);
}
// extension type handling
template<typename T>
template <typename T>
struct ExtTypeInfo {
static void destroy(void* handle) {
delete static_cast<T*>(handle);
}
static void* clone(void* handle) {
return new T(*static_cast<T*>(handle));
}
static void destroy(void* handle) { delete static_cast<T*>(handle); }
static void* clone(void* handle) { return new T(*static_cast<T*>(handle)); }
};
template<typename T>
template <typename T>
inline ExtTypeVTable* ExtTypeVTable::Register_() {
const int code = extension_class_info<T>::code;
static_assert(code != 0,
"require extension_class_info traits to be declared with non-zero code");
static_assert(
code != 0,
"require extension_class_info traits to be declared with non-zero code");
ExtTypeVTable vt;
vt.clone = ExtTypeInfo<T>::clone;
vt.destroy = ExtTypeInfo<T>::destroy;
......@@ -1109,7 +1071,8 @@ inline ExtTypeVTable* ExtTypeVTable::Register_() {
// Implement Module::GetFunction
// Put implementation in this file so we have seen the PackedFunc
inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
inline PackedFunc Module::GetFunction(
const std::string& name, bool query_imports) {
PackedFunc pf = node_->GetFunction(name, node_);
if (pf != nullptr) return pf;
if (query_imports) {
......
......@@ -27,6 +27,7 @@
#include <string>
#include <vector>
#include "packed_func.h"
namespace dgl {
......@@ -61,7 +62,7 @@ class Registry {
* \tparam FType the signature of the function.
* \tparam FLambda The type of f.
*/
template<typename FType, typename FLambda>
template <typename FType, typename FLambda>
Registry& set_body_typed(FLambda f) {
return set_body(TypedPackedFunc<FType>(f).packed());
}
......@@ -71,7 +72,8 @@ class Registry {
* \param override Whether allow oveeride existing function.
* \return Reference to theregistry.
*/
DGL_DLL static Registry& Register(const std::string& name, bool override = false); // NOLINT(*)
DGL_DLL static Registry& Register(
const std::string& name, bool override = false); // NOLINT(*)
/*!
* \brief Erase global function from registry, if exist.
* \param name The name of the function.
......@@ -112,11 +114,11 @@ class Registry {
#define DGL_STR_CONCAT_(__x, __y) __x##__y
#define DGL_STR_CONCAT(__x, __y) DGL_STR_CONCAT_(__x, __y)
#define DGL_FUNC_REG_VAR_DEF \
static DGL_ATTRIBUTE_UNUSED ::dgl::runtime::Registry& __mk_ ## DGL
#define DGL_FUNC_REG_VAR_DEF \
static DGL_ATTRIBUTE_UNUSED ::dgl::runtime::Registry& __mk_##DGL
#define DGL_TYPE_REG_VAR_DEF \
static DGL_ATTRIBUTE_UNUSED ::dgl::runtime::ExtTypeVTable* __mk_ ## DGLT
#define DGL_TYPE_REG_VAR_DEF \
static DGL_ATTRIBUTE_UNUSED ::dgl::runtime::ExtTypeVTable* __mk_##DGLT
/*!
* \brief Register a function globally.
......@@ -126,8 +128,8 @@ class Registry {
* });
* \endcode
*/
#define DGL_REGISTER_GLOBAL(OpName) \
DGL_STR_CONCAT(DGL_FUNC_REG_VAR_DEF, __COUNTER__) = \
#define DGL_REGISTER_GLOBAL(OpName) \
DGL_STR_CONCAT(DGL_FUNC_REG_VAR_DEF, __COUNTER__) = \
::dgl::runtime::Registry::Register(OpName)
/*!
......@@ -135,8 +137,8 @@ class Registry {
* This must be registered in a cc file
* after the trait extension_class_info is defined.
*/
#define DGL_REGISTER_EXT_TYPE(T) \
DGL_STR_CONCAT(DGL_TYPE_REG_VAR_DEF, __COUNTER__) = \
#define DGL_REGISTER_EXT_TYPE(T) \
DGL_STR_CONCAT(DGL_TYPE_REG_VAR_DEF, __COUNTER__) = \
::dgl::runtime::ExtTypeVTable::Register_<T>()
} // namespace runtime
......
......@@ -9,6 +9,7 @@
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include "c_runtime_api.h"
#include "smart_ptr_serializer.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment