"examples/krr_classification_ex.cpp" did not exist on "49532acf87315aa8514617fb9a40f2ecbd87c9d4"
Unverified Commit 07dc8fb6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4800)



* clang-format

* manul

* manul

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 9d9280cb
......@@ -32,8 +32,7 @@ DGL_DLL int DGLObjectFree(ObjectHandle handle);
* \param out_index the corresponding type index.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLObjectTypeKey2Index(const char* type_key,
int* out_index);
DGL_DLL int DGLObjectTypeKey2Index(const char* type_key, int* out_index);
/*!
* \brief Get runtime type index of the object.
......@@ -41,8 +40,7 @@ DGL_DLL int DGLObjectTypeKey2Index(const char* type_key,
* \param out_index the corresponding type index.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLObjectGetTypeIndex(ObjectHandle handle,
int* out_index);
DGL_DLL int DGLObjectGetTypeIndex(ObjectHandle handle, int* out_index);
/*!
* \brief get attributes given key
......@@ -54,11 +52,9 @@ DGL_DLL int DGLObjectGetTypeIndex(ObjectHandle handle,
* \return 0 when success, -1 when failure happens
* \note API calls always exchanges with type bits=64, lanes=1
*/
DGL_DLL int DGLObjectGetAttr(ObjectHandle handle,
const char* key,
DGLValue* out_value,
int* out_type_code,
int* out_success);
DGL_DLL int DGLObjectGetAttr(
ObjectHandle handle, const char* key, DGLValue* out_value,
int* out_type_code, int* out_success);
/*!
* \brief get attributes names in the object.
......@@ -67,9 +63,8 @@ DGL_DLL int DGLObjectGetAttr(ObjectHandle handle,
* \param out_array The array of function names.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLObjectListAttrNames(ObjectHandle handle,
int *out_size,
const char*** out_array);
DGL_DLL int DGLObjectListAttrNames(
ObjectHandle handle, int* out_size, const char*** out_array);
#ifdef __cplusplus
} // DGL_EXTERN_C
#endif
......
......@@ -38,8 +38,8 @@
#ifdef __cplusplus
extern "C" {
#endif
#include <stdint.h>
#include <stddef.h>
#include <stdint.h>
/*! \brief type of array index. */
typedef int64_t dgl_index_t;
......@@ -60,7 +60,8 @@ typedef enum {
} DGLDeviceType;
/*!
* \brief The object type code is used in DGL FFI to indicate the types of objects passed between C and Python.
* \brief The object type code is used in DGL FFI to indicate the types of
* objects passed between C and Python.
*/
typedef enum {
kInt = 0U,
......@@ -105,9 +106,9 @@ typedef enum {
} DGLDataTypeCode;
/*!
* \brief The data type the tensor can hold. The data type is assumed to follow the
* native endian-ness. An explicit error message should be raised when attempting to
* export an array with non-native endianness
* \brief The data type the tensor can hold. The data type is assumed to follow
* the native endian-ness. An explicit error message should be raised when
* attempting to export an array with non-native endianness
*
* Examples
* - float: type_code = 2, bits = 32, lanes=1
......@@ -149,12 +150,12 @@ typedef struct {
typedef struct {
/*!
* \brief The data pointer points to the allocated data.
*
*
* Depending on the device context, it can be a CPU pointer, or a CUDA
* device pointer or acl_mem handle in OpenCL.
* This pointer is always aligned to 256 bytes as in CUDA. Use the
* `byte_offset` field to mark the beginning of the actual data (if the address
* is not 256 byte aligned).
* `byte_offset` field to mark the beginning of the actual data (if the
* address is not 256 byte aligned).
*
* Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow,
* TVM, perhaps others) do not adhere to this 256 byte alignment requirement
......@@ -247,7 +248,7 @@ DGL_DLL void DGLAPISetLastError(const char* msg);
* this function is threadsafe and can be called by different thread
* \return error info
*/
DGL_DLL const char *DGLGetLastError(void);
DGL_DLL const char* DGLGetLastError(void);
/*!
* \brief Load module from file.
* \param file_name The file name to load the module from.
......@@ -258,9 +259,8 @@ DGL_DLL const char *DGLGetLastError(void);
* \note The resulting module do not contain import relation.
* It can be reconstructed by DGLModImport.
*/
DGL_DLL int DGLModLoadFromFile(const char* file_name,
const char* format,
DGLModuleHandle* out);
DGL_DLL int DGLModLoadFromFile(
const char* file_name, const char* format, DGLModuleHandle* out);
/*!
* \brief Add dep to mod's dependency.
......@@ -270,8 +270,7 @@ DGL_DLL int DGLModLoadFromFile(const char* file_name,
* \param dep The dependent module to be imported.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLModImport(DGLModuleHandle mod,
DGLModuleHandle dep);
DGL_DLL int DGLModImport(DGLModuleHandle mod, DGLModuleHandle dep);
/*!
* \brief Get function from the module.
......@@ -281,10 +280,9 @@ DGL_DLL int DGLModImport(DGLModuleHandle mod,
* \param out The result function, can be NULL if it is not available.
* \return 0 when no error is thrown, -1 when failure happens
*/
DGL_DLL int DGLModGetFunction(DGLModuleHandle mod,
const char* func_name,
int query_imports,
DGLFunctionHandle *out);
DGL_DLL int DGLModGetFunction(
DGLModuleHandle mod, const char* func_name, int query_imports,
DGLFunctionHandle* out);
/*!
* \brief Free front-end extension type resource.
......@@ -334,12 +332,9 @@ DGL_DLL int DGLFuncFree(DGLFunctionHandle func);
* The front-end need to call free function (e.g. DGLFuncFree)
* to free these handles.
*/
DGL_DLL int DGLFuncCall(DGLFunctionHandle func,
DGLValue* arg_values,
int* type_codes,
int num_args,
DGLValue* ret_val,
int* ret_type_code);
DGL_DLL int DGLFuncCall(
DGLFunctionHandle func, DGLValue* arg_values, int* type_codes, int num_args,
DGLValue* ret_val, int* ret_type_code);
/*!
* \brief Set the return value of DGLPackedCFunc.
......@@ -352,10 +347,8 @@ DGL_DLL int DGLFuncCall(DGLFunctionHandle func,
* \param type_code The type of the value to be returned.
* \param num_ret Number of return values, for now only 1 is supported.
*/
DGL_DLL int DGLCFuncSetReturn(DGLRetValueHandle ret,
DGLValue* value,
int* type_code,
int num_ret);
DGL_DLL int DGLCFuncSetReturn(
DGLRetValueHandle ret, DGLValue* value, int* type_code, int num_ret);
/*!
* \brief Inplace translate callback argument value to return value.
......@@ -377,14 +370,12 @@ DGL_DLL int DGLCbArgToReturn(DGLValue* value, int code);
* \param num_args Number of arguments.
* \param ret The return value handle.
* \param resource_handle The handle additional resouce handle from fron-end.
* \return 0 if success, -1 if failure happens, set error via DGLAPISetLastError.
* \return 0 if success, -1 if failure happens, set error via
* DGLAPISetLastError.
* \sa DGLCFuncSetReturn
*/
typedef int (*DGLPackedCFunc)(
DGLValue* args,
int* type_codes,
int num_args,
DGLRetValueHandle ret,
DGLValue* args, int* type_codes, int num_args, DGLRetValueHandle ret,
void* resource_handle);
/*!
......@@ -407,18 +398,19 @@ typedef int (*DGLExtensionFuncDeclarer)(DGLFunctionHandle register_func_handle);
/*!
* \brief Wrap a DGLPackedCFunc to become a FunctionHandle.
*
* The resource_handle will be managed by DGL API, until the function is no longer used.
* The resource_handle will be managed by DGL API, until the function is no
* longer used.
*
* \param func The packed C function.
* \param resource_handle The resource handle from front-end, can be NULL.
* \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
* \param fin The finalizer on resource handle when the FunctionHandle get
* freed, can be NULL.
* \param out the result function handle.
* \return 0 when success, -1 when failure happens
* \return 0 when success, -1 when failure happens.
*/
DGL_DLL int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
void* resource_handle,
DGLPackedCFuncFinalizer fin,
DGLFunctionHandle *out);
DGL_DLL int DGLFuncCreateFromCFunc(
DGLPackedCFunc func, void* resource_handle, DGLPackedCFuncFinalizer fin,
DGLFunctionHandle* out);
/*!
* \brief Register the function to runtime's global table.
......@@ -449,8 +441,7 @@ DGL_DLL int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out);
* \param out_array The array of function names.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLFuncListGlobalNames(int* out_size,
const char*** out_array);
DGL_DLL int DGLFuncListGlobalNames(int* out_size, const char*** out_array);
// Array related apis for quick proptyping
/*!
......@@ -467,14 +458,9 @@ DGL_DLL int DGLFuncListGlobalNames(int* out_size,
* \param out The output handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayAlloc(const dgl_index_t* shape,
int ndim,
int dtype_code,
int dtype_bits,
int dtype_lanes,
int device_type,
int device_id,
DGLArrayHandle* out);
DGL_DLL int DGLArrayAlloc(
const dgl_index_t* shape, int ndim, int dtype_code, int dtype_bits,
int dtype_lanes, int device_type, int device_id, DGLArrayHandle* out);
/*!
* \brief Allocate a nd-array's with shared memory,
......@@ -490,14 +476,9 @@ DGL_DLL int DGLArrayAlloc(const dgl_index_t* shape,
* \param out The output handle.
* \return 0 when success, -1 when failure happens
*/
int DGLArrayAllocSharedMem(const char *mem_name,
const dgl_index_t *shape,
int ndim,
int dtype_code,
int dtype_bits,
int dtype_lanes,
bool is_create,
DGLArrayHandle* out);
int DGLArrayAllocSharedMem(
const char* mem_name, const dgl_index_t* shape, int ndim, int dtype_code,
int dtype_bits, int dtype_lanes, bool is_create, DGLArrayHandle* out);
/*!
* \brief Free the DGL Array.
......@@ -513,9 +494,8 @@ DGL_DLL int DGLArrayFree(DGLArrayHandle handle);
* \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayCopyFromBytes(DGLArrayHandle handle,
void* data,
size_t nbytes);
DGL_DLL int DGLArrayCopyFromBytes(
DGLArrayHandle handle, void* data, size_t nbytes);
/*!
* \brief Copy array data to CPU byte array.
......@@ -524,9 +504,8 @@ DGL_DLL int DGLArrayCopyFromBytes(DGLArrayHandle handle,
* \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayCopyToBytes(DGLArrayHandle handle,
void* data,
size_t nbytes);
DGL_DLL int DGLArrayCopyToBytes(
DGLArrayHandle handle, void* data, size_t nbytes);
/*!
* \brief Copy the array, both from and to must be valid during the copy.
......@@ -534,8 +513,7 @@ DGL_DLL int DGLArrayCopyToBytes(DGLArrayHandle handle,
* \param to The target space.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from,
DGLArrayHandle to);
DGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from, DGLArrayHandle to);
/*!
* \brief Create a new runtime stream.
......@@ -545,7 +523,8 @@ DGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from,
* \param out The new stream handle
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out);
DGL_DLL int DGLStreamCreate(
int device_type, int device_id, DGLStreamHandle* out);
/*!
* \brief Free a created stream handle.
......@@ -555,7 +534,8 @@ DGL_DLL int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out
* \param stream The stream to be freed
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream);
DGL_DLL int DGLStreamFree(
int device_type, int device_id, DGLStreamHandle stream);
/*!
* \brief Set the runtime stream of current thread to be stream.
......@@ -568,7 +548,8 @@ DGL_DLL int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream
* \param handle The stream handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLSetStream(int device_type, int device_id, DGLStreamHandle handle);
DGL_DLL int DGLSetStream(
int device_type, int device_id, DGLStreamHandle handle);
/*!
* \brief Get the runtime stream of current thread.
......@@ -578,7 +559,8 @@ DGL_DLL int DGLSetStream(int device_type, int device_id, DGLStreamHandle handle)
* \param handle The stream handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLGetStream(int device_type, int device_id, DGLStreamHandle* handle);
DGL_DLL int DGLGetStream(
int device_type, int device_id, DGLStreamHandle* handle);
/*!
* \brief Wait until all computations on stream completes.
......@@ -588,7 +570,8 @@ DGL_DLL int DGLGetStream(int device_type, int device_id, DGLStreamHandle* handle
* \param stream The stream to be synchronized.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream);
DGL_DLL int DGLSynchronize(
int device_type, int device_id, DGLStreamHandle stream);
/*!
* \brief Synchronize two streams of execution.
......@@ -599,16 +582,14 @@ DGL_DLL int DGLSynchronize(int device_type, int device_id, DGLStreamHandle strea
* \param dst The destination stream to synchronize.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLStreamStreamSynchronize(int device_type,
int device_id,
DGLStreamHandle src,
DGLStreamHandle dst);
DGL_DLL int DGLStreamStreamSynchronize(
int device_type, int device_id, DGLStreamHandle src, DGLStreamHandle dst);
/*!
* \brief Load tensor adapter.
* \return 0 when success, -1 when failure happens.
*/
DGL_DLL int DGLLoadTensorAdapter(const char *path);
DGL_DLL int DGLLoadTensorAdapter(const char* path);
/*!
* \brief Pin host memory.
......@@ -628,17 +609,18 @@ int DGLArrayRecordStream(DGLArrayHandle handle, DGLStreamHandle stream);
/*!
* \brief Bug report macro.
*
* This serves as a sanity check on system side to make sure the code is correct by
* checking whether a condition always holds for complex reasons. Failing the
* condition signifies a system bug instead of users giving invalid inputs or using
* the functionality incorrectly.
* This serves as a sanity check on system side to make sure the code is correct
* by checking whether a condition always holds for complex reasons. Failing
* the condition signifies a system bug instead of users giving invalid inputs
* or using the functionality incorrectly.
*
* Hints the user to file a bug report if the condition fails.
*/
#define BUG_IF_FAIL(cond) \
CHECK(cond) << "A bug has been occurred. " \
"Please file a bug report at https://github.com/dmlc/dgl/issues. " \
"Message: "
#define BUG_IF_FAIL(cond) \
CHECK(cond) \
<< "A bug has been occurred. " \
"Please file a bug report at https://github.com/dmlc/dgl/issues. " \
"Message: "
#ifdef __cplusplus
} // DGL_EXTERN_C
......
......@@ -4,7 +4,6 @@
* \brief DGL runtime config
*/
#ifndef DGL_RUNTIME_CONFIG_H_
#define DGL_RUNTIME_CONFIG_H_
......@@ -23,7 +22,7 @@ class Config {
bool IsLibxsmmAvailable() const;
private:
Config() = default;
Config() = default;
bool libxsmm_ = true;
};
......
......@@ -6,11 +6,12 @@
#ifndef DGL_RUNTIME_CONTAINER_H_
#define DGL_RUNTIME_CONTAINER_H_
#include <unordered_map>
#include <vector>
#include <memory>
#include <utility>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "object.h"
#include "packed_func.h"
......@@ -44,7 +45,7 @@ inline std::shared_ptr<ValueObject> MakeValue(T&& val) {
class Value : public ObjectRef {
public:
Value() {}
explicit Value(std::shared_ptr<Object> o): ObjectRef(o) {}
explicit Value(std::shared_ptr<Object> o) : ObjectRef(o) {}
const ValueObject* operator->() const {
return static_cast<const ValueObject*>(obj_.get());
......@@ -60,7 +61,7 @@ class ListObject : public Object {
std::vector<std::shared_ptr<Object> > data;
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to list have no effect.
// Visitor to list have no effect.
}
static constexpr const char* _type_key = "List";
......@@ -71,7 +72,7 @@ class ListObject : public Object {
class MapObject : public Object {
public:
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to map have no effect.
// Visitor to map have no effect.
}
// hash function
struct Hash {
......@@ -90,9 +91,7 @@ class MapObject : public Object {
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<
std::shared_ptr<Object>,
std::shared_ptr<Object>,
Hash, Equal>;
std::shared_ptr<Object>, std::shared_ptr<Object>, Hash, Equal>;
/*! \brief the data content */
ContainerType data;
......@@ -101,17 +100,15 @@ class MapObject : public Object {
DGL_DECLARE_OBJECT_TYPE_INFO(MapObject, Object);
};
/*! \brief specialized map obj with string as key */
class StrMapObject : public Object {
public:
void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to map have no effect.
// Visitor to map have no effect.
}
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<
std::string,
std::shared_ptr<Object> >;
using ContainerType =
std::unordered_map<std::string, std::shared_ptr<Object> >;
/*! \brief the data content */
ContainerType data;
......@@ -125,8 +122,7 @@ class StrMapObject : public Object {
* \tparam Converter a struct that contains converting function
* \tparam TIter the content iterator type.
*/
template<typename Converter,
typename TIter>
template <typename Converter, typename TIter>
class IterAdapter {
public:
explicit IterAdapter(TIter iter) : iter_(iter) {}
......@@ -144,9 +140,7 @@ class IterAdapter {
inline bool operator==(IterAdapter other) const {
return iter_ == other.iter_;
}
inline bool operator!=(IterAdapter other) const {
return !(*this == other);
}
inline bool operator!=(IterAdapter other) const { return !(*this == other); }
inline const typename Converter::ResultType operator*() const {
return Converter::convert(*iter_);
}
......@@ -175,7 +169,7 @@ class IterAdapter {
* <code>
* error: no type named 'type' in 'struct std::enable_if<false, void>'
* </code>
*
*
* Example:
*
* <code>
......@@ -183,31 +177,32 @@ class IterAdapter {
* // List<NDArray> list2; // fails
* List<Value> list; // works
* list.push_back(Value(MakeValue(1))); // works
* list.push_back(Value(MakeValue(NDArray::Empty(shape, dtype, ctx)))); // works
* list.push_back(Value(MakeValue(NDArray::Empty(shape, dtype, ctx)))); //
* works
* </code>
*/
template<typename T,
typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type >
template <
typename T,
typename =
typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
class List : public ObjectRef {
public:
/*!
* \brief default constructor
*/
List() {
obj_ = std::make_shared<ListObject>();
}
List() { obj_ = std::make_shared<ListObject>(); }
/*!
* \brief move constructor
* \param other source
*/
List(List<T> && other) { // NOLINT(*)
List(List<T>&& other) { // NOLINT(*)
obj_ = std::move(other.obj_);
}
/*!
* \brief copy constructor
* \param other source
*/
List(const List<T> &other) : ObjectRef(other.obj_) { // NOLINT(*)
List(const List<T>& other) : ObjectRef(other.obj_) { // NOLINT(*)
}
/*!
* \brief constructor from pointer
......@@ -220,7 +215,7 @@ class List : public ObjectRef {
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
template <typename IterType>
List(IterType begin, IterType end) {
assign(begin, end);
}
......@@ -228,20 +223,19 @@ class List : public ObjectRef {
* \brief constructor from initializer list
* \param init The initalizer list
*/
List(std::initializer_list<T> init) { // NOLINT(*)
List(std::initializer_list<T> init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
List(const std::vector<T>& init) { // NOLINT(*)
List(const std::vector<T>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief Constructs a container with n elements. Each element is a copy of val
* \param n The size of the container
* \param val The init value
* \brief Constructs a container with n elements. Each element is a copy of
* val \param n The size of the container \param val The init value
*/
explicit List(size_t n, const T& val) {
auto tmp_obj = std::make_shared<ListObject>();
......@@ -255,7 +249,7 @@ class List : public ObjectRef {
* \param other The source of assignment
* \return reference to self.
*/
List<T>& operator=(List<T> && other) {
List<T>& operator=(List<T>&& other) {
obj_ = std::move(other.obj_);
return *this;
}
......@@ -264,7 +258,7 @@ class List : public ObjectRef {
* \param other The source of assignment
* \return reference to self.
*/
List<T>& operator=(const List<T> & other) {
List<T>& operator=(const List<T>& other) {
obj_ = other.obj_;
return *this;
}
......@@ -274,7 +268,7 @@ class List : public ObjectRef {
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
template <typename IterType>
void assign(IterType begin, IterType end) {
auto n = std::make_shared<ListObject>();
for (IterType it = begin; it != end; ++it) {
......@@ -304,7 +298,7 @@ class List : public ObjectRef {
* \return Handle to the internal obj container(which ganrantees to be unique)
*/
inline ListObject* CopyOnWrite() {
if (obj_.get() == nullptr || !obj_.unique()) {
if (obj_.get() == nullptr || !obj_.unique()) {
obj_ = std::make_shared<ListObject>(
*static_cast<const ListObject*>(obj_.get()));
}
......@@ -328,9 +322,7 @@ class List : public ObjectRef {
n->data[i] = value.obj_;
}
/*! \return whether list is empty */
inline bool empty() const {
return size() == 0;
}
inline bool empty() const { return size() == 0; }
/*! \brief Copy the content to a vector */
inline std::vector<T> ToVector() const {
return std::vector<T>(begin(), end());
......@@ -340,16 +332,14 @@ class List : public ObjectRef {
struct Ptr2ObjectRef {
using ResultType = T;
static inline T convert(const std::shared_ptr<Object>& n) {
return T(n);
}
static inline T convert(const std::shared_ptr<Object>& n) { return T(n); }
};
using iterator = IterAdapter<Ptr2ObjectRef,
std::vector<std::shared_ptr<Object> >::const_iterator>;
using iterator = IterAdapter<
Ptr2ObjectRef, std::vector<std::shared_ptr<Object> >::const_iterator>;
using reverse_iterator = IterAdapter<
Ptr2ObjectRef,
std::vector<std::shared_ptr<Object> >::const_reverse_iterator>;
Ptr2ObjectRef,
std::vector<std::shared_ptr<Object> >::const_reverse_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
......@@ -361,11 +351,13 @@ class List : public ObjectRef {
}
/*! \return rbegin iterator */
inline reverse_iterator rbegin() const {
return reverse_iterator(static_cast<const ListObject*>(obj_.get())->data.rbegin());
return reverse_iterator(
static_cast<const ListObject*>(obj_.get())->data.rbegin());
}
/*! \return rend iterator */
inline reverse_iterator rend() const {
return reverse_iterator(static_cast<const ListObject*>(obj_.get())->data.rend());
return reverse_iterator(
static_cast<const ListObject*>(obj_.get())->data.rend());
}
};
......@@ -390,7 +382,7 @@ class List : public ObjectRef {
* <code>
* error: no type named 'type' in 'struct std::enable_if<false, void>'
* </code>
*
*
* Example:
*
* <code>
......@@ -398,35 +390,35 @@ class List : public ObjectRef {
* // Map<std::string, NDArray> map2; // fails
* Map<std::string, Value> map; // works
* map.Set("key1", Value(MakeValue(1))); // works
* map.Set("key2", Value(MakeValue(NDArray::Empty(shape, dtype, ctx)))); // works
* map.Set("key2", Value(MakeValue(NDArray::Empty(shape, dtype, ctx)))); //
* works
* </code>
*/
template<typename K,
typename V,
typename = typename std::enable_if<
std::is_base_of<ObjectRef, K>::value ||
std::is_base_of<std::string, K>::value >::type,
typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
template <
typename K, typename V,
typename = typename std::enable_if<
std::is_base_of<ObjectRef, K>::value ||
std::is_base_of<std::string, K>::value>::type,
typename =
typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
class Map : public ObjectRef {
public:
/*!
* \brief default constructor
*/
Map() {
obj_ = std::make_shared<MapObject>();
}
Map() { obj_ = std::make_shared<MapObject>(); }
/*!
* \brief move constructor
* \param other source
*/
Map(Map<K, V> && other) { // NOLINT(*)
Map(Map<K, V>&& other) { // NOLINT(*)
obj_ = std::move(other.obj_);
}
/*!
* \brief copy constructor
* \param other source
*/
Map(const Map<K, V> &other) : ObjectRef(other.obj_) { // NOLINT(*)
Map(const Map<K, V>& other) : ObjectRef(other.obj_) { // NOLINT(*)
}
/*!
* \brief constructor from pointer
......@@ -439,7 +431,7 @@ class Map : public ObjectRef {
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
template <typename IterType>
Map(IterType begin, IterType end) {
assign(begin, end);
}
......@@ -447,15 +439,15 @@ class Map : public ObjectRef {
* \brief constructor from initializer list
* \param init The initalizer list
*/
Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*)
Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
template<typename Hash, typename Equal>
Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
template <typename Hash, typename Equal>
Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
......@@ -463,7 +455,7 @@ class Map : public ObjectRef {
* \param other The source of assignment
* \return reference to self.
*/
Map<K, V>& operator=(Map<K, V> && other) {
Map<K, V>& operator=(Map<K, V>&& other) {
obj_ = std::move(other.obj_);
return *this;
}
......@@ -472,7 +464,7 @@ class Map : public ObjectRef {
* \param other The source of assignment
* \return reference to self.
*/
Map<K, V>& operator=(const Map<K, V> & other) {
Map<K, V>& operator=(const Map<K, V>& other) {
obj_ = other.obj_;
return *this;
}
......@@ -482,12 +474,11 @@ class Map : public ObjectRef {
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template<typename IterType>
template <typename IterType>
void assign(IterType begin, IterType end) {
auto n = std::shared_ptr<MapObject>();
for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first.obj_,
i->second.obj_));
n->data.emplace(std::make_pair(i->first.obj_, i->second.obj_));
}
obj_ = std::move(n);
}
......@@ -526,7 +517,7 @@ class Map : public ObjectRef {
* \return Handle to the internal obj container(which ganrantees to be unique)
*/
inline MapObject* CopyOnWrite() {
if (obj_.get() == nullptr || !obj_.unique()) {
if (obj_.get() == nullptr || !obj_.unique()) {
obj_ = std::make_shared<MapObject>(
*static_cast<const MapObject*>(obj_.get()));
}
......@@ -543,23 +534,20 @@ class Map : public ObjectRef {
}
/*! \return whether list is empty */
inline bool empty() const {
return size() == 0;
}
inline bool empty() const { return size() == 0; }
/*! \brief specify container obj */
using ContainerType = MapObject;
struct Ptr2ObjectRef {
using ResultType = std::pair<K, V>;
static inline ResultType convert(const std::pair<
std::shared_ptr<Object>,
std::shared_ptr<Object> >& n) {
static inline ResultType convert(
const std::pair<std::shared_ptr<Object>, std::shared_ptr<Object> >& n) {
return std::make_pair(K(n.first), V(n.second));
}
};
using iterator = IterAdapter<
Ptr2ObjectRef, MapObject::ContainerType::const_iterator>;
using iterator =
IterAdapter<Ptr2ObjectRef, MapObject::ContainerType::const_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
......@@ -571,50 +559,49 @@ class Map : public ObjectRef {
}
/*! \return begin iterator */
inline iterator find(const K& key) const {
return iterator(static_cast<const MapObject*>(obj_.get())->data.find(key.obj_));
return iterator(
static_cast<const MapObject*>(obj_.get())->data.find(key.obj_));
}
};
// specialize of string map
template<typename V, typename T1, typename T2>
template <typename V, typename T1, typename T2>
class Map<std::string, V, T1, T2> : public ObjectRef {
public:
// for code reuse
Map() {
obj_ = std::make_shared<StrMapObject>();
}
Map(Map<std::string, V> && other) { // NOLINT(*)
Map() { obj_ = std::make_shared<StrMapObject>(); }
Map(Map<std::string, V>&& other) { // NOLINT(*)
obj_ = std::move(other.obj_);
}
Map(const Map<std::string, V> &other) : ObjectRef(other.obj_) { // NOLINT(*)
Map(const Map<std::string, V>& other) : ObjectRef(other.obj_) { // NOLINT(*)
}
explicit Map(std::shared_ptr<Object> n) : ObjectRef(n) {}
template<typename IterType>
template <typename IterType>
Map(IterType begin, IterType end) {
assign(begin, end);
}
Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*)
Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*)
assign(init.begin(), init.end());
}
template<typename Hash, typename Equal>
Map(const std::unordered_map<std::string, V, Hash, Equal>& init) { // NOLINT(*)
template <typename Hash, typename Equal>
Map(const std::unordered_map<std::string, V, Hash, Equal>&
init) { // NOLINT(*)
assign(init.begin(), init.end());
}
Map<std::string, V>& operator=(Map<std::string, V> && other) {
Map<std::string, V>& operator=(Map<std::string, V>&& other) {
obj_ = std::move(other.obj_);
return *this;
}
Map<std::string, V>& operator=(const Map<std::string, V> & other) {
Map<std::string, V>& operator=(const Map<std::string, V>& other) {
obj_ = other.obj_;
return *this;
}
template<typename IterType>
template <typename IterType>
void assign(IterType begin, IterType end) {
auto n = std::make_shared<StrMapObject>();
for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first,
i->second.obj_));
n->data.emplace(std::make_pair(i->first, i->second.obj_));
}
obj_ = std::move(n);
}
......@@ -633,7 +620,7 @@ class Map<std::string, V, T1, T2> : public ObjectRef {
return static_cast<const StrMapObject*>(obj_.get())->data.count(key);
}
inline StrMapObject* CopyOnWrite() {
if (obj_.get() == nullptr || !obj_.unique()) {
if (obj_.get() == nullptr || !obj_.unique()) {
obj_ = std::make_shared<MapObject>(
*static_cast<const MapObject*>(obj_.get()));
}
......@@ -643,22 +630,19 @@ class Map<std::string, V, T1, T2> : public ObjectRef {
StrMapObject* n = this->CopyOnWrite();
n->data[key] = value.obj_;
}
inline bool empty() const {
return size() == 0;
}
inline bool empty() const { return size() == 0; }
using ContainerType = StrMapObject;
struct Ptr2ObjectRef {
using ResultType = std::pair<std::string, V>;
static inline ResultType convert(const std::pair<
std::string,
std::shared_ptr<Object> >& n) {
static inline ResultType convert(
const std::pair<std::string, std::shared_ptr<Object> >& n) {
return std::make_pair(n.first, V(n.second));
}
};
using iterator = IterAdapter<
Ptr2ObjectRef, StrMapObject::ContainerType::const_iterator>;
using iterator =
IterAdapter<Ptr2ObjectRef, StrMapObject::ContainerType::const_iterator>;
/*! \return begin iterator */
inline iterator begin() const {
......@@ -670,7 +654,8 @@ class Map<std::string, V, T1, T2> : public ObjectRef {
}
/*! \return begin iterator */
inline iterator find(const std::string& key) const {
return iterator(static_cast<const StrMapObject*>(obj_.get())->data.find(key));
return iterator(
static_cast<const StrMapObject*>(obj_.get())->data.find(key));
}
};
......
......@@ -7,8 +7,9 @@
#define DGL_RUNTIME_DEVICE_API_H_
#include <string>
#include "packed_func.h"
#include "c_runtime_api.h"
#include "packed_func.h"
namespace dgl {
namespace runtime {
......@@ -30,7 +31,8 @@ enum DeviceAttrKind : int {
/*! \brief Number of bytes each allocation must align to */
constexpr int kAllocAlignment = 64;
/*! \brief Number of bytes each allocation must align to in temporary allocation */
/*! \brief Number of bytes each allocation must align to in temporary allocation
*/
constexpr int kTempAllocaAlignment = 64;
/*! \brief Maximum size that can be allocated on stack */
......@@ -47,9 +49,7 @@ class DeviceAPI {
/*!
* \brief Check whether the device is available.
*/
virtual bool IsAvailable() {
return true;
}
virtual bool IsAvailable() { return true; }
/*!
* \brief Set the environment device id to ctx
* \param ctx The context to be set.
......@@ -62,7 +62,8 @@ class DeviceAPI {
* \param rv The return value.
* \sa DeviceAttrKind
*/
virtual void GetAttr(DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) = 0;
virtual void GetAttr(
DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) = 0;
/*!
* \brief Allocate a data space on device.
* \param ctx The device context to perform operation.
......@@ -72,10 +73,9 @@ class DeviceAPI {
* as OpenGL, as nbytes & alignment are sufficient for most backends.
* \return The allocated device pointer.
*/
virtual void* AllocDataSpace(DGLContext ctx,
size_t nbytes,
size_t alignment,
DGLDataType type_hint) = 0;
virtual void* AllocDataSpace(
DGLContext ctx, size_t nbytes, size_t alignment,
DGLDataType type_hint) = 0;
/*!
* \brief Free a data space on device.
* \param ctx The device context to perform operation.
......@@ -94,15 +94,11 @@ class DeviceAPI {
* \param type_hint The type of elements, only neded by certain backends.
* can be useful for cross device endian converison.
*/
virtual void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t num_bytes,
DGLContext ctx_from,
DGLContext ctx_to,
DGLDataType type_hint) = 0;
/*!
virtual void CopyDataFromTo(
const void* from, size_t from_offset, void* to, size_t to_offset,
size_t num_bytes, DGLContext ctx_from, DGLContext ctx_to,
DGLDataType type_hint) = 0;
/*!
* \brief Create a new stream of execution.
*
* \param ctx The context of allocation.
......@@ -145,31 +141,28 @@ class DeviceAPI {
* \param event_src The source stream to synchronize.
* \param event_dst The destination stream to synchronize.
*/
DGL_DLL virtual void SyncStreamFromTo(DGLContext ctx,
DGLStreamHandle event_src,
DGLStreamHandle event_dst);
DGL_DLL virtual void SyncStreamFromTo(
DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst);
/*!
* \brief Pin host memory using cudaHostRegister().
*
* \param ptr The host memory pointer to be pinned.
* \param nbytes The size to be pinned.
*/
* \param nbytes The size to be pinned.
*/
DGL_DLL virtual void PinData(void* ptr, size_t nbytes);
/*!
* \brief Unpin host memory using cudaHostUnregister().
*
* \param ptr The host memory pointer to be unpinned.
*/
* \param ptr The host memory pointer to be unpinned.
*/
DGL_DLL virtual void UnpinData(void* ptr);
/*!
* \brief Check whether the memory is in pinned memory.
*/
DGL_DLL virtual bool IsPinned(const void* ptr) {
return false;
}
DGL_DLL virtual bool IsPinned(const void* ptr) { return false; }
/*!
* \brief Allocate temporal workspace for backend execution.
......@@ -180,16 +173,16 @@ class DeviceAPI {
* - Only a few allocation will happen, and space will be released after use.
* - The release order is usually in reverse order of allocate (stack style).
* - Repeative pattern of same allocations over different runs.
* - Workspace should not overlap between different threads(i.e. be threadlocal)
* - Workspace should not overlap between different threads(i.e. be
* threadlocal)
*
* \param ctx The context of allocation.
* \param nbytes The size to be allocated.
* \param type_hint The type of elements. Only needed by certain backends such
* as OpenGL, as nbytes is sufficient for most backends.
*/
DGL_DLL virtual void* AllocWorkspace(DGLContext ctx,
size_t nbytes,
DGLDataType type_hint = {});
DGL_DLL virtual void* AllocWorkspace(
DGLContext ctx, size_t nbytes, DGLDataType type_hint = {});
/*!
* \brief Free temporal workspace in backend execution.
*
......@@ -206,14 +199,14 @@ class DeviceAPI {
*/
DGL_DLL static DeviceAPI* Get(DGLContext ctx, bool allow_missing = false);
/*!
* \brief Get device API based on context.
* \param dev_type The device type
* \param allow_missing Whether allow missing
* \return The corresponding device API.
*/
DGL_DLL static DeviceAPI* Get(DGLDeviceType dev_type, bool allow_missing = false);
DGL_DLL static DeviceAPI* Get(
DGLDeviceType dev_type, bool allow_missing = false);
};
/*! \brief The device type bigger than this is RPC device */
......
......@@ -31,10 +31,10 @@ struct DLPackConvert {
/*!
* \brief Deleter for NDArray converted from DLPack.
*
* This is used from data which is passed from external DLPack(DLManagedTensor)
* that are not allocated inside of DGL.
* This enables us to create NDArray from memory allocated by other
* frameworks that are DLPack compatible
* This is used from data which is passed from external
* DLPack(DLManagedTensor) that are not allocated inside of DGL. This enables
* us to create NDArray from memory allocated by other frameworks that are
* DLPack compatible
*/
static void DLPackDeleter(NDArray::Container* ptr);
......@@ -43,7 +43,7 @@ struct DLPackConvert {
* \param from The DGL NDArray.
* \return A DLPack tensor.
*/
static DLManagedTensor* ToDLPack(const NDArray &from);
static DLManagedTensor* ToDLPack(const NDArray& from);
};
} // namespace runtime
......@@ -66,8 +66,7 @@ DGL_DLL void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor);
* \param out The output array handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from,
DGLArrayHandle* out);
DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from, DGLArrayHandle* out);
/*!
* \brief Produce a DLMangedTensor from the array that shares data memory with
......@@ -76,8 +75,8 @@ DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from,
* \param out The DLManagedTensor handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out,
int alignment = 0);
DGL_DLL int DGLArrayToDLPack(
DGLArrayHandle from, DLManagedTensor** out, int alignment = 0);
#ifdef __cplusplus
} // DGL_EXTERN_C
......
......@@ -9,10 +9,12 @@
#define DGL_RUNTIME_MODULE_H_
#include <dmlc/io.h>
#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include <vector>
#include "c_runtime_api.h"
namespace dgl {
......@@ -29,8 +31,7 @@ class Module {
public:
Module() {}
// constructor from container.
explicit Module(std::shared_ptr<ModuleNode> n)
: node_(n) {}
explicit Module(std::shared_ptr<ModuleNode> n) : node_(n) {}
/*!
* \brief Get packed function from current module by name.
*
......@@ -40,7 +41,8 @@ class Module {
* This function will return PackedFunc(nullptr) if function do not exist.
* \note Implemented in packed_func.cc
*/
inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
inline PackedFunc GetFunction(
const std::string& name, bool query_imports = false);
/*! \return internal container */
inline ModuleNode* operator->();
/*! \return internal container */
......@@ -61,8 +63,8 @@ class Module {
* \note This function won't load the import relationship.
* Re-create import relationship by calling Import.
*/
DGL_DLL static Module LoadFromFile(const std::string& file_name,
const std::string& format = "");
DGL_DLL static Module LoadFromFile(
const std::string& file_name, const std::string& format = "");
private:
std::shared_ptr<ModuleNode> node_;
......@@ -103,8 +105,8 @@ class ModuleNode {
* \param file_name The file to be saved to.
* \param format The format of the file.
*/
virtual void SaveToFile(const std::string& file_name,
const std::string& format);
virtual void SaveToFile(
const std::string& file_name, const std::string& format);
/*!
* \brief Save the module to binary stream.
* \param stream The binary stream to save to.
......@@ -128,9 +130,7 @@ class ModuleNode {
*/
DGL_DLL const PackedFunc* GetFuncFromEnv(const std::string& name);
/*! \return The module it imports from */
const std::vector<Module>& imports() const {
return imports_;
}
const std::vector<Module>& imports() const { return imports_; }
protected:
friend class Module;
......@@ -139,8 +139,7 @@ class ModuleNode {
private:
/*! \brief Cache used by GetImport */
std::unordered_map<std::string,
std::unique_ptr<PackedFunc> > import_cache_;
std::unordered_map<std::string, std::unique_ptr<PackedFunc> > import_cache_;
};
/*! \brief namespace for constant symbols */
......@@ -155,20 +154,19 @@ constexpr const char* dgl_dev_mblob_nbytes = "__dgl_dev_mblob_nbytes";
constexpr const char* dgl_set_device = "__dgl_set_device";
/*! \brief Auxiliary counter to global barrier. */
constexpr const char* dgl_global_barrier_state = "__dgl_global_barrier_state";
/*! \brief Prepare the global barrier before kernels that uses global barrier. */
constexpr const char* dgl_prepare_global_barrier = "__dgl_prepare_global_barrier";
/*!
* \brief Prepare the global barrier before kernels that uses global barrier.
*/
constexpr const char* dgl_prepare_global_barrier =
"__dgl_prepare_global_barrier";
/*! \brief Placeholder for the module's entry function. */
constexpr const char* dgl_module_main = "__dgl_main__";
} // namespace symbol
// implementations of inline functions.
inline ModuleNode* Module::operator->() {
return node_.get();
}
inline ModuleNode* Module::operator->() { return node_.get(); }
inline const ModuleNode* Module::operator->() const {
return node_.get();
}
inline const ModuleNode* Module::operator->() const { return node_.get(); }
} // namespace runtime
} // namespace dgl
......
......@@ -7,10 +7,11 @@
#define DGL_RUNTIME_OBJECT_H_
#include <dmlc/logging.h>
#include <string>
#include <vector>
#include <memory>
#include <string>
#include <type_traits>
#include <vector>
namespace dgl {
namespace runtime {
......@@ -26,7 +27,7 @@ class NDArray;
*/
class AttrVisitor {
public:
//! \cond Doxygen_Suppress
//! \cond Doxygen_Suppress
virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, uint64_t* value) = 0;
......@@ -35,14 +36,16 @@ class AttrVisitor {
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, ObjectRef* value) = 0;
virtual void Visit(const char* key, NDArray* value) = 0;
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
template <
typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
static_assert(
std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
this->Visit(key, reinterpret_cast<int*>(ptr));
}
//! \endcond
//! \endcond
};
/*!
......@@ -87,20 +90,19 @@ class Object {
/*!
* \return whether the type is derived from
*/
template<typename T>
template <typename T>
inline bool derived_from() const;
/*!
* \return whether the object is of type T
* \tparam The type to be checked.
*/
template<typename T>
template <typename T>
inline bool is_type() const;
// object ref can see this
friend class ObjectRef;
static constexpr const char* _type_key = "Object";
};
/*! \brief base class of all reference object */
class ObjectRef {
public:
......@@ -109,7 +111,8 @@ class ObjectRef {
/*!
* \brief Comparator
*
* Compare with the two are referencing to the same object (compare by address).
* Compare with the two are referencing to the same object (compare by
* address).
*
* \param other Another object ref.
* \return the compare result.
......@@ -119,7 +122,8 @@ class ObjectRef {
/*!
* \brief Comparator
*
* Compare with the two are referencing to the same object (compare by address).
* Compare with the two are referencing to the same object (compare by
* address).
*
* \param other Another object ref.
* \return the compare result.
......@@ -161,8 +165,8 @@ class ObjectRef {
* }
* \tparam T the target type, must be subtype of Object
*/
template<typename T>
inline const T *as() const;
template <typename T>
inline const T* as() const;
/*! \brief default constructor */
ObjectRef() = default;
......@@ -178,11 +182,11 @@ class ObjectRef {
* This is macro should be used in abstract base class definition
* because it does not define type_key and type_index.
*/
#define DGL_DECLARE_BASE_OBJECT_INFO(TypeName, Parent) \
const bool _DerivedFrom(uint32_t tid) const override { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
#define DGL_DECLARE_BASE_OBJECT_INFO(TypeName, Parent) \
const bool _DerivedFrom(uint32_t tid) const override { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
}
/*!
......@@ -206,69 +210,61 @@ class ObjectRef {
* DGL_DECLARE_OBJECT_TYPE_INFO(SomeChildClass, SomeBaseClass);
* };
*/
#define DGL_DECLARE_OBJECT_TYPE_INFO(TypeName, Parent) \
const char* type_key() const final { \
return TypeName::_type_key; \
} \
uint32_t type_index() const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
return tidx; \
} \
bool _DerivedFrom(uint32_t tid) const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
#define DGL_DECLARE_OBJECT_TYPE_INFO(TypeName, Parent) \
const char* type_key() const final { return TypeName::_type_key; } \
uint32_t type_index() const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
return tidx; \
} \
bool _DerivedFrom(uint32_t tid) const final { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
}
/*! \brief Macro to generate common object reference class method definition */
#define DGL_DEFINE_OBJECT_REF_METHODS(TypeName, BaseTypeName, ObjectName) \
TypeName() {} \
explicit TypeName(std::shared_ptr<runtime::Object> obj): BaseTypeName(obj) {} \
const ObjectName* operator->() const { \
return static_cast<const ObjectName*>(obj_.get()); \
} \
ObjectName* operator->() { \
return static_cast<ObjectName*>(obj_.get()); \
} \
std::shared_ptr<ObjectName> sptr() const { \
return CHECK_NOTNULL(std::dynamic_pointer_cast<ObjectName>(obj_)); \
} \
operator bool() const { return this->defined(); } \
#define DGL_DEFINE_OBJECT_REF_METHODS(TypeName, BaseTypeName, ObjectName) \
TypeName() {} \
explicit TypeName(std::shared_ptr<runtime::Object> obj) \
: BaseTypeName(obj) {} \
const ObjectName* operator->() const { \
return static_cast<const ObjectName*>(obj_.get()); \
} \
ObjectName* operator->() { return static_cast<ObjectName*>(obj_.get()); } \
std::shared_ptr<ObjectName> sptr() const { \
return CHECK_NOTNULL(std::dynamic_pointer_cast<ObjectName>(obj_)); \
} \
operator bool() const { return this->defined(); } \
using ContainerType = ObjectName
/*! \brief Macro to generate object reference class definition */
#define DGL_DEFINE_OBJECT_REF(TypeName, ObjectName) \
class TypeName : public ::dgl::runtime::ObjectRef { \
public: \
DGL_DEFINE_OBJECT_REF_METHODS(TypeName, ::dgl::runtime::ObjectRef, ObjectName); \
#define DGL_DEFINE_OBJECT_REF(TypeName, ObjectName) \
class TypeName : public ::dgl::runtime::ObjectRef { \
public: \
DGL_DEFINE_OBJECT_REF_METHODS( \
TypeName, ::dgl::runtime::ObjectRef, ObjectName); \
}
// implementations of inline functions after this
template<typename T>
template <typename T>
inline bool Object::is_type() const {
// use static field so query only happens once.
static uint32_t type_id = Object::TypeKey2Index(T::_type_key);
return type_id == this->type_index();
}
template<typename T>
template <typename T>
inline bool Object::derived_from() const {
// use static field so query only happens once.
static uint32_t type_id = Object::TypeKey2Index(T::_type_key);
return this->_DerivedFrom(type_id);
}
inline const Object* ObjectRef::get() const {
return obj_.get();
}
inline const Object* ObjectRef::get() const { return obj_.get(); }
inline const Object* ObjectRef::operator->() const {
return obj_.get();
}
inline const Object* ObjectRef::operator->() const { return obj_.get(); }
inline bool ObjectRef::defined() const {
return obj_.get() != nullptr;
}
inline bool ObjectRef::defined() const { return obj_.get() != nullptr; }
inline bool ObjectRef::operator==(const ObjectRef& other) const {
return obj_.get() == other.obj_.get();
......@@ -295,7 +291,7 @@ inline uint32_t ObjectRef::type_index() const {
return get()->type_index();
}
template<typename T>
template <typename T>
inline const T* ObjectRef::as() const {
const Object* ptr = get();
if (ptr && ptr->is_type<T>()) {
......@@ -306,9 +302,7 @@ inline const T* ObjectRef::as() const {
/*! \brief The hash function for nodes */
struct ObjectHash {
size_t operator()(const ObjectRef& a) const {
return a.hash();
}
size_t operator()(const ObjectRef& a) const { return a.hash(); }
};
/*! \brief The equal comparator for nodes */
......
This diff is collapsed.
This diff is collapsed.
......@@ -9,6 +9,7 @@
#include <dmlc/io.h>
#include <dmlc/serializer.h>
#include "c_runtime_api.h"
#include "smart_ptr_serializer.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment