Unverified Commit 07dc8fb6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4800)



* clang-format

* manul

* manul

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 9d9280cb
...@@ -32,8 +32,7 @@ DGL_DLL int DGLObjectFree(ObjectHandle handle); ...@@ -32,8 +32,7 @@ DGL_DLL int DGLObjectFree(ObjectHandle handle);
* \param out_index the corresponding type index. * \param out_index the corresponding type index.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLObjectTypeKey2Index(const char* type_key, DGL_DLL int DGLObjectTypeKey2Index(const char* type_key, int* out_index);
int* out_index);
/*! /*!
* \brief Get runtime type index of the object. * \brief Get runtime type index of the object.
...@@ -41,8 +40,7 @@ DGL_DLL int DGLObjectTypeKey2Index(const char* type_key, ...@@ -41,8 +40,7 @@ DGL_DLL int DGLObjectTypeKey2Index(const char* type_key,
* \param out_index the corresponding type index. * \param out_index the corresponding type index.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLObjectGetTypeIndex(ObjectHandle handle, DGL_DLL int DGLObjectGetTypeIndex(ObjectHandle handle, int* out_index);
int* out_index);
/*! /*!
* \brief get attributes given key * \brief get attributes given key
...@@ -54,11 +52,9 @@ DGL_DLL int DGLObjectGetTypeIndex(ObjectHandle handle, ...@@ -54,11 +52,9 @@ DGL_DLL int DGLObjectGetTypeIndex(ObjectHandle handle,
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
* \note API calls always exchanges with type bits=64, lanes=1 * \note API calls always exchanges with type bits=64, lanes=1
*/ */
DGL_DLL int DGLObjectGetAttr(ObjectHandle handle, DGL_DLL int DGLObjectGetAttr(
const char* key, ObjectHandle handle, const char* key, DGLValue* out_value,
DGLValue* out_value, int* out_type_code, int* out_success);
int* out_type_code,
int* out_success);
/*! /*!
* \brief get attributes names in the object. * \brief get attributes names in the object.
...@@ -67,9 +63,8 @@ DGL_DLL int DGLObjectGetAttr(ObjectHandle handle, ...@@ -67,9 +63,8 @@ DGL_DLL int DGLObjectGetAttr(ObjectHandle handle,
* \param out_array The array of function names. * \param out_array The array of function names.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLObjectListAttrNames(ObjectHandle handle, DGL_DLL int DGLObjectListAttrNames(
int *out_size, ObjectHandle handle, int* out_size, const char*** out_array);
const char*** out_array);
#ifdef __cplusplus #ifdef __cplusplus
} // DGL_EXTERN_C } // DGL_EXTERN_C
#endif #endif
......
...@@ -38,8 +38,8 @@ ...@@ -38,8 +38,8 @@
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
#include <stdint.h>
#include <stddef.h> #include <stddef.h>
#include <stdint.h>
/*! \brief type of array index. */ /*! \brief type of array index. */
typedef int64_t dgl_index_t; typedef int64_t dgl_index_t;
...@@ -60,7 +60,8 @@ typedef enum { ...@@ -60,7 +60,8 @@ typedef enum {
} DGLDeviceType; } DGLDeviceType;
/*! /*!
* \brief The object type code is used in DGL FFI to indicate the types of objects passed between C and Python. * \brief The object type code is used in DGL FFI to indicate the types of
* objects passed between C and Python.
*/ */
typedef enum { typedef enum {
kInt = 0U, kInt = 0U,
...@@ -105,9 +106,9 @@ typedef enum { ...@@ -105,9 +106,9 @@ typedef enum {
} DGLDataTypeCode; } DGLDataTypeCode;
/*! /*!
* \brief The data type the tensor can hold. The data type is assumed to follow the * \brief The data type the tensor can hold. The data type is assumed to follow
* native endian-ness. An explicit error message should be raised when attempting to * the native endian-ness. An explicit error message should be raised when
* export an array with non-native endianness * attempting to export an array with non-native endianness
* *
* Examples * Examples
* - float: type_code = 2, bits = 32, lanes=1 * - float: type_code = 2, bits = 32, lanes=1
...@@ -149,12 +150,12 @@ typedef struct { ...@@ -149,12 +150,12 @@ typedef struct {
typedef struct { typedef struct {
/*! /*!
* \brief The data pointer points to the allocated data. * \brief The data pointer points to the allocated data.
* *
* Depending on the device context, it can be a CPU pointer, or a CUDA * Depending on the device context, it can be a CPU pointer, or a CUDA
* device pointer or acl_mem handle in OpenCL. * device pointer or acl_mem handle in OpenCL.
* This pointer is always aligned to 256 bytes as in CUDA. Use the * This pointer is always aligned to 256 bytes as in CUDA. Use the
* `byte_offset` field to mark the beginning of the actual data (if the address * `byte_offset` field to mark the beginning of the actual data (if the
* is not 256 byte aligned). * address is not 256 byte aligned).
* *
* Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow, * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow,
* TVM, perhaps others) do not adhere to this 256 byte alignment requirement * TVM, perhaps others) do not adhere to this 256 byte alignment requirement
...@@ -247,7 +248,7 @@ DGL_DLL void DGLAPISetLastError(const char* msg); ...@@ -247,7 +248,7 @@ DGL_DLL void DGLAPISetLastError(const char* msg);
* this function is threadsafe and can be called by different thread * this function is threadsafe and can be called by different thread
* \return error info * \return error info
*/ */
DGL_DLL const char *DGLGetLastError(void); DGL_DLL const char* DGLGetLastError(void);
/*! /*!
* \brief Load module from file. * \brief Load module from file.
* \param file_name The file name to load the module from. * \param file_name The file name to load the module from.
...@@ -258,9 +259,8 @@ DGL_DLL const char *DGLGetLastError(void); ...@@ -258,9 +259,8 @@ DGL_DLL const char *DGLGetLastError(void);
* \note The resulting module do not contain import relation. * \note The resulting module do not contain import relation.
* It can be reconstructed by DGLModImport. * It can be reconstructed by DGLModImport.
*/ */
DGL_DLL int DGLModLoadFromFile(const char* file_name, DGL_DLL int DGLModLoadFromFile(
const char* format, const char* file_name, const char* format, DGLModuleHandle* out);
DGLModuleHandle* out);
/*! /*!
* \brief Add dep to mod's dependency. * \brief Add dep to mod's dependency.
...@@ -270,8 +270,7 @@ DGL_DLL int DGLModLoadFromFile(const char* file_name, ...@@ -270,8 +270,7 @@ DGL_DLL int DGLModLoadFromFile(const char* file_name,
* \param dep The dependent module to be imported. * \param dep The dependent module to be imported.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLModImport(DGLModuleHandle mod, DGL_DLL int DGLModImport(DGLModuleHandle mod, DGLModuleHandle dep);
DGLModuleHandle dep);
/*! /*!
* \brief Get function from the module. * \brief Get function from the module.
...@@ -281,10 +280,9 @@ DGL_DLL int DGLModImport(DGLModuleHandle mod, ...@@ -281,10 +280,9 @@ DGL_DLL int DGLModImport(DGLModuleHandle mod,
* \param out The result function, can be NULL if it is not available. * \param out The result function, can be NULL if it is not available.
* \return 0 when no error is thrown, -1 when failure happens * \return 0 when no error is thrown, -1 when failure happens
*/ */
DGL_DLL int DGLModGetFunction(DGLModuleHandle mod, DGL_DLL int DGLModGetFunction(
const char* func_name, DGLModuleHandle mod, const char* func_name, int query_imports,
int query_imports, DGLFunctionHandle* out);
DGLFunctionHandle *out);
/*! /*!
* \brief Free front-end extension type resource. * \brief Free front-end extension type resource.
...@@ -334,12 +332,9 @@ DGL_DLL int DGLFuncFree(DGLFunctionHandle func); ...@@ -334,12 +332,9 @@ DGL_DLL int DGLFuncFree(DGLFunctionHandle func);
* The front-end need to call free function (e.g. DGLFuncFree) * The front-end need to call free function (e.g. DGLFuncFree)
* to free these handles. * to free these handles.
*/ */
DGL_DLL int DGLFuncCall(DGLFunctionHandle func, DGL_DLL int DGLFuncCall(
DGLValue* arg_values, DGLFunctionHandle func, DGLValue* arg_values, int* type_codes, int num_args,
int* type_codes, DGLValue* ret_val, int* ret_type_code);
int num_args,
DGLValue* ret_val,
int* ret_type_code);
/*! /*!
* \brief Set the return value of DGLPackedCFunc. * \brief Set the return value of DGLPackedCFunc.
...@@ -352,10 +347,8 @@ DGL_DLL int DGLFuncCall(DGLFunctionHandle func, ...@@ -352,10 +347,8 @@ DGL_DLL int DGLFuncCall(DGLFunctionHandle func,
* \param type_code The type of the value to be returned. * \param type_code The type of the value to be returned.
* \param num_ret Number of return values, for now only 1 is supported. * \param num_ret Number of return values, for now only 1 is supported.
*/ */
DGL_DLL int DGLCFuncSetReturn(DGLRetValueHandle ret, DGL_DLL int DGLCFuncSetReturn(
DGLValue* value, DGLRetValueHandle ret, DGLValue* value, int* type_code, int num_ret);
int* type_code,
int num_ret);
/*! /*!
* \brief Inplace translate callback argument value to return value. * \brief Inplace translate callback argument value to return value.
...@@ -377,14 +370,12 @@ DGL_DLL int DGLCbArgToReturn(DGLValue* value, int code); ...@@ -377,14 +370,12 @@ DGL_DLL int DGLCbArgToReturn(DGLValue* value, int code);
* \param num_args Number of arguments. * \param num_args Number of arguments.
* \param ret The return value handle. * \param ret The return value handle.
* \param resource_handle The handle additional resouce handle from fron-end. * \param resource_handle The handle additional resouce handle from fron-end.
* \return 0 if success, -1 if failure happens, set error via DGLAPISetLastError. * \return 0 if success, -1 if failure happens, set error via
* DGLAPISetLastError.
* \sa DGLCFuncSetReturn * \sa DGLCFuncSetReturn
*/ */
typedef int (*DGLPackedCFunc)( typedef int (*DGLPackedCFunc)(
DGLValue* args, DGLValue* args, int* type_codes, int num_args, DGLRetValueHandle ret,
int* type_codes,
int num_args,
DGLRetValueHandle ret,
void* resource_handle); void* resource_handle);
/*! /*!
...@@ -407,18 +398,19 @@ typedef int (*DGLExtensionFuncDeclarer)(DGLFunctionHandle register_func_handle); ...@@ -407,18 +398,19 @@ typedef int (*DGLExtensionFuncDeclarer)(DGLFunctionHandle register_func_handle);
/*! /*!
* \brief Wrap a DGLPackedCFunc to become a FunctionHandle. * \brief Wrap a DGLPackedCFunc to become a FunctionHandle.
* *
* The resource_handle will be managed by DGL API, until the function is no longer used. * The resource_handle will be managed by DGL API, until the function is no
* longer used.
* *
* \param func The packed C function. * \param func The packed C function.
* \param resource_handle The resource handle from front-end, can be NULL. * \param resource_handle The resource handle from front-end, can be NULL.
* \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL * \param fin The finalizer on resource handle when the FunctionHandle get
* freed, can be NULL.
* \param out the result function handle. * \param out the result function handle.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens.
*/ */
DGL_DLL int DGLFuncCreateFromCFunc(DGLPackedCFunc func, DGL_DLL int DGLFuncCreateFromCFunc(
void* resource_handle, DGLPackedCFunc func, void* resource_handle, DGLPackedCFuncFinalizer fin,
DGLPackedCFuncFinalizer fin, DGLFunctionHandle* out);
DGLFunctionHandle *out);
/*! /*!
* \brief Register the function to runtime's global table. * \brief Register the function to runtime's global table.
...@@ -449,8 +441,7 @@ DGL_DLL int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out); ...@@ -449,8 +441,7 @@ DGL_DLL int DGLFuncGetGlobal(const char* name, DGLFunctionHandle* out);
* \param out_array The array of function names. * \param out_array The array of function names.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLFuncListGlobalNames(int* out_size, DGL_DLL int DGLFuncListGlobalNames(int* out_size, const char*** out_array);
const char*** out_array);
// Array related apis for quick proptyping // Array related apis for quick proptyping
/*! /*!
...@@ -467,14 +458,9 @@ DGL_DLL int DGLFuncListGlobalNames(int* out_size, ...@@ -467,14 +458,9 @@ DGL_DLL int DGLFuncListGlobalNames(int* out_size,
* \param out The output handle. * \param out The output handle.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLArrayAlloc(const dgl_index_t* shape, DGL_DLL int DGLArrayAlloc(
int ndim, const dgl_index_t* shape, int ndim, int dtype_code, int dtype_bits,
int dtype_code, int dtype_lanes, int device_type, int device_id, DGLArrayHandle* out);
int dtype_bits,
int dtype_lanes,
int device_type,
int device_id,
DGLArrayHandle* out);
/*! /*!
* \brief Allocate a nd-array's with shared memory, * \brief Allocate a nd-array's with shared memory,
...@@ -490,14 +476,9 @@ DGL_DLL int DGLArrayAlloc(const dgl_index_t* shape, ...@@ -490,14 +476,9 @@ DGL_DLL int DGLArrayAlloc(const dgl_index_t* shape,
* \param out The output handle. * \param out The output handle.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
int DGLArrayAllocSharedMem(const char *mem_name, int DGLArrayAllocSharedMem(
const dgl_index_t *shape, const char* mem_name, const dgl_index_t* shape, int ndim, int dtype_code,
int ndim, int dtype_bits, int dtype_lanes, bool is_create, DGLArrayHandle* out);
int dtype_code,
int dtype_bits,
int dtype_lanes,
bool is_create,
DGLArrayHandle* out);
/*! /*!
* \brief Free the DGL Array. * \brief Free the DGL Array.
...@@ -513,9 +494,8 @@ DGL_DLL int DGLArrayFree(DGLArrayHandle handle); ...@@ -513,9 +494,8 @@ DGL_DLL int DGLArrayFree(DGLArrayHandle handle);
* \param nbytes The number of bytes to copy. * \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLArrayCopyFromBytes(DGLArrayHandle handle, DGL_DLL int DGLArrayCopyFromBytes(
void* data, DGLArrayHandle handle, void* data, size_t nbytes);
size_t nbytes);
/*! /*!
* \brief Copy array data to CPU byte array. * \brief Copy array data to CPU byte array.
...@@ -524,9 +504,8 @@ DGL_DLL int DGLArrayCopyFromBytes(DGLArrayHandle handle, ...@@ -524,9 +504,8 @@ DGL_DLL int DGLArrayCopyFromBytes(DGLArrayHandle handle,
* \param nbytes The number of bytes to copy. * \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLArrayCopyToBytes(DGLArrayHandle handle, DGL_DLL int DGLArrayCopyToBytes(
void* data, DGLArrayHandle handle, void* data, size_t nbytes);
size_t nbytes);
/*! /*!
* \brief Copy the array, both from and to must be valid during the copy. * \brief Copy the array, both from and to must be valid during the copy.
...@@ -534,8 +513,7 @@ DGL_DLL int DGLArrayCopyToBytes(DGLArrayHandle handle, ...@@ -534,8 +513,7 @@ DGL_DLL int DGLArrayCopyToBytes(DGLArrayHandle handle,
* \param to The target space. * \param to The target space.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from, DGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from, DGLArrayHandle to);
DGLArrayHandle to);
/*! /*!
* \brief Create a new runtime stream. * \brief Create a new runtime stream.
...@@ -545,7 +523,8 @@ DGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from, ...@@ -545,7 +523,8 @@ DGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from,
* \param out The new stream handle * \param out The new stream handle
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out); DGL_DLL int DGLStreamCreate(
int device_type, int device_id, DGLStreamHandle* out);
/*! /*!
* \brief Free a created stream handle. * \brief Free a created stream handle.
...@@ -555,7 +534,8 @@ DGL_DLL int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out ...@@ -555,7 +534,8 @@ DGL_DLL int DGLStreamCreate(int device_type, int device_id, DGLStreamHandle* out
* \param stream The stream to be freed * \param stream The stream to be freed
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream); DGL_DLL int DGLStreamFree(
int device_type, int device_id, DGLStreamHandle stream);
/*! /*!
* \brief Set the runtime stream of current thread to be stream. * \brief Set the runtime stream of current thread to be stream.
...@@ -568,7 +548,8 @@ DGL_DLL int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream ...@@ -568,7 +548,8 @@ DGL_DLL int DGLStreamFree(int device_type, int device_id, DGLStreamHandle stream
* \param handle The stream handle. * \param handle The stream handle.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLSetStream(int device_type, int device_id, DGLStreamHandle handle); DGL_DLL int DGLSetStream(
int device_type, int device_id, DGLStreamHandle handle);
/*! /*!
* \brief Get the runtime stream of current thread. * \brief Get the runtime stream of current thread.
...@@ -578,7 +559,8 @@ DGL_DLL int DGLSetStream(int device_type, int device_id, DGLStreamHandle handle) ...@@ -578,7 +559,8 @@ DGL_DLL int DGLSetStream(int device_type, int device_id, DGLStreamHandle handle)
* \param handle The stream handle. * \param handle The stream handle.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLGetStream(int device_type, int device_id, DGLStreamHandle* handle); DGL_DLL int DGLGetStream(
int device_type, int device_id, DGLStreamHandle* handle);
/*! /*!
* \brief Wait until all computations on stream completes. * \brief Wait until all computations on stream completes.
...@@ -588,7 +570,8 @@ DGL_DLL int DGLGetStream(int device_type, int device_id, DGLStreamHandle* handle ...@@ -588,7 +570,8 @@ DGL_DLL int DGLGetStream(int device_type, int device_id, DGLStreamHandle* handle
* \param stream The stream to be synchronized. * \param stream The stream to be synchronized.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream); DGL_DLL int DGLSynchronize(
int device_type, int device_id, DGLStreamHandle stream);
/*! /*!
* \brief Synchronize two streams of execution. * \brief Synchronize two streams of execution.
...@@ -599,16 +582,14 @@ DGL_DLL int DGLSynchronize(int device_type, int device_id, DGLStreamHandle strea ...@@ -599,16 +582,14 @@ DGL_DLL int DGLSynchronize(int device_type, int device_id, DGLStreamHandle strea
* \param dst The destination stream to synchronize. * \param dst The destination stream to synchronize.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLStreamStreamSynchronize(int device_type, DGL_DLL int DGLStreamStreamSynchronize(
int device_id, int device_type, int device_id, DGLStreamHandle src, DGLStreamHandle dst);
DGLStreamHandle src,
DGLStreamHandle dst);
/*! /*!
* \brief Load tensor adapter. * \brief Load tensor adapter.
* \return 0 when success, -1 when failure happens. * \return 0 when success, -1 when failure happens.
*/ */
DGL_DLL int DGLLoadTensorAdapter(const char *path); DGL_DLL int DGLLoadTensorAdapter(const char* path);
/*! /*!
* \brief Pin host memory. * \brief Pin host memory.
...@@ -628,17 +609,18 @@ int DGLArrayRecordStream(DGLArrayHandle handle, DGLStreamHandle stream); ...@@ -628,17 +609,18 @@ int DGLArrayRecordStream(DGLArrayHandle handle, DGLStreamHandle stream);
/*! /*!
* \brief Bug report macro. * \brief Bug report macro.
* *
* This serves as a sanity check on system side to make sure the code is correct by * This serves as a sanity check on system side to make sure the code is correct
* checking whether a condition always holds for complex reasons. Failing the * by checking whether a condition always holds for complex reasons. Failing
* condition signifies a system bug instead of users giving invalid inputs or using * the condition signifies a system bug instead of users giving invalid inputs
* the functionality incorrectly. * or using the functionality incorrectly.
* *
* Hints the user to file a bug report if the condition fails. * Hints the user to file a bug report if the condition fails.
*/ */
#define BUG_IF_FAIL(cond) \ #define BUG_IF_FAIL(cond) \
CHECK(cond) << "A bug has been occurred. " \ CHECK(cond) \
"Please file a bug report at https://github.com/dmlc/dgl/issues. " \ << "A bug has been occurred. " \
"Message: " "Please file a bug report at https://github.com/dmlc/dgl/issues. " \
"Message: "
#ifdef __cplusplus #ifdef __cplusplus
} // DGL_EXTERN_C } // DGL_EXTERN_C
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
* \brief DGL runtime config * \brief DGL runtime config
*/ */
#ifndef DGL_RUNTIME_CONFIG_H_ #ifndef DGL_RUNTIME_CONFIG_H_
#define DGL_RUNTIME_CONFIG_H_ #define DGL_RUNTIME_CONFIG_H_
...@@ -23,7 +22,7 @@ class Config { ...@@ -23,7 +22,7 @@ class Config {
bool IsLibxsmmAvailable() const; bool IsLibxsmmAvailable() const;
private: private:
Config() = default; Config() = default;
bool libxsmm_ = true; bool libxsmm_ = true;
}; };
......
...@@ -6,11 +6,12 @@ ...@@ -6,11 +6,12 @@
#ifndef DGL_RUNTIME_CONTAINER_H_ #ifndef DGL_RUNTIME_CONTAINER_H_
#define DGL_RUNTIME_CONTAINER_H_ #define DGL_RUNTIME_CONTAINER_H_
#include <unordered_map>
#include <vector>
#include <memory> #include <memory>
#include <utility>
#include <string> #include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "object.h" #include "object.h"
#include "packed_func.h" #include "packed_func.h"
...@@ -44,7 +45,7 @@ inline std::shared_ptr<ValueObject> MakeValue(T&& val) { ...@@ -44,7 +45,7 @@ inline std::shared_ptr<ValueObject> MakeValue(T&& val) {
class Value : public ObjectRef { class Value : public ObjectRef {
public: public:
Value() {} Value() {}
explicit Value(std::shared_ptr<Object> o): ObjectRef(o) {} explicit Value(std::shared_ptr<Object> o) : ObjectRef(o) {}
const ValueObject* operator->() const { const ValueObject* operator->() const {
return static_cast<const ValueObject*>(obj_.get()); return static_cast<const ValueObject*>(obj_.get());
...@@ -60,7 +61,7 @@ class ListObject : public Object { ...@@ -60,7 +61,7 @@ class ListObject : public Object {
std::vector<std::shared_ptr<Object> > data; std::vector<std::shared_ptr<Object> > data;
void VisitAttrs(AttrVisitor* visitor) final { void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to list have no effect. // Visitor to list have no effect.
} }
static constexpr const char* _type_key = "List"; static constexpr const char* _type_key = "List";
...@@ -71,7 +72,7 @@ class ListObject : public Object { ...@@ -71,7 +72,7 @@ class ListObject : public Object {
class MapObject : public Object { class MapObject : public Object {
public: public:
void VisitAttrs(AttrVisitor* visitor) final { void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to map have no effect. // Visitor to map have no effect.
} }
// hash function // hash function
struct Hash { struct Hash {
...@@ -90,9 +91,7 @@ class MapObject : public Object { ...@@ -90,9 +91,7 @@ class MapObject : public Object {
/*! \brief The corresponding conatiner type */ /*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map< using ContainerType = std::unordered_map<
std::shared_ptr<Object>, std::shared_ptr<Object>, std::shared_ptr<Object>, Hash, Equal>;
std::shared_ptr<Object>,
Hash, Equal>;
/*! \brief the data content */ /*! \brief the data content */
ContainerType data; ContainerType data;
...@@ -101,17 +100,15 @@ class MapObject : public Object { ...@@ -101,17 +100,15 @@ class MapObject : public Object {
DGL_DECLARE_OBJECT_TYPE_INFO(MapObject, Object); DGL_DECLARE_OBJECT_TYPE_INFO(MapObject, Object);
}; };
/*! \brief specialized map obj with string as key */ /*! \brief specialized map obj with string as key */
class StrMapObject : public Object { class StrMapObject : public Object {
public: public:
void VisitAttrs(AttrVisitor* visitor) final { void VisitAttrs(AttrVisitor* visitor) final {
// Visitor to map have no effect. // Visitor to map have no effect.
} }
/*! \brief The corresponding conatiner type */ /*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map< using ContainerType =
std::string, std::unordered_map<std::string, std::shared_ptr<Object> >;
std::shared_ptr<Object> >;
/*! \brief the data content */ /*! \brief the data content */
ContainerType data; ContainerType data;
...@@ -125,8 +122,7 @@ class StrMapObject : public Object { ...@@ -125,8 +122,7 @@ class StrMapObject : public Object {
* \tparam Converter a struct that contains converting function * \tparam Converter a struct that contains converting function
* \tparam TIter the content iterator type. * \tparam TIter the content iterator type.
*/ */
template<typename Converter, template <typename Converter, typename TIter>
typename TIter>
class IterAdapter { class IterAdapter {
public: public:
explicit IterAdapter(TIter iter) : iter_(iter) {} explicit IterAdapter(TIter iter) : iter_(iter) {}
...@@ -144,9 +140,7 @@ class IterAdapter { ...@@ -144,9 +140,7 @@ class IterAdapter {
inline bool operator==(IterAdapter other) const { inline bool operator==(IterAdapter other) const {
return iter_ == other.iter_; return iter_ == other.iter_;
} }
inline bool operator!=(IterAdapter other) const { inline bool operator!=(IterAdapter other) const { return !(*this == other); }
return !(*this == other);
}
inline const typename Converter::ResultType operator*() const { inline const typename Converter::ResultType operator*() const {
return Converter::convert(*iter_); return Converter::convert(*iter_);
} }
...@@ -175,7 +169,7 @@ class IterAdapter { ...@@ -175,7 +169,7 @@ class IterAdapter {
* <code> * <code>
* error: no type named 'type' in 'struct std::enable_if<false, void>' * error: no type named 'type' in 'struct std::enable_if<false, void>'
* </code> * </code>
* *
* Example: * Example:
* *
* <code> * <code>
...@@ -183,31 +177,32 @@ class IterAdapter { ...@@ -183,31 +177,32 @@ class IterAdapter {
* // List<NDArray> list2; // fails * // List<NDArray> list2; // fails
* List<Value> list; // works * List<Value> list; // works
* list.push_back(Value(MakeValue(1))); // works * list.push_back(Value(MakeValue(1))); // works
* list.push_back(Value(MakeValue(NDArray::Empty(shape, dtype, ctx)))); // works * list.push_back(Value(MakeValue(NDArray::Empty(shape, dtype, ctx)))); //
* works
* </code> * </code>
*/ */
template<typename T, template <
typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type > typename T,
typename =
typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
class List : public ObjectRef { class List : public ObjectRef {
public: public:
/*! /*!
* \brief default constructor * \brief default constructor
*/ */
List() { List() { obj_ = std::make_shared<ListObject>(); }
obj_ = std::make_shared<ListObject>();
}
/*! /*!
* \brief move constructor * \brief move constructor
* \param other source * \param other source
*/ */
List(List<T> && other) { // NOLINT(*) List(List<T>&& other) { // NOLINT(*)
obj_ = std::move(other.obj_); obj_ = std::move(other.obj_);
} }
/*! /*!
* \brief copy constructor * \brief copy constructor
* \param other source * \param other source
*/ */
List(const List<T> &other) : ObjectRef(other.obj_) { // NOLINT(*) List(const List<T>& other) : ObjectRef(other.obj_) { // NOLINT(*)
} }
/*! /*!
* \brief constructor from pointer * \brief constructor from pointer
...@@ -220,7 +215,7 @@ class List : public ObjectRef { ...@@ -220,7 +215,7 @@ class List : public ObjectRef {
* \param end end of iterator * \param end end of iterator
* \tparam IterType The type of iterator * \tparam IterType The type of iterator
*/ */
template<typename IterType> template <typename IterType>
List(IterType begin, IterType end) { List(IterType begin, IterType end) {
assign(begin, end); assign(begin, end);
} }
...@@ -228,20 +223,19 @@ class List : public ObjectRef { ...@@ -228,20 +223,19 @@ class List : public ObjectRef {
* \brief constructor from initializer list * \brief constructor from initializer list
* \param init The initalizer list * \param init The initalizer list
*/ */
List(std::initializer_list<T> init) { // NOLINT(*) List(std::initializer_list<T> init) { // NOLINT(*)
assign(init.begin(), init.end()); assign(init.begin(), init.end());
} }
/*! /*!
* \brief constructor from vector * \brief constructor from vector
* \param init The vector * \param init The vector
*/ */
List(const std::vector<T>& init) { // NOLINT(*) List(const std::vector<T>& init) { // NOLINT(*)
assign(init.begin(), init.end()); assign(init.begin(), init.end());
} }
/*! /*!
* \brief Constructs a container with n elements. Each element is a copy of val * \brief Constructs a container with n elements. Each element is a copy of
* \param n The size of the container * val \param n The size of the container \param val The init value
* \param val The init value
*/ */
explicit List(size_t n, const T& val) { explicit List(size_t n, const T& val) {
auto tmp_obj = std::make_shared<ListObject>(); auto tmp_obj = std::make_shared<ListObject>();
...@@ -255,7 +249,7 @@ class List : public ObjectRef { ...@@ -255,7 +249,7 @@ class List : public ObjectRef {
* \param other The source of assignment * \param other The source of assignment
* \return reference to self. * \return reference to self.
*/ */
List<T>& operator=(List<T> && other) { List<T>& operator=(List<T>&& other) {
obj_ = std::move(other.obj_); obj_ = std::move(other.obj_);
return *this; return *this;
} }
...@@ -264,7 +258,7 @@ class List : public ObjectRef { ...@@ -264,7 +258,7 @@ class List : public ObjectRef {
* \param other The source of assignment * \param other The source of assignment
* \return reference to self. * \return reference to self.
*/ */
List<T>& operator=(const List<T> & other) { List<T>& operator=(const List<T>& other) {
obj_ = other.obj_; obj_ = other.obj_;
return *this; return *this;
} }
...@@ -274,7 +268,7 @@ class List : public ObjectRef { ...@@ -274,7 +268,7 @@ class List : public ObjectRef {
* \param end end of iterator * \param end end of iterator
* \tparam IterType The type of iterator * \tparam IterType The type of iterator
*/ */
template<typename IterType> template <typename IterType>
void assign(IterType begin, IterType end) { void assign(IterType begin, IterType end) {
auto n = std::make_shared<ListObject>(); auto n = std::make_shared<ListObject>();
for (IterType it = begin; it != end; ++it) { for (IterType it = begin; it != end; ++it) {
...@@ -304,7 +298,7 @@ class List : public ObjectRef { ...@@ -304,7 +298,7 @@ class List : public ObjectRef {
* \return Handle to the internal obj container(which ganrantees to be unique) * \return Handle to the internal obj container(which ganrantees to be unique)
*/ */
inline ListObject* CopyOnWrite() { inline ListObject* CopyOnWrite() {
if (obj_.get() == nullptr || !obj_.unique()) { if (obj_.get() == nullptr || !obj_.unique()) {
obj_ = std::make_shared<ListObject>( obj_ = std::make_shared<ListObject>(
*static_cast<const ListObject*>(obj_.get())); *static_cast<const ListObject*>(obj_.get()));
} }
...@@ -328,9 +322,7 @@ class List : public ObjectRef { ...@@ -328,9 +322,7 @@ class List : public ObjectRef {
n->data[i] = value.obj_; n->data[i] = value.obj_;
} }
/*! \return whether list is empty */ /*! \return whether list is empty */
inline bool empty() const { inline bool empty() const { return size() == 0; }
return size() == 0;
}
/*! \brief Copy the content to a vector */ /*! \brief Copy the content to a vector */
inline std::vector<T> ToVector() const { inline std::vector<T> ToVector() const {
return std::vector<T>(begin(), end()); return std::vector<T>(begin(), end());
...@@ -340,16 +332,14 @@ class List : public ObjectRef { ...@@ -340,16 +332,14 @@ class List : public ObjectRef {
struct Ptr2ObjectRef { struct Ptr2ObjectRef {
using ResultType = T; using ResultType = T;
static inline T convert(const std::shared_ptr<Object>& n) { static inline T convert(const std::shared_ptr<Object>& n) { return T(n); }
return T(n);
}
}; };
using iterator = IterAdapter<Ptr2ObjectRef, using iterator = IterAdapter<
std::vector<std::shared_ptr<Object> >::const_iterator>; Ptr2ObjectRef, std::vector<std::shared_ptr<Object> >::const_iterator>;
using reverse_iterator = IterAdapter< using reverse_iterator = IterAdapter<
Ptr2ObjectRef, Ptr2ObjectRef,
std::vector<std::shared_ptr<Object> >::const_reverse_iterator>; std::vector<std::shared_ptr<Object> >::const_reverse_iterator>;
/*! \return begin iterator */ /*! \return begin iterator */
inline iterator begin() const { inline iterator begin() const {
...@@ -361,11 +351,13 @@ class List : public ObjectRef { ...@@ -361,11 +351,13 @@ class List : public ObjectRef {
} }
/*! \return rbegin iterator */ /*! \return rbegin iterator */
inline reverse_iterator rbegin() const { inline reverse_iterator rbegin() const {
return reverse_iterator(static_cast<const ListObject*>(obj_.get())->data.rbegin()); return reverse_iterator(
static_cast<const ListObject*>(obj_.get())->data.rbegin());
} }
/*! \return rend iterator */ /*! \return rend iterator */
inline reverse_iterator rend() const { inline reverse_iterator rend() const {
return reverse_iterator(static_cast<const ListObject*>(obj_.get())->data.rend()); return reverse_iterator(
static_cast<const ListObject*>(obj_.get())->data.rend());
} }
}; };
...@@ -390,7 +382,7 @@ class List : public ObjectRef { ...@@ -390,7 +382,7 @@ class List : public ObjectRef {
* <code> * <code>
* error: no type named 'type' in 'struct std::enable_if<false, void>' * error: no type named 'type' in 'struct std::enable_if<false, void>'
* </code> * </code>
* *
* Example: * Example:
* *
* <code> * <code>
...@@ -398,35 +390,35 @@ class List : public ObjectRef { ...@@ -398,35 +390,35 @@ class List : public ObjectRef {
* // Map<std::string, NDArray> map2; // fails * // Map<std::string, NDArray> map2; // fails
* Map<std::string, Value> map; // works * Map<std::string, Value> map; // works
* map.Set("key1", Value(MakeValue(1))); // works * map.Set("key1", Value(MakeValue(1))); // works
* map.Set("key2", Value(MakeValue(NDArray::Empty(shape, dtype, ctx)))); // works * map.Set("key2", Value(MakeValue(NDArray::Empty(shape, dtype, ctx)))); //
* works
* </code> * </code>
*/ */
template<typename K, template <
typename V, typename K, typename V,
typename = typename std::enable_if< typename = typename std::enable_if<
std::is_base_of<ObjectRef, K>::value || std::is_base_of<ObjectRef, K>::value ||
std::is_base_of<std::string, K>::value >::type, std::is_base_of<std::string, K>::value>::type,
typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type> typename =
typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
class Map : public ObjectRef { class Map : public ObjectRef {
public: public:
/*! /*!
* \brief default constructor * \brief default constructor
*/ */
Map() { Map() { obj_ = std::make_shared<MapObject>(); }
obj_ = std::make_shared<MapObject>();
}
/*! /*!
* \brief move constructor * \brief move constructor
* \param other source * \param other source
*/ */
Map(Map<K, V> && other) { // NOLINT(*) Map(Map<K, V>&& other) { // NOLINT(*)
obj_ = std::move(other.obj_); obj_ = std::move(other.obj_);
} }
/*! /*!
* \brief copy constructor * \brief copy constructor
* \param other source * \param other source
*/ */
Map(const Map<K, V> &other) : ObjectRef(other.obj_) { // NOLINT(*) Map(const Map<K, V>& other) : ObjectRef(other.obj_) { // NOLINT(*)
} }
/*! /*!
* \brief constructor from pointer * \brief constructor from pointer
...@@ -439,7 +431,7 @@ class Map : public ObjectRef { ...@@ -439,7 +431,7 @@ class Map : public ObjectRef {
* \param end end of iterator * \param end end of iterator
* \tparam IterType The type of iterator * \tparam IterType The type of iterator
*/ */
template<typename IterType> template <typename IterType>
Map(IterType begin, IterType end) { Map(IterType begin, IterType end) {
assign(begin, end); assign(begin, end);
} }
...@@ -447,15 +439,15 @@ class Map : public ObjectRef { ...@@ -447,15 +439,15 @@ class Map : public ObjectRef {
* \brief constructor from initializer list * \brief constructor from initializer list
* \param init The initalizer list * \param init The initalizer list
*/ */
Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*) Map(std::initializer_list<std::pair<K, V> > init) { // NOLINT(*)
assign(init.begin(), init.end()); assign(init.begin(), init.end());
} }
/*! /*!
* \brief constructor from vector * \brief constructor from vector
* \param init The vector * \param init The vector
*/ */
template<typename Hash, typename Equal> template <typename Hash, typename Equal>
Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*) Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
assign(init.begin(), init.end()); assign(init.begin(), init.end());
} }
/*! /*!
...@@ -463,7 +455,7 @@ class Map : public ObjectRef { ...@@ -463,7 +455,7 @@ class Map : public ObjectRef {
* \param other The source of assignment * \param other The source of assignment
* \return reference to self. * \return reference to self.
*/ */
Map<K, V>& operator=(Map<K, V> && other) { Map<K, V>& operator=(Map<K, V>&& other) {
obj_ = std::move(other.obj_); obj_ = std::move(other.obj_);
return *this; return *this;
} }
...@@ -472,7 +464,7 @@ class Map : public ObjectRef { ...@@ -472,7 +464,7 @@ class Map : public ObjectRef {
* \param other The source of assignment * \param other The source of assignment
* \return reference to self. * \return reference to self.
*/ */
Map<K, V>& operator=(const Map<K, V> & other) { Map<K, V>& operator=(const Map<K, V>& other) {
obj_ = other.obj_; obj_ = other.obj_;
return *this; return *this;
} }
...@@ -482,12 +474,11 @@ class Map : public ObjectRef { ...@@ -482,12 +474,11 @@ class Map : public ObjectRef {
* \param end end of iterator * \param end end of iterator
* \tparam IterType The type of iterator * \tparam IterType The type of iterator
*/ */
template<typename IterType> template <typename IterType>
void assign(IterType begin, IterType end) { void assign(IterType begin, IterType end) {
auto n = std::shared_ptr<MapObject>(); auto n = std::shared_ptr<MapObject>();
for (IterType i = begin; i != end; ++i) { for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first.obj_, n->data.emplace(std::make_pair(i->first.obj_, i->second.obj_));
i->second.obj_));
} }
obj_ = std::move(n); obj_ = std::move(n);
} }
...@@ -526,7 +517,7 @@ class Map : public ObjectRef { ...@@ -526,7 +517,7 @@ class Map : public ObjectRef {
* \return Handle to the internal obj container(which ganrantees to be unique) * \return Handle to the internal obj container(which ganrantees to be unique)
*/ */
inline MapObject* CopyOnWrite() { inline MapObject* CopyOnWrite() {
if (obj_.get() == nullptr || !obj_.unique()) { if (obj_.get() == nullptr || !obj_.unique()) {
obj_ = std::make_shared<MapObject>( obj_ = std::make_shared<MapObject>(
*static_cast<const MapObject*>(obj_.get())); *static_cast<const MapObject*>(obj_.get()));
} }
...@@ -543,23 +534,20 @@ class Map : public ObjectRef { ...@@ -543,23 +534,20 @@ class Map : public ObjectRef {
} }
/*! \return whether list is empty */ /*! \return whether list is empty */
inline bool empty() const { inline bool empty() const { return size() == 0; }
return size() == 0;
}
/*! \brief specify container obj */ /*! \brief specify container obj */
using ContainerType = MapObject; using ContainerType = MapObject;
struct Ptr2ObjectRef { struct Ptr2ObjectRef {
using ResultType = std::pair<K, V>; using ResultType = std::pair<K, V>;
static inline ResultType convert(const std::pair< static inline ResultType convert(
std::shared_ptr<Object>, const std::pair<std::shared_ptr<Object>, std::shared_ptr<Object> >& n) {
std::shared_ptr<Object> >& n) {
return std::make_pair(K(n.first), V(n.second)); return std::make_pair(K(n.first), V(n.second));
} }
}; };
using iterator = IterAdapter< using iterator =
Ptr2ObjectRef, MapObject::ContainerType::const_iterator>; IterAdapter<Ptr2ObjectRef, MapObject::ContainerType::const_iterator>;
/*! \return begin iterator */ /*! \return begin iterator */
inline iterator begin() const { inline iterator begin() const {
...@@ -571,50 +559,49 @@ class Map : public ObjectRef { ...@@ -571,50 +559,49 @@ class Map : public ObjectRef {
} }
/*! \return begin iterator */ /*! \return begin iterator */
inline iterator find(const K& key) const { inline iterator find(const K& key) const {
return iterator(static_cast<const MapObject*>(obj_.get())->data.find(key.obj_)); return iterator(
static_cast<const MapObject*>(obj_.get())->data.find(key.obj_));
} }
}; };
// specialize of string map // specialize of string map
template<typename V, typename T1, typename T2> template <typename V, typename T1, typename T2>
class Map<std::string, V, T1, T2> : public ObjectRef { class Map<std::string, V, T1, T2> : public ObjectRef {
public: public:
// for code reuse // for code reuse
Map() { Map() { obj_ = std::make_shared<StrMapObject>(); }
obj_ = std::make_shared<StrMapObject>(); Map(Map<std::string, V>&& other) { // NOLINT(*)
}
Map(Map<std::string, V> && other) { // NOLINT(*)
obj_ = std::move(other.obj_); obj_ = std::move(other.obj_);
} }
Map(const Map<std::string, V> &other) : ObjectRef(other.obj_) { // NOLINT(*) Map(const Map<std::string, V>& other) : ObjectRef(other.obj_) { // NOLINT(*)
} }
explicit Map(std::shared_ptr<Object> n) : ObjectRef(n) {} explicit Map(std::shared_ptr<Object> n) : ObjectRef(n) {}
template<typename IterType> template <typename IterType>
Map(IterType begin, IterType end) { Map(IterType begin, IterType end) {
assign(begin, end); assign(begin, end);
} }
Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*) Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*)
assign(init.begin(), init.end()); assign(init.begin(), init.end());
} }
template<typename Hash, typename Equal> template <typename Hash, typename Equal>
Map(const std::unordered_map<std::string, V, Hash, Equal>& init) { // NOLINT(*) Map(const std::unordered_map<std::string, V, Hash, Equal>&
init) { // NOLINT(*)
assign(init.begin(), init.end()); assign(init.begin(), init.end());
} }
Map<std::string, V>& operator=(Map<std::string, V> && other) { Map<std::string, V>& operator=(Map<std::string, V>&& other) {
obj_ = std::move(other.obj_); obj_ = std::move(other.obj_);
return *this; return *this;
} }
Map<std::string, V>& operator=(const Map<std::string, V> & other) { Map<std::string, V>& operator=(const Map<std::string, V>& other) {
obj_ = other.obj_; obj_ = other.obj_;
return *this; return *this;
} }
template<typename IterType> template <typename IterType>
void assign(IterType begin, IterType end) { void assign(IterType begin, IterType end) {
auto n = std::make_shared<StrMapObject>(); auto n = std::make_shared<StrMapObject>();
for (IterType i = begin; i != end; ++i) { for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first, n->data.emplace(std::make_pair(i->first, i->second.obj_));
i->second.obj_));
} }
obj_ = std::move(n); obj_ = std::move(n);
} }
...@@ -633,7 +620,7 @@ class Map<std::string, V, T1, T2> : public ObjectRef { ...@@ -633,7 +620,7 @@ class Map<std::string, V, T1, T2> : public ObjectRef {
return static_cast<const StrMapObject*>(obj_.get())->data.count(key); return static_cast<const StrMapObject*>(obj_.get())->data.count(key);
} }
inline StrMapObject* CopyOnWrite() { inline StrMapObject* CopyOnWrite() {
if (obj_.get() == nullptr || !obj_.unique()) { if (obj_.get() == nullptr || !obj_.unique()) {
obj_ = std::make_shared<MapObject>( obj_ = std::make_shared<MapObject>(
*static_cast<const MapObject*>(obj_.get())); *static_cast<const MapObject*>(obj_.get()));
} }
...@@ -643,22 +630,19 @@ class Map<std::string, V, T1, T2> : public ObjectRef { ...@@ -643,22 +630,19 @@ class Map<std::string, V, T1, T2> : public ObjectRef {
StrMapObject* n = this->CopyOnWrite(); StrMapObject* n = this->CopyOnWrite();
n->data[key] = value.obj_; n->data[key] = value.obj_;
} }
inline bool empty() const { inline bool empty() const { return size() == 0; }
return size() == 0;
}
using ContainerType = StrMapObject; using ContainerType = StrMapObject;
struct Ptr2ObjectRef { struct Ptr2ObjectRef {
using ResultType = std::pair<std::string, V>; using ResultType = std::pair<std::string, V>;
static inline ResultType convert(const std::pair< static inline ResultType convert(
std::string, const std::pair<std::string, std::shared_ptr<Object> >& n) {
std::shared_ptr<Object> >& n) {
return std::make_pair(n.first, V(n.second)); return std::make_pair(n.first, V(n.second));
} }
}; };
using iterator = IterAdapter< using iterator =
Ptr2ObjectRef, StrMapObject::ContainerType::const_iterator>; IterAdapter<Ptr2ObjectRef, StrMapObject::ContainerType::const_iterator>;
/*! \return begin iterator */ /*! \return begin iterator */
inline iterator begin() const { inline iterator begin() const {
...@@ -670,7 +654,8 @@ class Map<std::string, V, T1, T2> : public ObjectRef { ...@@ -670,7 +654,8 @@ class Map<std::string, V, T1, T2> : public ObjectRef {
} }
/*! \return begin iterator */ /*! \return begin iterator */
inline iterator find(const std::string& key) const { inline iterator find(const std::string& key) const {
return iterator(static_cast<const StrMapObject*>(obj_.get())->data.find(key)); return iterator(
static_cast<const StrMapObject*>(obj_.get())->data.find(key));
} }
}; };
......
...@@ -7,8 +7,9 @@ ...@@ -7,8 +7,9 @@
#define DGL_RUNTIME_DEVICE_API_H_ #define DGL_RUNTIME_DEVICE_API_H_
#include <string> #include <string>
#include "packed_func.h"
#include "c_runtime_api.h" #include "c_runtime_api.h"
#include "packed_func.h"
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
...@@ -30,7 +31,8 @@ enum DeviceAttrKind : int { ...@@ -30,7 +31,8 @@ enum DeviceAttrKind : int {
/*! \brief Number of bytes each allocation must align to */ /*! \brief Number of bytes each allocation must align to */
constexpr int kAllocAlignment = 64; constexpr int kAllocAlignment = 64;
/*! \brief Number of bytes each allocation must align to in temporary allocation */ /*! \brief Number of bytes each allocation must align to in temporary allocation
*/
constexpr int kTempAllocaAlignment = 64; constexpr int kTempAllocaAlignment = 64;
/*! \brief Maximum size that can be allocated on stack */ /*! \brief Maximum size that can be allocated on stack */
...@@ -47,9 +49,7 @@ class DeviceAPI { ...@@ -47,9 +49,7 @@ class DeviceAPI {
/*! /*!
* \brief Check whether the device is available. * \brief Check whether the device is available.
*/ */
virtual bool IsAvailable() { virtual bool IsAvailable() { return true; }
return true;
}
/*! /*!
* \brief Set the environment device id to ctx * \brief Set the environment device id to ctx
* \param ctx The context to be set. * \param ctx The context to be set.
...@@ -62,7 +62,8 @@ class DeviceAPI { ...@@ -62,7 +62,8 @@ class DeviceAPI {
* \param rv The return value. * \param rv The return value.
* \sa DeviceAttrKind * \sa DeviceAttrKind
*/ */
virtual void GetAttr(DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) = 0; virtual void GetAttr(
DGLContext ctx, DeviceAttrKind kind, DGLRetValue* rv) = 0;
/*! /*!
* \brief Allocate a data space on device. * \brief Allocate a data space on device.
* \param ctx The device context to perform operation. * \param ctx The device context to perform operation.
...@@ -72,10 +73,9 @@ class DeviceAPI { ...@@ -72,10 +73,9 @@ class DeviceAPI {
* as OpenGL, as nbytes & alignment are sufficient for most backends. * as OpenGL, as nbytes & alignment are sufficient for most backends.
* \return The allocated device pointer. * \return The allocated device pointer.
*/ */
virtual void* AllocDataSpace(DGLContext ctx, virtual void* AllocDataSpace(
size_t nbytes, DGLContext ctx, size_t nbytes, size_t alignment,
size_t alignment, DGLDataType type_hint) = 0;
DGLDataType type_hint) = 0;
/*! /*!
* \brief Free a data space on device. * \brief Free a data space on device.
* \param ctx The device context to perform operation. * \param ctx The device context to perform operation.
...@@ -94,15 +94,11 @@ class DeviceAPI { ...@@ -94,15 +94,11 @@ class DeviceAPI {
* \param type_hint The type of elements, only neded by certain backends. * \param type_hint The type of elements, only neded by certain backends.
* can be useful for cross device endian converison. * can be useful for cross device endian converison.
*/ */
virtual void CopyDataFromTo(const void* from, virtual void CopyDataFromTo(
size_t from_offset, const void* from, size_t from_offset, void* to, size_t to_offset,
void* to, size_t num_bytes, DGLContext ctx_from, DGLContext ctx_to,
size_t to_offset, DGLDataType type_hint) = 0;
size_t num_bytes, /*!
DGLContext ctx_from,
DGLContext ctx_to,
DGLDataType type_hint) = 0;
/*!
* \brief Create a new stream of execution. * \brief Create a new stream of execution.
* *
* \param ctx The context of allocation. * \param ctx The context of allocation.
...@@ -145,31 +141,28 @@ class DeviceAPI { ...@@ -145,31 +141,28 @@ class DeviceAPI {
* \param event_src The source stream to synchronize. * \param event_src The source stream to synchronize.
* \param event_dst The destination stream to synchronize. * \param event_dst The destination stream to synchronize.
*/ */
DGL_DLL virtual void SyncStreamFromTo(DGLContext ctx, DGL_DLL virtual void SyncStreamFromTo(
DGLStreamHandle event_src, DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst);
DGLStreamHandle event_dst);
/*! /*!
* \brief Pin host memory using cudaHostRegister(). * \brief Pin host memory using cudaHostRegister().
* *
* \param ptr The host memory pointer to be pinned. * \param ptr The host memory pointer to be pinned.
* \param nbytes The size to be pinned. * \param nbytes The size to be pinned.
*/ */
DGL_DLL virtual void PinData(void* ptr, size_t nbytes); DGL_DLL virtual void PinData(void* ptr, size_t nbytes);
/*! /*!
* \brief Unpin host memory using cudaHostUnregister(). * \brief Unpin host memory using cudaHostUnregister().
* *
* \param ptr The host memory pointer to be unpinned. * \param ptr The host memory pointer to be unpinned.
*/ */
DGL_DLL virtual void UnpinData(void* ptr); DGL_DLL virtual void UnpinData(void* ptr);
/*! /*!
* \brief Check whether the memory is in pinned memory. * \brief Check whether the memory is in pinned memory.
*/ */
DGL_DLL virtual bool IsPinned(const void* ptr) { DGL_DLL virtual bool IsPinned(const void* ptr) { return false; }
return false;
}
/*! /*!
* \brief Allocate temporal workspace for backend execution. * \brief Allocate temporal workspace for backend execution.
...@@ -180,16 +173,16 @@ class DeviceAPI { ...@@ -180,16 +173,16 @@ class DeviceAPI {
* - Only a few allocation will happen, and space will be released after use. * - Only a few allocation will happen, and space will be released after use.
* - The release order is usually in reverse order of allocate (stack style). * - The release order is usually in reverse order of allocate (stack style).
* - Repeative pattern of same allocations over different runs. * - Repeative pattern of same allocations over different runs.
* - Workspace should not overlap between different threads(i.e. be threadlocal) * - Workspace should not overlap between different threads(i.e. be
* threadlocal)
* *
* \param ctx The context of allocation. * \param ctx The context of allocation.
* \param nbytes The size to be allocated. * \param nbytes The size to be allocated.
* \param type_hint The type of elements. Only needed by certain backends such * \param type_hint The type of elements. Only needed by certain backends such
* as OpenGL, as nbytes is sufficient for most backends. * as OpenGL, as nbytes is sufficient for most backends.
*/ */
DGL_DLL virtual void* AllocWorkspace(DGLContext ctx, DGL_DLL virtual void* AllocWorkspace(
size_t nbytes, DGLContext ctx, size_t nbytes, DGLDataType type_hint = {});
DGLDataType type_hint = {});
/*! /*!
* \brief Free temporal workspace in backend execution. * \brief Free temporal workspace in backend execution.
* *
...@@ -206,14 +199,14 @@ class DeviceAPI { ...@@ -206,14 +199,14 @@ class DeviceAPI {
*/ */
DGL_DLL static DeviceAPI* Get(DGLContext ctx, bool allow_missing = false); DGL_DLL static DeviceAPI* Get(DGLContext ctx, bool allow_missing = false);
/*! /*!
* \brief Get device API based on context. * \brief Get device API based on context.
* \param dev_type The device type * \param dev_type The device type
* \param allow_missing Whether allow missing * \param allow_missing Whether allow missing
* \return The corresponding device API. * \return The corresponding device API.
*/ */
DGL_DLL static DeviceAPI* Get(DGLDeviceType dev_type, bool allow_missing = false); DGL_DLL static DeviceAPI* Get(
DGLDeviceType dev_type, bool allow_missing = false);
}; };
/*! \brief The device type bigger than this is RPC device */ /*! \brief The device type bigger than this is RPC device */
......
...@@ -31,10 +31,10 @@ struct DLPackConvert { ...@@ -31,10 +31,10 @@ struct DLPackConvert {
/*! /*!
* \brief Deleter for NDArray converted from DLPack. * \brief Deleter for NDArray converted from DLPack.
* *
* This is used from data which is passed from external DLPack(DLManagedTensor) * This is used from data which is passed from external
* that are not allocated inside of DGL. * DLPack(DLManagedTensor) that are not allocated inside of DGL. This enables
* This enables us to create NDArray from memory allocated by other * us to create NDArray from memory allocated by other frameworks that are
* frameworks that are DLPack compatible * DLPack compatible
*/ */
static void DLPackDeleter(NDArray::Container* ptr); static void DLPackDeleter(NDArray::Container* ptr);
...@@ -43,7 +43,7 @@ struct DLPackConvert { ...@@ -43,7 +43,7 @@ struct DLPackConvert {
* \param from The DGL NDArray. * \param from The DGL NDArray.
* \return A DLPack tensor. * \return A DLPack tensor.
*/ */
static DLManagedTensor* ToDLPack(const NDArray &from); static DLManagedTensor* ToDLPack(const NDArray& from);
}; };
} // namespace runtime } // namespace runtime
...@@ -66,8 +66,7 @@ DGL_DLL void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor); ...@@ -66,8 +66,7 @@ DGL_DLL void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor);
* \param out The output array handle. * \param out The output array handle.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from, DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from, DGLArrayHandle* out);
DGLArrayHandle* out);
/*! /*!
* \brief Produce a DLMangedTensor from the array that shares data memory with * \brief Produce a DLMangedTensor from the array that shares data memory with
...@@ -76,8 +75,8 @@ DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from, ...@@ -76,8 +75,8 @@ DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from,
* \param out The DLManagedTensor handle. * \param out The DLManagedTensor handle.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DGL_DLL int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out, DGL_DLL int DGLArrayToDLPack(
int alignment = 0); DGLArrayHandle from, DLManagedTensor** out, int alignment = 0);
#ifdef __cplusplus #ifdef __cplusplus
} // DGL_EXTERN_C } // DGL_EXTERN_C
......
...@@ -9,10 +9,12 @@ ...@@ -9,10 +9,12 @@
#define DGL_RUNTIME_MODULE_H_ #define DGL_RUNTIME_MODULE_H_
#include <dmlc/io.h> #include <dmlc/io.h>
#include <memory> #include <memory>
#include <vector>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "c_runtime_api.h" #include "c_runtime_api.h"
namespace dgl { namespace dgl {
...@@ -29,8 +31,7 @@ class Module { ...@@ -29,8 +31,7 @@ class Module {
public: public:
Module() {} Module() {}
// constructor from container. // constructor from container.
explicit Module(std::shared_ptr<ModuleNode> n) explicit Module(std::shared_ptr<ModuleNode> n) : node_(n) {}
: node_(n) {}
/*! /*!
* \brief Get packed function from current module by name. * \brief Get packed function from current module by name.
* *
...@@ -40,7 +41,8 @@ class Module { ...@@ -40,7 +41,8 @@ class Module {
* This function will return PackedFunc(nullptr) if function do not exist. * This function will return PackedFunc(nullptr) if function do not exist.
* \note Implemented in packed_func.cc * \note Implemented in packed_func.cc
*/ */
inline PackedFunc GetFunction(const std::string& name, bool query_imports = false); inline PackedFunc GetFunction(
const std::string& name, bool query_imports = false);
/*! \return internal container */ /*! \return internal container */
inline ModuleNode* operator->(); inline ModuleNode* operator->();
/*! \return internal container */ /*! \return internal container */
...@@ -61,8 +63,8 @@ class Module { ...@@ -61,8 +63,8 @@ class Module {
* \note This function won't load the import relationship. * \note This function won't load the import relationship.
* Re-create import relationship by calling Import. * Re-create import relationship by calling Import.
*/ */
DGL_DLL static Module LoadFromFile(const std::string& file_name, DGL_DLL static Module LoadFromFile(
const std::string& format = ""); const std::string& file_name, const std::string& format = "");
private: private:
std::shared_ptr<ModuleNode> node_; std::shared_ptr<ModuleNode> node_;
...@@ -103,8 +105,8 @@ class ModuleNode { ...@@ -103,8 +105,8 @@ class ModuleNode {
* \param file_name The file to be saved to. * \param file_name The file to be saved to.
* \param format The format of the file. * \param format The format of the file.
*/ */
virtual void SaveToFile(const std::string& file_name, virtual void SaveToFile(
const std::string& format); const std::string& file_name, const std::string& format);
/*! /*!
* \brief Save the module to binary stream. * \brief Save the module to binary stream.
* \param stream The binary stream to save to. * \param stream The binary stream to save to.
...@@ -128,9 +130,7 @@ class ModuleNode { ...@@ -128,9 +130,7 @@ class ModuleNode {
*/ */
DGL_DLL const PackedFunc* GetFuncFromEnv(const std::string& name); DGL_DLL const PackedFunc* GetFuncFromEnv(const std::string& name);
/*! \return The module it imports from */ /*! \return The module it imports from */
const std::vector<Module>& imports() const { const std::vector<Module>& imports() const { return imports_; }
return imports_;
}
protected: protected:
friend class Module; friend class Module;
...@@ -139,8 +139,7 @@ class ModuleNode { ...@@ -139,8 +139,7 @@ class ModuleNode {
private: private:
/*! \brief Cache used by GetImport */ /*! \brief Cache used by GetImport */
std::unordered_map<std::string, std::unordered_map<std::string, std::unique_ptr<PackedFunc> > import_cache_;
std::unique_ptr<PackedFunc> > import_cache_;
}; };
/*! \brief namespace for constant symbols */ /*! \brief namespace for constant symbols */
...@@ -155,20 +154,19 @@ constexpr const char* dgl_dev_mblob_nbytes = "__dgl_dev_mblob_nbytes"; ...@@ -155,20 +154,19 @@ constexpr const char* dgl_dev_mblob_nbytes = "__dgl_dev_mblob_nbytes";
constexpr const char* dgl_set_device = "__dgl_set_device"; constexpr const char* dgl_set_device = "__dgl_set_device";
/*! \brief Auxiliary counter to global barrier. */ /*! \brief Auxiliary counter to global barrier. */
constexpr const char* dgl_global_barrier_state = "__dgl_global_barrier_state"; constexpr const char* dgl_global_barrier_state = "__dgl_global_barrier_state";
/*! \brief Prepare the global barrier before kernels that uses global barrier. */ /*!
constexpr const char* dgl_prepare_global_barrier = "__dgl_prepare_global_barrier"; * \brief Prepare the global barrier before kernels that uses global barrier.
*/
constexpr const char* dgl_prepare_global_barrier =
"__dgl_prepare_global_barrier";
/*! \brief Placeholder for the module's entry function. */ /*! \brief Placeholder for the module's entry function. */
constexpr const char* dgl_module_main = "__dgl_main__"; constexpr const char* dgl_module_main = "__dgl_main__";
} // namespace symbol } // namespace symbol
// implementations of inline functions. // implementations of inline functions.
inline ModuleNode* Module::operator->() { inline ModuleNode* Module::operator->() { return node_.get(); }
return node_.get();
}
inline const ModuleNode* Module::operator->() const { inline const ModuleNode* Module::operator->() const { return node_.get(); }
return node_.get();
}
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
......
...@@ -7,10 +7,11 @@ ...@@ -7,10 +7,11 @@
#define DGL_RUNTIME_OBJECT_H_ #define DGL_RUNTIME_OBJECT_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <string>
#include <vector>
#include <memory> #include <memory>
#include <string>
#include <type_traits> #include <type_traits>
#include <vector>
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
...@@ -26,7 +27,7 @@ class NDArray; ...@@ -26,7 +27,7 @@ class NDArray;
*/ */
class AttrVisitor { class AttrVisitor {
public: public:
//! \cond Doxygen_Suppress //! \cond Doxygen_Suppress
virtual void Visit(const char* key, double* value) = 0; virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0; virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, uint64_t* value) = 0; virtual void Visit(const char* key, uint64_t* value) = 0;
...@@ -35,14 +36,16 @@ class AttrVisitor { ...@@ -35,14 +36,16 @@ class AttrVisitor {
virtual void Visit(const char* key, std::string* value) = 0; virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, ObjectRef* value) = 0; virtual void Visit(const char* key, ObjectRef* value) = 0;
virtual void Visit(const char* key, NDArray* value) = 0; virtual void Visit(const char* key, NDArray* value) = 0;
template<typename ENum, template <
typename = typename std::enable_if<std::is_enum<ENum>::value>::type> typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
void Visit(const char* key, ENum* ptr) { void Visit(const char* key, ENum* ptr) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value, static_assert(
"declare enum to be enum int to use visitor"); std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
this->Visit(key, reinterpret_cast<int*>(ptr)); this->Visit(key, reinterpret_cast<int*>(ptr));
} }
//! \endcond //! \endcond
}; };
/*! /*!
...@@ -87,20 +90,19 @@ class Object { ...@@ -87,20 +90,19 @@ class Object {
/*! /*!
* \return whether the type is derived from * \return whether the type is derived from
*/ */
template<typename T> template <typename T>
inline bool derived_from() const; inline bool derived_from() const;
/*! /*!
* \return whether the object is of type T * \return whether the object is of type T
* \tparam The type to be checked. * \tparam The type to be checked.
*/ */
template<typename T> template <typename T>
inline bool is_type() const; inline bool is_type() const;
// object ref can see this // object ref can see this
friend class ObjectRef; friend class ObjectRef;
static constexpr const char* _type_key = "Object"; static constexpr const char* _type_key = "Object";
}; };
/*! \brief base class of all reference object */ /*! \brief base class of all reference object */
class ObjectRef { class ObjectRef {
public: public:
...@@ -109,7 +111,8 @@ class ObjectRef { ...@@ -109,7 +111,8 @@ class ObjectRef {
/*! /*!
* \brief Comparator * \brief Comparator
* *
* Compare with the two are referencing to the same object (compare by address). * Compare with the two are referencing to the same object (compare by
* address).
* *
* \param other Another object ref. * \param other Another object ref.
* \return the compare result. * \return the compare result.
...@@ -119,7 +122,8 @@ class ObjectRef { ...@@ -119,7 +122,8 @@ class ObjectRef {
/*! /*!
* \brief Comparator * \brief Comparator
* *
* Compare with the two are referencing to the same object (compare by address). * Compare with the two are referencing to the same object (compare by
* address).
* *
* \param other Another object ref. * \param other Another object ref.
* \return the compare result. * \return the compare result.
...@@ -161,8 +165,8 @@ class ObjectRef { ...@@ -161,8 +165,8 @@ class ObjectRef {
* } * }
* \tparam T the target type, must be subtype of Object * \tparam T the target type, must be subtype of Object
*/ */
template<typename T> template <typename T>
inline const T *as() const; inline const T* as() const;
/*! \brief default constructor */ /*! \brief default constructor */
ObjectRef() = default; ObjectRef() = default;
...@@ -178,11 +182,11 @@ class ObjectRef { ...@@ -178,11 +182,11 @@ class ObjectRef {
* This is macro should be used in abstract base class definition * This is macro should be used in abstract base class definition
* because it does not define type_key and type_index. * because it does not define type_key and type_index.
*/ */
#define DGL_DECLARE_BASE_OBJECT_INFO(TypeName, Parent) \ #define DGL_DECLARE_BASE_OBJECT_INFO(TypeName, Parent) \
const bool _DerivedFrom(uint32_t tid) const override { \ const bool _DerivedFrom(uint32_t tid) const override { \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
if (tidx == tid) return true; \ if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \ return Parent::_DerivedFrom(tid); \
} }
/*! /*!
...@@ -206,69 +210,61 @@ class ObjectRef { ...@@ -206,69 +210,61 @@ class ObjectRef {
* DGL_DECLARE_OBJECT_TYPE_INFO(SomeChildClass, SomeBaseClass); * DGL_DECLARE_OBJECT_TYPE_INFO(SomeChildClass, SomeBaseClass);
* }; * };
*/ */
#define DGL_DECLARE_OBJECT_TYPE_INFO(TypeName, Parent) \ #define DGL_DECLARE_OBJECT_TYPE_INFO(TypeName, Parent) \
const char* type_key() const final { \ const char* type_key() const final { return TypeName::_type_key; } \
return TypeName::_type_key; \ uint32_t type_index() const final { \
} \ static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
uint32_t type_index() const final { \ return tidx; \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ } \
return tidx; \ bool _DerivedFrom(uint32_t tid) const final { \
} \ static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \
bool _DerivedFrom(uint32_t tid) const final { \ if (tidx == tid) return true; \
static uint32_t tidx = TypeKey2Index(TypeName::_type_key); \ return Parent::_DerivedFrom(tid); \
if (tidx == tid) return true; \
return Parent::_DerivedFrom(tid); \
} }
/*! \brief Macro to generate common object reference class method definition */ /*! \brief Macro to generate common object reference class method definition */
#define DGL_DEFINE_OBJECT_REF_METHODS(TypeName, BaseTypeName, ObjectName) \ #define DGL_DEFINE_OBJECT_REF_METHODS(TypeName, BaseTypeName, ObjectName) \
TypeName() {} \ TypeName() {} \
explicit TypeName(std::shared_ptr<runtime::Object> obj): BaseTypeName(obj) {} \ explicit TypeName(std::shared_ptr<runtime::Object> obj) \
const ObjectName* operator->() const { \ : BaseTypeName(obj) {} \
return static_cast<const ObjectName*>(obj_.get()); \ const ObjectName* operator->() const { \
} \ return static_cast<const ObjectName*>(obj_.get()); \
ObjectName* operator->() { \ } \
return static_cast<ObjectName*>(obj_.get()); \ ObjectName* operator->() { return static_cast<ObjectName*>(obj_.get()); } \
} \ std::shared_ptr<ObjectName> sptr() const { \
std::shared_ptr<ObjectName> sptr() const { \ return CHECK_NOTNULL(std::dynamic_pointer_cast<ObjectName>(obj_)); \
return CHECK_NOTNULL(std::dynamic_pointer_cast<ObjectName>(obj_)); \ } \
} \ operator bool() const { return this->defined(); } \
operator bool() const { return this->defined(); } \
using ContainerType = ObjectName using ContainerType = ObjectName
/*! \brief Macro to generate object reference class definition */ /*! \brief Macro to generate object reference class definition */
#define DGL_DEFINE_OBJECT_REF(TypeName, ObjectName) \ #define DGL_DEFINE_OBJECT_REF(TypeName, ObjectName) \
class TypeName : public ::dgl::runtime::ObjectRef { \ class TypeName : public ::dgl::runtime::ObjectRef { \
public: \ public: \
DGL_DEFINE_OBJECT_REF_METHODS(TypeName, ::dgl::runtime::ObjectRef, ObjectName); \ DGL_DEFINE_OBJECT_REF_METHODS( \
TypeName, ::dgl::runtime::ObjectRef, ObjectName); \
} }
// implementations of inline functions after this // implementations of inline functions after this
template<typename T> template <typename T>
inline bool Object::is_type() const { inline bool Object::is_type() const {
// use static field so query only happens once. // use static field so query only happens once.
static uint32_t type_id = Object::TypeKey2Index(T::_type_key); static uint32_t type_id = Object::TypeKey2Index(T::_type_key);
return type_id == this->type_index(); return type_id == this->type_index();
} }
template<typename T> template <typename T>
inline bool Object::derived_from() const { inline bool Object::derived_from() const {
// use static field so query only happens once. // use static field so query only happens once.
static uint32_t type_id = Object::TypeKey2Index(T::_type_key); static uint32_t type_id = Object::TypeKey2Index(T::_type_key);
return this->_DerivedFrom(type_id); return this->_DerivedFrom(type_id);
} }
inline const Object* ObjectRef::get() const { inline const Object* ObjectRef::get() const { return obj_.get(); }
return obj_.get();
}
inline const Object* ObjectRef::operator->() const { inline const Object* ObjectRef::operator->() const { return obj_.get(); }
return obj_.get();
}
inline bool ObjectRef::defined() const { inline bool ObjectRef::defined() const { return obj_.get() != nullptr; }
return obj_.get() != nullptr;
}
inline bool ObjectRef::operator==(const ObjectRef& other) const { inline bool ObjectRef::operator==(const ObjectRef& other) const {
return obj_.get() == other.obj_.get(); return obj_.get() == other.obj_.get();
...@@ -295,7 +291,7 @@ inline uint32_t ObjectRef::type_index() const { ...@@ -295,7 +291,7 @@ inline uint32_t ObjectRef::type_index() const {
return get()->type_index(); return get()->type_index();
} }
template<typename T> template <typename T>
inline const T* ObjectRef::as() const { inline const T* ObjectRef::as() const {
const Object* ptr = get(); const Object* ptr = get();
if (ptr && ptr->is_type<T>()) { if (ptr && ptr->is_type<T>()) {
...@@ -306,9 +302,7 @@ inline const T* ObjectRef::as() const { ...@@ -306,9 +302,7 @@ inline const T* ObjectRef::as() const {
/*! \brief The hash function for nodes */ /*! \brief The hash function for nodes */
struct ObjectHash { struct ObjectHash {
size_t operator()(const ObjectRef& a) const { size_t operator()(const ObjectRef& a) const { return a.hash(); }
return a.hash();
}
}; };
/*! \brief The equal comparator for nodes */ /*! \brief The equal comparator for nodes */
......
This diff is collapsed.
This diff is collapsed.
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/serializer.h> #include <dmlc/serializer.h>
#include "c_runtime_api.h" #include "c_runtime_api.h"
#include "smart_ptr_serializer.h" #include "smart_ptr_serializer.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment