/*! * Copyright (c) 2018 by Contributors * @file graph/network.h * @brief DGL networking related APIs */ #ifndef DGL_GRAPH_NETWORK_H_ #define DGL_GRAPH_NETWORK_H_ #include #include #include #include #include #include "../c_api_common.h" #include "../rpc/network/msg_queue.h" using dgl::runtime::NDArray; namespace dgl { namespace network { /*! * @brief Create NDArray from raw data */ NDArray CreateNDArrayFromRaw( std::vector shape, DGLDataType dtype, DGLContext ctx, void* raw); /*! * @brief Message type for DGL distributed training */ enum MessageType { /*! * @brief Message for send/recv NodeFlow */ kNodeFlowMsg = 0, /*! * @brief Message for end-signal */ kFinalMsg = 1, /*! * @brief Initialize KVStore */ kInitMsg = 2, /*! * @brief Push msg to KVStore */ kPushMsg = 3, /*! * @brief Pull msg from KVStore */ kPullMsg = 4, /*! * @brief PullBack msg from KVStore */ kPullBackMsg = 5, /*! * @brief Barrier msg for KVStore */ kBarrierMsg = 6, /*! * @brief IP and ID msg for KVStore */ kIPIDMsg = 7, /*! * @brief Get data shape msg for KVStore */ kGetShapeMsg = 8, /*! * @brief Get data shape back msg for KVStore */ kGetShapeBackMsg = 9 }; /*! * @brief Meta data for NDArray message */ class ArrayMeta { public: /*! * @brief ArrayMeta constructor. * @param msg_type type of message */ explicit ArrayMeta(int msg_type) : msg_type_(msg_type), ndarray_count_(0) {} /*! * @brief Construct ArrayMeta from binary data buffer. * @param buffer data buffer * @param size data size */ ArrayMeta(char* buffer, int64_t size) { CHECK_NOTNULL(buffer); this->Deserialize(buffer, size); } /*! * @return message type */ inline int msg_type() const { return msg_type_; } /*! * @return count of ndarray */ inline int ndarray_count() const { return ndarray_count_; } /*! * @brief Add NDArray meta data to ArrayMeta * @param array DGL NDArray */ void AddArray(const NDArray& array); /*! * @brief Serialize ArrayMeta to data buffer * @param size size of serialized message * @return pointer of data buffer */ char* Serialize(int64_t* size); /*! * @brief Deserialize ArrayMeta from data buffer * @param buffer data buffer * @param size size of data buffer */ void Deserialize(char* buffer, int64_t size); /*! * @brief type of message */ int msg_type_; /*! * @brief count of ndarray in MetaMsg */ int ndarray_count_; /*! * @brief DataType for each NDArray */ std::vector data_type_; /*! * @brief We first write the ndim to data_shape_ * and then write the data shape. */ std::vector data_shape_; }; /*! * @brief C structure for holding DGL KVServer message */ class KVStoreMsg { public: /*! * @brief KVStoreMsg constructor. */ KVStoreMsg() {} /*! * @brief Construct KVStoreMsg from binary data buffer. * @param buffer data buffer * @param size data size */ KVStoreMsg(char* buffer, int64_t size) { CHECK_NOTNULL(buffer); this->Deserialize(buffer, size); } /*! * @brief Serialize KVStoreMsg to data buffer * Note that we don't serialize ID and data here. * @param size size of serialized message * @return pointer of data buffer */ char* Serialize(int64_t* size); /*! * @brief Deserialize KVStoreMsg from data buffer * @param buffer data buffer * @param size size of data buffer */ void Deserialize(char* buffer, int64_t size); /*! * @brief Message type of kvstore */ int msg_type; /*! * @brief Sender's ID */ int rank; /*! * @brief data name */ std::string name; /*! * @brief data ID */ NDArray id; /*! * @brief data matrix */ NDArray data; /*! * @brief data shape */ NDArray shape; }; } // namespace network } // namespace dgl #endif // DGL_GRAPH_NETWORK_H_