Unverified Commit 07dc8fb6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4800)



* clang-format

* manul

* manul

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 9d9280cb
...@@ -15,9 +15,9 @@ namespace featgraph { ...@@ -15,9 +15,9 @@ namespace featgraph {
void LoadFeatGraphModule(const std::string& path); void LoadFeatGraphModule(const std::string& path);
/* \brief Call Featgraph's SDDMM kernel. */ /* \brief Call Featgraph's SDDMM kernel. */
void SDDMMTreeReduction(DLManagedTensor* row, DLManagedTensor* col, void SDDMMTreeReduction(
DLManagedTensor* lhs, DLManagedTensor* rhs, DLManagedTensor* row, DLManagedTensor* col, DLManagedTensor* lhs,
DLManagedTensor* out); DLManagedTensor* rhs, DLManagedTensor* out);
} // namespace featgraph } // namespace featgraph
} // namespace dgl } // namespace dgl
......
...@@ -3,18 +3,18 @@ ...@@ -3,18 +3,18 @@
* \file featgraph/src/featgraph.cc * \file featgraph/src/featgraph.cc
* \brief FeatGraph kernels. * \brief FeatGraph kernels.
*/ */
#include <dmlc/logging.h>
#include <featgraph.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <dmlc/logging.h>
#include <featgraph.h>
namespace dgl { namespace dgl {
namespace featgraph { namespace featgraph {
/* \brief Singleton that loads the featgraph module. */ /* \brief Singleton that loads the featgraph module. */
class FeatGraphModule { class FeatGraphModule {
public: public:
static FeatGraphModule* Global() { static FeatGraphModule* Global() {
static FeatGraphModule inst; static FeatGraphModule inst;
return &inst; return &inst;
...@@ -32,7 +32,8 @@ public: ...@@ -32,7 +32,8 @@ public:
} }
return ret; return ret;
} }
private:
private:
tvm::runtime::Module mod; tvm::runtime::Module mod;
FeatGraphModule() {} FeatGraphModule() {}
}; };
...@@ -44,34 +45,36 @@ void LoadFeatGraphModule(const std::string& path) { ...@@ -44,34 +45,36 @@ void LoadFeatGraphModule(const std::string& path) {
/* \brief Convert DLDataType to string. */ /* \brief Convert DLDataType to string. */
inline std::string DTypeAsStr(const DLDataType& t) { inline std::string DTypeAsStr(const DLDataType& t) {
switch(t.code) { switch (t.code) {
case 0U: return "int" + std::to_string(t.bits); case 0U:
case 1U: return "uint" + std::to_string(t.bits); return "int" + std::to_string(t.bits);
case 2U: return "float" + std::to_string(t.bits); case 1U:
case 3U: return "bfloat" + std::to_string(t.bits); return "uint" + std::to_string(t.bits);
default: LOG(FATAL) << "Type code " << t.code << " not recognized"; case 2U:
return "float" + std::to_string(t.bits);
case 3U:
return "bfloat" + std::to_string(t.bits);
default:
LOG(FATAL) << "Type code " << t.code << " not recognized";
} }
} }
/* \brief Get operator filename. */ /* \brief Get operator filename. */
inline std::string GetOperatorName( inline std::string GetOperatorName(
const std::string& base_name, const std::string& base_name, const DLDataType& dtype,
const DLDataType& dtype,
const DLDataType& idtype) { const DLDataType& idtype) {
return base_name + "_" + DTypeAsStr(dtype) + "_" + DTypeAsStr(idtype); return base_name + "_" + DTypeAsStr(dtype) + "_" + DTypeAsStr(idtype);
} }
/* \brief Call FeatGraph's SDDMM kernel. */ /* \brief Call FeatGraph's SDDMM kernel. */
void SDDMMTreeReduction(DLManagedTensor* row, DLManagedTensor* col, void SDDMMTreeReduction(
DLManagedTensor* lhs, DLManagedTensor* rhs, DLManagedTensor* row, DLManagedTensor* col, DLManagedTensor* lhs,
DLManagedTensor* out) { DLManagedTensor* rhs, DLManagedTensor* out) {
tvm::runtime::ModuleNode* mod = FeatGraphModule::Global()->Get(); tvm::runtime::ModuleNode* mod = FeatGraphModule::Global()->Get();
std::string f_name = GetOperatorName("SDDMMTreeReduction", std::string f_name = GetOperatorName(
(row->dl_tensor).dtype, "SDDMMTreeReduction", (row->dl_tensor).dtype, (lhs->dl_tensor).dtype);
(lhs->dl_tensor).dtype);
tvm::runtime::PackedFunc f = mod->GetFunction(f_name); tvm::runtime::PackedFunc f = mod->GetFunction(f_name);
if (f != nullptr) if (f != nullptr) f(row, col, lhs, rhs, out);
f(row, col, lhs, rhs, out);
} }
} // namespace featgraph } // namespace featgraph
......
/* /*
* NOTE(zihao): this file was modified from TVM project: * NOTE(zihao): this file was modified from TVM project:
* - https://github.com/apache/tvm/blob/9713d675c64ae3075e10be5acadeef1328a44bb5/apps/howto_deploy/tvm_runtime_pack.cc * -
* * https://github.com/apache/tvm/blob/9713d675c64ae3075e10be5acadeef1328a44bb5/apps/howto_deploy/tvm_runtime_pack.cc
*
* Licensed to the Apache Software Foundation (ASF) under one * Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file * or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information * distributed with this work for additional information
......
...@@ -8,10 +8,10 @@ ...@@ -8,10 +8,10 @@
*/ */
#ifndef DGL_ARRAY_H_ #ifndef DGL_ARRAY_H_
#define DGL_ARRAY_H_ #define DGL_ARRAY_H_
#include "./aten/types.h"
#include "./aten/array_ops.h" #include "./aten/array_ops.h"
#include "./aten/coo.h"
#include "./aten/csr.h"
#include "./aten/macro.h" #include "./aten/macro.h"
#include "./aten/spmat.h" #include "./aten/spmat.h"
#include "./aten/csr.h" #include "./aten/types.h"
#include "./aten/coo.h"
#endif // DGL_ARRAY_H_ #endif // DGL_ARRAY_H_
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
#define CUB_INLINE inline #define CUB_INLINE inline
#endif // __CUDA_ARCH__ #endif // __CUDA_ARCH__
#include <iterator>
#include <algorithm> #include <algorithm>
#include <iterator>
#include <utility> #include <utility>
namespace dgl { namespace dgl {
...@@ -51,8 +51,8 @@ CUB_INLINE void swap(const Pair<DType>& r1, const Pair<DType>& r2) { ...@@ -51,8 +51,8 @@ CUB_INLINE void swap(const Pair<DType>& r1, const Pair<DType>& r2) {
r1.swap(r2); r1.swap(r2);
} }
// PairRef and PairIterator that serves as an iterator over a pair of arrays in a // PairRef and PairIterator that serves as an iterator over a pair of arrays in
// zipped fashion like zip(a, b). // a zipped fashion like zip(a, b).
template <typename DType> template <typename DType>
struct PairRef { struct PairRef {
PairRef() = delete; PairRef() = delete;
...@@ -69,9 +69,7 @@ struct PairRef { ...@@ -69,9 +69,7 @@ struct PairRef {
*b = val.second; *b = val.second;
return *this; return *this;
} }
CUB_INLINE operator Pair<DType>() const { CUB_INLINE operator Pair<DType>() const { return Pair<DType>(*a, *b); }
return Pair<DType>(*a, *b);
}
CUB_INLINE operator std::pair<DType, DType>() const { CUB_INLINE operator std::pair<DType, DType>() const {
return std::make_pair(*a, *b); return std::make_pair(*a, *b);
} }
...@@ -91,11 +89,9 @@ CUB_INLINE void swap(const PairRef<DType>& r1, const PairRef<DType>& r2) { ...@@ -91,11 +89,9 @@ CUB_INLINE void swap(const PairRef<DType>& r1, const PairRef<DType>& r2) {
} }
template <typename DType> template <typename DType>
struct PairIterator : public std::iterator<std::random_access_iterator_tag, struct PairIterator : public std::iterator<
Pair<DType>, std::random_access_iterator_tag, Pair<DType>,
std::ptrdiff_t, std::ptrdiff_t, Pair<DType*>, PairRef<DType>> {
Pair<DType*>,
PairRef<DType>> {
PairIterator() = default; PairIterator() = default;
PairIterator(const PairIterator& other) = default; PairIterator(const PairIterator& other) = default;
PairIterator(PairIterator&& other) = default; PairIterator(PairIterator&& other) = default;
...@@ -103,12 +99,24 @@ struct PairIterator : public std::iterator<std::random_access_iterator_tag, ...@@ -103,12 +99,24 @@ struct PairIterator : public std::iterator<std::random_access_iterator_tag,
PairIterator& operator=(const PairIterator& other) = default; PairIterator& operator=(const PairIterator& other) = default;
PairIterator& operator=(PairIterator&& other) = default; PairIterator& operator=(PairIterator&& other) = default;
~PairIterator() = default; ~PairIterator() = default;
CUB_INLINE bool operator==(const PairIterator& other) const { return a == other.a; } CUB_INLINE bool operator==(const PairIterator& other) const {
CUB_INLINE bool operator!=(const PairIterator& other) const { return a != other.a; } return a == other.a;
CUB_INLINE bool operator<(const PairIterator& other) const { return a < other.a; } }
CUB_INLINE bool operator>(const PairIterator& other) const { return a > other.a; } CUB_INLINE bool operator!=(const PairIterator& other) const {
CUB_INLINE bool operator<=(const PairIterator& other) const { return a <= other.a; } return a != other.a;
CUB_INLINE bool operator>=(const PairIterator& other) const { return a >= other.a; } }
CUB_INLINE bool operator<(const PairIterator& other) const {
return a < other.a;
}
CUB_INLINE bool operator>(const PairIterator& other) const {
return a > other.a;
}
CUB_INLINE bool operator<=(const PairIterator& other) const {
return a <= other.a;
}
CUB_INLINE bool operator>=(const PairIterator& other) const {
return a >= other.a;
}
CUB_INLINE PairIterator& operator+=(const std::ptrdiff_t& movement) { CUB_INLINE PairIterator& operator+=(const std::ptrdiff_t& movement) {
a += movement; a += movement;
b += movement; b += movement;
...@@ -148,12 +156,8 @@ struct PairIterator : public std::iterator<std::random_access_iterator_tag, ...@@ -148,12 +156,8 @@ struct PairIterator : public std::iterator<std::random_access_iterator_tag,
CUB_INLINE std::ptrdiff_t operator-(const PairIterator& other) const { CUB_INLINE std::ptrdiff_t operator-(const PairIterator& other) const {
return a - other.a; return a - other.a;
} }
CUB_INLINE PairRef<DType> operator*() const { CUB_INLINE PairRef<DType> operator*() const { return PairRef<DType>(a, b); }
return PairRef<DType>(a, b); CUB_INLINE PairRef<DType> operator*() { return PairRef<DType>(a, b); }
}
CUB_INLINE PairRef<DType> operator*() {
return PairRef<DType>(a, b);
}
CUB_INLINE PairRef<DType> operator[](size_t offset) const { CUB_INLINE PairRef<DType> operator[](size_t offset) const {
return PairRef<DType>(a + offset, b + offset); return PairRef<DType>(a + offset, b + offset);
} }
......
...@@ -8,8 +8,9 @@ ...@@ -8,8 +8,9 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "./types.h"
#include "../runtime/object.h" #include "../runtime/object.h"
#include "./types.h"
namespace dgl { namespace dgl {
...@@ -56,31 +57,28 @@ inline std::string ToStringSparseFormat(SparseFormat sparse_format) { ...@@ -56,31 +57,28 @@ inline std::string ToStringSparseFormat(SparseFormat sparse_format) {
inline std::vector<SparseFormat> CodeToSparseFormats(dgl_format_code_t code) { inline std::vector<SparseFormat> CodeToSparseFormats(dgl_format_code_t code) {
std::vector<SparseFormat> ret; std::vector<SparseFormat> ret;
if (code & COO_CODE) if (code & COO_CODE) ret.push_back(SparseFormat::kCOO);
ret.push_back(SparseFormat::kCOO); if (code & CSR_CODE) ret.push_back(SparseFormat::kCSR);
if (code & CSR_CODE) if (code & CSC_CODE) ret.push_back(SparseFormat::kCSC);
ret.push_back(SparseFormat::kCSR);
if (code & CSC_CODE)
ret.push_back(SparseFormat::kCSC);
return ret; return ret;
} }
inline dgl_format_code_t inline dgl_format_code_t SparseFormatsToCode(
SparseFormatsToCode(const std::vector<SparseFormat> &formats) { const std::vector<SparseFormat>& formats) {
dgl_format_code_t ret = 0; dgl_format_code_t ret = 0;
for (auto format : formats) { for (auto format : formats) {
switch (format) { switch (format) {
case SparseFormat::kCOO: case SparseFormat::kCOO:
ret |= COO_CODE; ret |= COO_CODE;
break; break;
case SparseFormat::kCSR: case SparseFormat::kCSR:
ret |= CSR_CODE; ret |= CSR_CODE;
break; break;
case SparseFormat::kCSC: case SparseFormat::kCSC:
ret |= CSC_CODE; ret |= CSC_CODE;
break; break;
default: default:
LOG(FATAL) << "Only support COO/CSR/CSC formats."; LOG(FATAL) << "Only support COO/CSR/CSC formats.";
} }
} }
return ret; return ret;
...@@ -88,20 +86,15 @@ SparseFormatsToCode(const std::vector<SparseFormat> &formats) { ...@@ -88,20 +86,15 @@ SparseFormatsToCode(const std::vector<SparseFormat> &formats) {
inline std::string CodeToStr(dgl_format_code_t code) { inline std::string CodeToStr(dgl_format_code_t code) {
std::string ret = ""; std::string ret = "";
if (code & COO_CODE) if (code & COO_CODE) ret += "coo ";
ret += "coo "; if (code & CSR_CODE) ret += "csr ";
if (code & CSR_CODE) if (code & CSC_CODE) ret += "csc ";
ret += "csr ";
if (code & CSC_CODE)
ret += "csc ";
return ret; return ret;
} }
inline SparseFormat DecodeFormat(dgl_format_code_t code) { inline SparseFormat DecodeFormat(dgl_format_code_t code) {
if (code & COO_CODE) if (code & COO_CODE) return SparseFormat::kCOO;
return SparseFormat::kCOO; if (code & CSC_CODE) return SparseFormat::kCSC;
if (code & CSC_CODE)
return SparseFormat::kCSC;
return SparseFormat::kCSR; return SparseFormat::kCSR;
} }
...@@ -113,20 +106,25 @@ struct SparseMatrix : public runtime::Object { ...@@ -113,20 +106,25 @@ struct SparseMatrix : public runtime::Object {
// Shape of this matrix. // Shape of this matrix.
int64_t num_rows = 0, num_cols = 0; int64_t num_rows = 0, num_cols = 0;
// Index arrays. For CSR, it is {indptr, indices, data}. For COO, it is {row, col, data}. // Index arrays. For CSR, it is {indptr, indices, data}. For COO, it is {row,
// col, data}.
std::vector<IdArray> indices; std::vector<IdArray> indices;
// Boolean flags. // Boolean flags.
// TODO(minjie): We might revisit this later to provide a more general solution. Currently, // TODO(minjie): We might revisit this later to provide a more general
// we only consider aten::COOMatrix and aten::CSRMatrix. // solution. Currently, we only consider aten::COOMatrix and aten::CSRMatrix.
std::vector<bool> flags; std::vector<bool> flags;
SparseMatrix() {} SparseMatrix() {}
SparseMatrix(int32_t fmt, int64_t nrows, int64_t ncols, SparseMatrix(
const std::vector<IdArray>& idx, int32_t fmt, int64_t nrows, int64_t ncols,
const std::vector<bool>& flg) const std::vector<IdArray>& idx, const std::vector<bool>& flg)
: format(fmt), num_rows(nrows), num_cols(ncols), indices(idx), flags(flg) {} : format(fmt),
num_rows(nrows),
num_cols(ncols),
indices(idx),
flags(flg) {}
static constexpr const char* _type_key = "aten.SparseMatrix"; static constexpr const char* _type_key = "aten.SparseMatrix";
DGL_DECLARE_OBJECT_TYPE_INFO(SparseMatrix, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(SparseMatrix, runtime::Object);
......
/*! /*!
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* \file dgl/aten/types.h * \file dgl/aten/types.h
* \brief Array and ID types * \brief Array and ID types
*/ */
#ifndef DGL_ATEN_TYPES_H_ #ifndef DGL_ATEN_TYPES_H_
#define DGL_ATEN_TYPES_H_ #define DGL_ATEN_TYPES_H_
#include <cstdint> #include <cstdint>
#include "../runtime/ndarray.h" #include "../runtime/ndarray.h"
namespace dgl { namespace dgl {
...@@ -15,7 +16,7 @@ typedef uint64_t dgl_id_t; ...@@ -15,7 +16,7 @@ typedef uint64_t dgl_id_t;
typedef uint64_t dgl_type_t; typedef uint64_t dgl_type_t;
/*! \brief Type for dgl fomrat code, whose binary representation indices /*! \brief Type for dgl fomrat code, whose binary representation indices
* which sparse format is in use and which is not. * which sparse format is in use and which is not.
* *
* Suppose the binary representation is xyz, then * Suppose the binary representation is xyz, then
* - x indicates whether csc is in use (1 for true and 0 for false). * - x indicates whether csc is in use (1 for true and 0 for false).
* - y indicates whether csr is in use. * - y indicates whether csr is in use.
......
This diff is collapsed.
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "./runtime/ndarray.h" #include "./runtime/ndarray.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -37,7 +38,7 @@ struct BcastOff { ...@@ -37,7 +38,7 @@ struct BcastOff {
/*! \brief Whether broadcast is required or not. */ /*! \brief Whether broadcast is required or not. */
bool use_bcast; bool use_bcast;
/*! /*!
* \brief Auxiliary information for kernel computation * \brief Auxiliary information for kernel computation
* \note lhs_len refers to the left hand side operand length. * \note lhs_len refers to the left hand side operand length.
* e.g. 15 for shape (1, 3, 5) * e.g. 15 for shape (1, 3, 5)
* rhs_len refers to the right hand side operand length. * rhs_len refers to the right hand side operand length.
...@@ -61,7 +62,6 @@ struct BcastOff { ...@@ -61,7 +62,6 @@ struct BcastOff {
*/ */
BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs); BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs);
} // namespace dgl } // namespace dgl
#endif // DGL_BCAST_H_ #endif // DGL_BCAST_H_
...@@ -6,13 +6,12 @@ ...@@ -6,13 +6,12 @@
#ifndef DGL_GRAPH_H_ #ifndef DGL_GRAPH_H_
#define DGL_GRAPH_H_ #define DGL_GRAPH_H_
#include <string>
#include <vector>
#include <string>
#include <cstdint> #include <cstdint>
#include <utility>
#include <tuple>
#include <memory> #include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "graph_interface.h" #include "graph_interface.h"
...@@ -23,7 +22,7 @@ class GraphOp; ...@@ -23,7 +22,7 @@ class GraphOp;
typedef std::shared_ptr<Graph> MutableGraphPtr; typedef std::shared_ptr<Graph> MutableGraphPtr;
/*! \brief Mutable graph based on adjacency list. */ /*! \brief Mutable graph based on adjacency list. */
class Graph: public GraphInterface { class Graph : public GraphInterface {
public: public:
/*! \brief default constructor */ /*! \brief default constructor */
Graph() {} Graph() {}
...@@ -89,13 +88,9 @@ class Graph: public GraphInterface { ...@@ -89,13 +88,9 @@ class Graph: public GraphInterface {
num_edges_ = 0; num_edges_ = 0;
} }
DGLContext Context() const override { DGLContext Context() const override { return DGLContext{kDGLCPU, 0}; }
return DGLContext{kDGLCPU, 0};
}
uint8_t NumBits() const override { uint8_t NumBits() const override { return 64; }
return 64;
}
/*! /*!
* \note not const since we have caches * \note not const since we have caches
...@@ -106,21 +101,17 @@ class Graph: public GraphInterface { ...@@ -106,21 +101,17 @@ class Graph: public GraphInterface {
/*! /*!
* \return whether the graph is read-only * \return whether the graph is read-only
*/ */
bool IsReadonly() const override { bool IsReadonly() const override { return false; }
return false;
}
/*! \return the number of vertices in the graph.*/ /*! \return the number of vertices in the graph.*/
uint64_t NumVertices() const override { uint64_t NumVertices() const override { return adjlist_.size(); }
return adjlist_.size();
}
/*! \return the number of edges in the graph.*/ /*! \return the number of edges in the graph.*/
uint64_t NumEdges() const override { uint64_t NumEdges() const override { return num_edges_; }
return num_edges_;
}
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/ /*! \return a 0-1 array indicating whether the given vertices are in the
* graph.
*/
BoolArray HasVertices(IdArray vids) const override; BoolArray HasVertices(IdArray vids) const override;
/*! \return true if the given edge is in the graph.*/ /*! \return true if the given edge is in the graph.*/
...@@ -132,7 +123,8 @@ class Graph: public GraphInterface { ...@@ -132,7 +123,8 @@ class Graph: public GraphInterface {
/*! /*!
* \brief Find the predecessors of a vertex. * \brief Find the predecessors of a vertex.
* \param vid The vertex id. * \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1). * \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the predecessor id array. * \return the predecessor id array.
*/ */
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override; IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override;
...@@ -140,7 +132,8 @@ class Graph: public GraphInterface { ...@@ -140,7 +132,8 @@ class Graph: public GraphInterface {
/*! /*!
* \brief Find the successors of a vertex. * \brief Find the successors of a vertex.
* \param vid The vertex id. * \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1). * \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the successor id array. * \return the successor id array.
*/ */
IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override; IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override;
...@@ -169,7 +162,8 @@ class Graph: public GraphInterface { ...@@ -169,7 +162,8 @@ class Graph: public GraphInterface {
/*! /*!
* \brief Find the edge ID and return the pair of endpoints * \brief Find the edge ID and return the pair of endpoints
* \param eid The edge ID * \param eid The edge ID
* \return a pair whose first element is the source and the second the destination. * \return a pair whose first element is the source and the second the
* destination.
*/ */
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override { std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
return std::make_pair(all_edges_src_[eid], all_edges_dst_[eid]); return std::make_pair(all_edges_src_[eid], all_edges_dst_[eid]);
...@@ -178,7 +172,8 @@ class Graph: public GraphInterface { ...@@ -178,7 +172,8 @@ class Graph: public GraphInterface {
/*! /*!
* \brief Find the edge IDs and return their source and target node IDs. * \brief Find the edge IDs and return their source and target node IDs.
* \param eids The edge ID array. * \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved. * \return EdgeArray containing all edges with id in eid. The order is
* preserved.
*/ */
EdgeArray FindEdges(IdArray eids) const override; EdgeArray FindEdges(IdArray eids) const override;
...@@ -216,10 +211,11 @@ class Graph: public GraphInterface { ...@@ -216,10 +211,11 @@ class Graph: public GraphInterface {
* \brief Get all the edges in the graph. * \brief Get all the edges in the graph.
* \note If sorted is true, the returned edges list is sorted by their src and * \note If sorted is true, the returned edges list is sorted by their src and
* dst ids. Otherwise, they are in their edge id order. * dst ids. Otherwise, they are in their edge id order.
* \param sorted Whether the returned edge list is sorted by their src and dst ids * \param sorted Whether the returned edge list is sorted by their src and dst
* ids.
* \return the id arrays of the two endpoints of the edges. * \return the id arrays of the two endpoints of the edges.
*/ */
EdgeArray Edges(const std::string &order = "") const override; EdgeArray Edges(const std::string& order = "") const override;
/*! /*!
* \brief Get the in degree of the given vertex. * \brief Get the in degree of the given vertex.
...@@ -258,13 +254,14 @@ class Graph: public GraphInterface { ...@@ -258,13 +254,14 @@ class Graph: public GraphInterface {
/*! /*!
* \brief Construct the induced subgraph of the given vertices. * \brief Construct the induced subgraph of the given vertices.
* *
* The induced subgraph is a subgraph formed by specifying a set of vertices V' and then * The induced subgraph is a subgraph formed by specifying a set of vertices
* selecting all of the edges from the original graph that connect two vertices in V'. * V' and then selecting all of the edges from the original graph that connect
* two vertices in V'.
* *
* Vertices and edges in the original graph will be "reindexed" to local index. The local * Vertices and edges in the original graph will be "reindexed" to local
* index of the vertices preserve the order of the given id array, while the local index * index. The local index of the vertices preserve the order of the given id
* of the edges preserve the index order in the original graph. Vertices not in the * array, while the local index of the edges preserve the index order in the
* original graph are ignored. * original graph. Vertices not in the original graph are ignored.
* *
* The result subgraph is read-only. * The result subgraph is read-only.
* *
...@@ -276,20 +273,22 @@ class Graph: public GraphInterface { ...@@ -276,20 +273,22 @@ class Graph: public GraphInterface {
/*! /*!
* \brief Construct the induced edge subgraph of the given edges. * \brief Construct the induced edge subgraph of the given edges.
* *
* The induced edges subgraph is a subgraph formed by specifying a set of edges E' and then * The induced edges subgraph is a subgraph formed by specifying a set of
* selecting all of the nodes from the original graph that are endpoints in E'. * edges E' and then selecting all of the nodes from the original graph that
* are endpoints in E'.
* *
* Vertices and edges in the original graph will be "reindexed" to local index. The local * Vertices and edges in the original graph will be "reindexed" to local
* index of the edges preserve the order of the given id array, while the local index * index. The local index of the edges preserve the order of the given id
* of the vertices preserve the index order in the original graph. Edges not in the * array, while the local index of the vertices preserve the index order in
* original graph are ignored. * the original graph. Edges not in the original graph are ignored.
* *
* The result subgraph is read-only. * The result subgraph is read-only.
* *
* \param eids The edges in the subgraph. * \param eids The edges in the subgraph.
* \return the induced edge subgraph * \return the induced edge subgraph
*/ */
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override; Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const override;
/*! /*!
* \brief Return the successor vector * \brief Return the successor vector
...@@ -344,12 +343,11 @@ class Graph: public GraphInterface { ...@@ -344,12 +343,11 @@ class Graph: public GraphInterface {
* \param fmt the format of the returned adjacency matrix. * \param fmt the format of the returned adjacency matrix.
* \return a vector of three IdArray. * \return a vector of three IdArray.
*/ */
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override; std::vector<IdArray> GetAdj(
bool transpose, const std::string& fmt) const override;
/*! \brief Create an empty graph */ /*! \brief Create an empty graph */
static MutableGraphPtr Create() { static MutableGraphPtr Create() { return std::make_shared<Graph>(); }
return std::make_shared<Graph>();
}
/*! \brief Create from coo */ /*! \brief Create from coo */
static MutableGraphPtr CreateFromCOO( static MutableGraphPtr CreateFromCOO(
......
...@@ -6,11 +6,11 @@ ...@@ -6,11 +6,11 @@
#ifndef DGL_GRAPH_INTERFACE_H_ #ifndef DGL_GRAPH_INTERFACE_H_
#define DGL_GRAPH_INTERFACE_H_ #define DGL_GRAPH_INTERFACE_H_
#include <string>
#include <vector>
#include <utility>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <string>
#include <utility>
#include <vector>
#include "./runtime/object.h" #include "./runtime/object.h"
#include "array.h" #include "array.h"
...@@ -23,7 +23,8 @@ const dgl_id_t DGL_INVALID_ID = static_cast<dgl_id_t>(-1); ...@@ -23,7 +23,8 @@ const dgl_id_t DGL_INVALID_ID = static_cast<dgl_id_t>(-1);
* \brief This class references data in std::vector. * \brief This class references data in std::vector.
* *
* This isn't a STL-style iterator. It provides a STL data container interface. * This isn't a STL-style iterator. It provides a STL data container interface.
* but it doesn't own data itself. instead, it only references data in std::vector. * but it doesn't own data itself. instead, it only references data in
* std::vector.
*/ */
class DGLIdIters { class DGLIdIters {
public: public:
...@@ -34,18 +35,11 @@ class DGLIdIters { ...@@ -34,18 +35,11 @@ class DGLIdIters {
this->begin_ = begin; this->begin_ = begin;
this->end_ = end; this->end_ = end;
} }
const dgl_id_t *begin() const { const dgl_id_t *begin() const { return this->begin_; }
return this->begin_; const dgl_id_t *end() const { return this->end_; }
} dgl_id_t operator[](int64_t i) const { return *(this->begin_ + i); }
const dgl_id_t *end() const { size_t size() const { return this->end_ - this->begin_; }
return this->end_;
}
dgl_id_t operator[](int64_t i) const {
return *(this->begin_ + i);
}
size_t size() const {
return this->end_ - this->begin_;
}
private: private:
const dgl_id_t *begin_{nullptr}, *end_{nullptr}; const dgl_id_t *begin_{nullptr}, *end_{nullptr};
}; };
...@@ -63,23 +57,15 @@ class DGLIdIters32 { ...@@ -63,23 +57,15 @@ class DGLIdIters32 {
this->begin_ = begin; this->begin_ = begin;
this->end_ = end; this->end_ = end;
} }
const int32_t *begin() const { const int32_t *begin() const { return this->begin_; }
return this->begin_; const int32_t *end() const { return this->end_; }
} int32_t operator[](int32_t i) const { return *(this->begin_ + i); }
const int32_t *end() const { size_t size() const { return this->end_ - this->begin_; }
return this->end_;
}
int32_t operator[](int32_t i) const {
return *(this->begin_ + i);
}
size_t size() const {
return this->end_ - this->begin_;
}
private: private:
const int32_t *begin_{nullptr}, *end_{nullptr}; const int32_t *begin_{nullptr}, *end_{nullptr};
}; };
/* \brief structure used to represent a list of edges */ /* \brief structure used to represent a list of edges */
typedef struct { typedef struct {
/* \brief the two endpoints and the id of the edge */ /* \brief the two endpoints and the id of the edge */
...@@ -140,7 +126,8 @@ class GraphInterface : public runtime::Object { ...@@ -140,7 +126,8 @@ class GraphInterface : public runtime::Object {
virtual DGLContext Context() const = 0; virtual DGLContext Context() const = 0;
/*! /*!
* \brief Get the number of integer bits used to store node/edge ids (32 or 64). * \brief Get the number of integer bits used to store node/edge ids
* (32 or 64).
*/ */
virtual uint8_t NumBits() const = 0; virtual uint8_t NumBits() const = 0;
...@@ -192,11 +179,11 @@ class GraphInterface : public runtime::Object { ...@@ -192,11 +179,11 @@ class GraphInterface : public runtime::Object {
virtual uint64_t NumEdges() const = 0; virtual uint64_t NumEdges() const = 0;
/*! \return true if the given vertex is in the graph.*/ /*! \return true if the given vertex is in the graph.*/
virtual bool HasVertex(dgl_id_t vid) const { virtual bool HasVertex(dgl_id_t vid) const { return vid < NumVertices(); }
return vid < NumVertices();
}
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/ /*! \return a 0-1 array indicating whether the given vertices are in the
* graph.
*/
virtual BoolArray HasVertices(IdArray vids) const = 0; virtual BoolArray HasVertices(IdArray vids) const = 0;
/*! \return true if the given edge is in the graph.*/ /*! \return true if the given edge is in the graph.*/
...@@ -208,7 +195,8 @@ class GraphInterface : public runtime::Object { ...@@ -208,7 +195,8 @@ class GraphInterface : public runtime::Object {
/*! /*!
* \brief Find the predecessors of a vertex. * \brief Find the predecessors of a vertex.
* \param vid The vertex id. * \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1). * \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the predecessor id array. * \return the predecessor id array.
*/ */
virtual IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const = 0; virtual IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const = 0;
...@@ -216,7 +204,8 @@ class GraphInterface : public runtime::Object { ...@@ -216,7 +204,8 @@ class GraphInterface : public runtime::Object {
/*! /*!
* \brief Find the successors of a vertex. * \brief Find the successors of a vertex.
* \param vid The vertex id. * \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1). * \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the successor id array. * \return the successor id array.
*/ */
virtual IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const = 0; virtual IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const = 0;
...@@ -245,14 +234,16 @@ class GraphInterface : public runtime::Object { ...@@ -245,14 +234,16 @@ class GraphInterface : public runtime::Object {
/*! /*!
* \brief Find the edge ID and return the pair of endpoints * \brief Find the edge ID and return the pair of endpoints
* \param eid The edge ID * \param eid The edge ID
* \return a pair whose first element is the source and the second the destination. * \return a pair whose first element is the source and the second the
* destination.
*/ */
virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const = 0; virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const = 0;
/*! /*!
* \brief Find the edge IDs and return their source and target node IDs. * \brief Find the edge IDs and return their source and target node IDs.
* \param eids The edge ID array. * \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved. * \return EdgeArray containing all edges with id in eid. The order is
* preserved.
*/ */
virtual EdgeArray FindEdges(IdArray eids) const = 0; virtual EdgeArray FindEdges(IdArray eids) const = 0;
...@@ -288,8 +279,8 @@ class GraphInterface : public runtime::Object { ...@@ -288,8 +279,8 @@ class GraphInterface : public runtime::Object {
/*! /*!
* \brief Get all the edges in the graph. * \brief Get all the edges in the graph.
* \note If order is "srcdst", the returned edges list is sorted by their src and * \note If order is "srcdst", the returned edges list is sorted by their src
* dst ids. If order is "eid", they are in their edge id order. * and dst ids. If order is "eid", they are in their edge id order.
* Otherwise, in the arbitrary order. * Otherwise, in the arbitrary order.
* \param order The order of the returned edge list. * \param order The order of the returned edge list.
* \return the id arrays of the two endpoints of the edges. * \return the id arrays of the two endpoints of the edges.
...@@ -327,13 +318,14 @@ class GraphInterface : public runtime::Object { ...@@ -327,13 +318,14 @@ class GraphInterface : public runtime::Object {
/*! /*!
* \brief Construct the induced subgraph of the given vertices. * \brief Construct the induced subgraph of the given vertices.
* *
* The induced subgraph is a subgraph formed by specifying a set of vertices V' and then * The induced subgraph is a subgraph formed by specifying a set of vertices
* selecting all of the edges from the original graph that connect two vertices in V'. * V' and then selecting all of the edges from the original graph that connect
* two vertices in V'.
* *
* Vertices and edges in the original graph will be "reindexed" to local index. The local * Vertices and edges in the original graph will be "reindexed" to local
* index of the vertices preserve the order of the given id array, while the local index * index. The local index of the vertices preserve the order of the given id
* of the edges preserve the index order in the original graph. Vertices not in the * array, while the local index of the edges preserve the index order in the
* original graph are ignored. * original graph. Vertices not in the original graph are ignored.
* *
* The result subgraph is read-only. * The result subgraph is read-only.
* *
...@@ -345,22 +337,24 @@ class GraphInterface : public runtime::Object { ...@@ -345,22 +337,24 @@ class GraphInterface : public runtime::Object {
/*! /*!
* \brief Construct the induced edge subgraph of the given edges. * \brief Construct the induced edge subgraph of the given edges.
* *
* The induced edges subgraph is a subgraph formed by specifying a set of edges E' and then * The induced edges subgraph is a subgraph formed by specifying a set of
* selecting all of the nodes from the original graph that are endpoints in E'. * edges E' and then selecting all of the nodes from the original graph that
* are endpoints in E'.
* *
* Vertices and edges in the original graph will be "reindexed" to local index. The local * Vertices and edges in the original graph will be "reindexed" to local
* index of the edges preserve the order of the given id array, while the local index * index. The local index of the edges preserve the order of the given id
* of the vertices preserve the index order in the original graph. Edges not in the * array, while the local index of the vertices preserve the index order in
* original graph are ignored. * the original graph. Edges not in the original graph are ignored.
* *
* The result subgraph is read-only. * The result subgraph is read-only.
* *
* \param eids The edges in the subgraph. * \param eids The edges in the subgraph.
* \param preserve_nodes If true, the vertices will not be relabeled, so some vertices * \param preserve_nodes If true, the vertices will not be relabeled, so some
* may have no incident edges. * vertices may have no incident edges.
* \return the induced edge subgraph * \return the induced edge subgraph
*/ */
virtual Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const = 0; virtual Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const = 0;
/*! /*!
* \brief Return the successor vector * \brief Return the successor vector
...@@ -399,14 +393,15 @@ class GraphInterface : public runtime::Object { ...@@ -399,14 +393,15 @@ class GraphInterface : public runtime::Object {
* If the fmt is 'csr', the function should return three arrays, representing * If the fmt is 'csr', the function should return three arrays, representing
* indptr, indices and edge ids * indptr, indices and edge ids
* *
* If the fmt is 'coo', the function should return one array of shape (2, nnz), * If the fmt is 'coo', the function should return one array of shape (2,
* representing a horitonzal stack of row and col indices. * nnz), representing a horitonzal stack of row and col indices.
* *
* \param transpose A flag to transpose the returned adjacency matrix. * \param transpose A flag to transpose the returned adjacency matrix.
* \param fmt the format of the returned adjacency matrix. * \param fmt the format of the returned adjacency matrix.
* \return a vector of IdArrays. * \return a vector of IdArrays.
*/ */
virtual std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const = 0; virtual std::vector<IdArray> GetAdj(
bool transpose, const std::string &fmt) const = 0;
/*! /*!
* \brief Sort the columns in CSR. * \brief Sort the columns in CSR.
...@@ -414,10 +409,9 @@ class GraphInterface : public runtime::Object { ...@@ -414,10 +409,9 @@ class GraphInterface : public runtime::Object {
* This sorts the columns in each row based on the column Ids. * This sorts the columns in each row based on the column Ids.
* The edge ids should be sorted accordingly. * The edge ids should be sorted accordingly.
*/ */
virtual void SortCSR() { virtual void SortCSR() {}
}
static constexpr const char* _type_key = "graph.Graph"; static constexpr const char *_type_key = "graph.Graph";
DGL_DECLARE_OBJECT_TYPE_INFO(GraphInterface, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(GraphInterface, runtime::Object);
}; };
...@@ -430,21 +424,23 @@ struct Subgraph : public runtime::Object { ...@@ -430,21 +424,23 @@ struct Subgraph : public runtime::Object {
GraphPtr graph; GraphPtr graph;
/*! /*!
* \brief The induced vertex ids. * \brief The induced vertex ids.
* \note This is also a map from the new vertex id to the vertex id in the parent graph. * \note This is also a map from the new vertex id to the vertex id in the
* parent graph.
*/ */
IdArray induced_vertices; IdArray induced_vertices;
/*! /*!
* \brief The induced edge ids. * \brief The induced edge ids.
* \note This is also a map from the new edge id to the edge id in the parent graph. * \note This is also a map from the new edge id to the edge id in the parent
* graph.
*/ */
IdArray induced_edges; IdArray induced_edges;
static constexpr const char* _type_key = "graph.Subgraph"; static constexpr const char *_type_key = "graph.Subgraph";
DGL_DECLARE_OBJECT_TYPE_INFO(Subgraph, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(Subgraph, runtime::Object);
}; };
/*! \brief Subgraph data structure for negative subgraph */ /*! \brief Subgraph data structure for negative subgraph */
struct NegSubgraph: public Subgraph { struct NegSubgraph : public Subgraph {
/*! \brief The existence of the negative edges in the parent graph. */ /*! \brief The existence of the negative edges in the parent graph. */
IdArray exist; IdArray exist;
...@@ -456,7 +452,7 @@ struct NegSubgraph: public Subgraph { ...@@ -456,7 +452,7 @@ struct NegSubgraph: public Subgraph {
}; };
/*! \brief Subgraph data structure for halo subgraph */ /*! \brief Subgraph data structure for halo subgraph */
struct HaloSubgraph: public Subgraph { struct HaloSubgraph : public Subgraph {
/*! \brief Indicate if a node belongs to the partition. */ /*! \brief Indicate if a node belongs to the partition. */
IdArray inner_nodes; IdArray inner_nodes;
}; };
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define DGL_GRAPH_OP_H_ #define DGL_GRAPH_OP_H_
#include <vector> #include <vector>
#include "graph.h" #include "graph.h"
#include "immutable_graph.h" #include "immutable_graph.h"
...@@ -17,7 +18,8 @@ class GraphOp { ...@@ -17,7 +18,8 @@ class GraphOp {
/*! /*!
* \brief Return a new graph with all the edges reversed. * \brief Return a new graph with all the edges reversed.
* *
* The returned graph preserves the vertex and edge index in the original graph. * The returned graph preserves the vertex and edge index in the original
* graph.
* *
* \return the reversed graph * \return the reversed graph
*/ */
...@@ -40,10 +42,10 @@ class GraphOp { ...@@ -40,10 +42,10 @@ class GraphOp {
* \brief Return a disjoint union of the input graphs. * \brief Return a disjoint union of the input graphs.
* *
* The new graph will include all the nodes/edges in the given graphs. * The new graph will include all the nodes/edges in the given graphs.
* Nodes/Edges will be relabled by adding the cumsum of the previous graph sizes * Nodes/Edges will be relabled by adding the cumsum of the previous graph
* in the given sequence order. For example, giving input [g1, g2, g3], where * sizes in the given sequence order. For example, giving input [g1, g2, g3],
* they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7 * where they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become
* in the result graph. Edge ids are re-assigned similarly. * node#7 in the result graph. Edge ids are re-assigned similarly.
* *
* The input list must be either ALL mutable graphs or ALL immutable graphs. * The input list must be either ALL mutable graphs or ALL immutable graphs.
* The returned graph type is also determined by the input graph type. * The returned graph type is also determined by the input graph type.
...@@ -62,12 +64,13 @@ class GraphOp { ...@@ -62,12 +64,13 @@ class GraphOp {
* *
* If the input graph is mutable, the result graphs are mutable. * If the input graph is mutable, the result graphs are mutable.
* If the input graph is immutable, the result graphs are immutable. * If the input graph is immutable, the result graphs are immutable.
* *
* \param graph The graph to be partitioned. * \param graph The graph to be partitioned.
* \param num The number of partitions. * \param num The number of partitions.
* \return a list of partitioned graphs * \return a list of partitioned graphs
*/ */
static std::vector<GraphPtr> DisjointPartitionByNum(GraphPtr graph, int64_t num); static std::vector<GraphPtr> DisjointPartitionByNum(
GraphPtr graph, int64_t num);
/*! /*!
* \brief Partition the graph into several subgraphs. * \brief Partition the graph into several subgraphs.
...@@ -78,21 +81,22 @@ class GraphOp { ...@@ -78,21 +81,22 @@ class GraphOp {
* *
* If the input graph is mutable, the result graphs are mutable. * If the input graph is mutable, the result graphs are mutable.
* If the input graph is immutable, the result graphs are immutable. * If the input graph is immutable, the result graphs are immutable.
* *
* \param graph The graph to be partitioned. * \param graph The graph to be partitioned.
* \param sizes The number of partitions. * \param sizes The number of partitions.
* \return a list of partitioned graphs * \return a list of partitioned graphs
*/ */
static std::vector<GraphPtr> DisjointPartitionBySizes(GraphPtr graph, IdArray sizes); static std::vector<GraphPtr> DisjointPartitionBySizes(
GraphPtr graph, IdArray sizes);
/*! /*!
* \brief Map vids in the parent graph to the vids in the subgraph. * \brief Map vids in the parent graph to the vids in the subgraph.
* *
* If the Id doesn't exist in the subgraph, -1 will be used. * If the Id doesn't exist in the subgraph, -1 will be used.
* *
* \param parent_vid_map An array that maps the vids in the parent graph to the * \param parent_vid_map An array that maps the vids in the parent graph to
* subgraph. The elements store the vertex Ids in the parent graph, and the * the subgraph. The elements store the vertex Ids in the parent graph, and
* indices indicate the vertex Ids in the subgraph. * the indices indicate the vertex Ids in the subgraph.
* \param query The vertex Ids in the parent graph. * \param query The vertex Ids in the parent graph.
* \return an Id array that contains the subgraph node Ids. * \return an Id array that contains the subgraph node Ids.
*/ */
...@@ -134,15 +138,18 @@ class GraphOp { ...@@ -134,15 +138,18 @@ class GraphOp {
static GraphPtr ToBidirectedMutableGraph(GraphPtr graph); static GraphPtr ToBidirectedMutableGraph(GraphPtr graph);
/*! /*!
* \brief Same as BidirectedMutableGraph except that the returned graph is immutable. * \brief Same as BidirectedMutableGraph except that the returned graph is
* immutable.
* \param graph The input graph. * \param graph The input graph.
* \return a new immutable bidirected graph. * \return a new immutable bidirected
* graph.
*/ */
static GraphPtr ToBidirectedImmutableGraph(GraphPtr graph); static GraphPtr ToBidirectedImmutableGraph(GraphPtr graph);
/*! /*!
* \brief Same as BidirectedMutableGraph except that the returned graph is immutable * \brief Same as BidirectedMutableGraph except that the returned graph is
* and call gk_csr_MakeSymmetric in GKlib. This is more efficient than ToBidirectedImmutableGraph. * immutable and call gk_csr_MakeSymmetric in GKlib. This is more efficient
* It return a null pointer if the conversion fails. * than ToBidirectedImmutableGraph. It return a null pointer if the conversion
* fails.
* *
* \param graph The input graph. * \param graph The input graph.
* \return a new immutable bidirected graph. * \return a new immutable bidirected graph.
...@@ -151,21 +158,25 @@ class GraphOp { ...@@ -151,21 +158,25 @@ class GraphOp {
/*! /*!
* \brief Get a induced subgraph with HALO nodes. * \brief Get a induced subgraph with HALO nodes.
* The HALO nodes are the ones that can be reached from `nodes` within `num_hops`. * The HALO nodes are the ones that can be reached from `nodes` within
* `num_hops`.
* \param graph The input graph. * \param graph The input graph.
* \param nodes The input nodes that form the core of the induced subgraph. * \param nodes The input nodes that form the core of the induced subgraph.
* \param num_hops The number of hops to reach. * \param num_hops The number of hops to reach.
* \return the induced subgraph with HALO nodes. * \return the induced subgraph with HALO nodes.
*/ */
static HaloSubgraph GetSubgraphWithHalo(GraphPtr graph, IdArray nodes, int num_hops); static HaloSubgraph GetSubgraphWithHalo(
GraphPtr graph, IdArray nodes, int num_hops);
/*! /*!
* \brief Reorder the nodes in the immutable graph. * \brief Reorder the nodes in the immutable graph.
* \param graph The input graph. * \param graph The input graph.
* \param new_order The node Ids in the new graph. The index in `new_order` is old node Ids. * \param new_order The node Ids in the new graph. The index in `new_order` is
* old node Ids.
* \return the graph with reordered node Ids * \return the graph with reordered node Ids
*/ */
static GraphPtr ReorderImmutableGraph(ImmutableGraphPtr ig, IdArray new_order); static GraphPtr ReorderImmutableGraph(
ImmutableGraphPtr ig, IdArray new_order);
}; };
} // namespace dgl } // namespace dgl
......
...@@ -16,7 +16,8 @@ namespace dgl { ...@@ -16,7 +16,8 @@ namespace dgl {
* \brief Class for representing frontiers. * \brief Class for representing frontiers.
* *
* Each frontier is a list of nodes/edges (specified by their ids). * Each frontier is a list of nodes/edges (specified by their ids).
* An optional tag can be specified on each node/edge (represented by an int value). * An optional tag can be specified on each node/edge (represented by an int
* value).
*/ */
struct Frontiers { struct Frontiers {
/*!\brief a vector store for the nodes/edges in all the frontiers */ /*!\brief a vector store for the nodes/edges in all the frontiers */
...@@ -78,10 +79,10 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source); ...@@ -78,10 +79,10 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
* FORWARD(0), REVERSE(1), NONTREE(2) * FORWARD(0), REVERSE(1), NONTREE(2)
* *
* A FORWARD edge is one in which `u` has been visisted but `v` has not. * A FORWARD edge is one in which `u` has been visisted but `v` has not.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the edge * A REVERSE edge is one in which both `u` and `v` have been visisted and the
* is in the DFS tree. * edge is in the DFS tree.
* A NONTREE edge is one in which both `u` and `v` have been visisted but the edge * A NONTREE edge is one in which both `u` and `v` have been visisted but the
* is NOT in the DFS tree. * edge is NOT in the DFS tree.
* *
* \param csr The input csr matrix. * \param csr The input csr matrix.
* \param sources Source nodes. * \param sources Source nodes.
...@@ -90,11 +91,9 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source); ...@@ -90,11 +91,9 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
* \param return_labels If true, return the recorded edge tags. * \param return_labels If true, return the recorded edge tags.
* \return A Frontiers object containing the search result * \return A Frontiers object containing the search result
*/ */
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, Frontiers DGLDFSLabeledEdges(
IdArray source, const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
const bool has_reverse_edge, const bool has_nontree_edge, const bool return_labels);
const bool has_nontree_edge,
const bool return_labels);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
This diff is collapsed.
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#define DGL_KERNEL_H_ #define DGL_KERNEL_H_
#include <string> #include <string>
#include <vector>
#include <utility> #include <utility>
#include <vector>
#include "array.h"
#include "./bcast.h"
#include "./base_heterograph.h" #include "./base_heterograph.h"
#include "./bcast.h"
#include "array.h"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -30,12 +30,9 @@ namespace aten { ...@@ -30,12 +30,9 @@ namespace aten {
* as the argmax on source nodes and edges for reduce operators such as * as the argmax on source nodes and edges for reduce operators such as
* `min` and `max`. * `min` and `max`.
*/ */
void SpMM(const std::string& op, const std::string& reduce, void SpMM(
HeteroGraphPtr graph, const std::string& op, const std::string& reduce, HeteroGraphPtr graph,
NDArray ufeat, NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux);
/*! /*!
* \brief Generalized Sampled Dense-Dense Matrix Multiplication. * \brief Generalized Sampled Dense-Dense Matrix Multiplication.
...@@ -46,23 +43,18 @@ void SpMM(const std::string& op, const std::string& reduce, ...@@ -46,23 +43,18 @@ void SpMM(const std::string& op, const std::string& reduce,
* \param vfeat The destination node feature. * \param vfeat The destination node feature.
* \param out The output feature on edge. * \param out The output feature on edge.
*/ */
void SDDMM(const std::string& op, void SDDMM(
HeteroGraphPtr graph, const std::string& op, HeteroGraphPtr graph, NDArray ufeat, NDArray efeat,
NDArray ufeat, NDArray out);
NDArray efeat,
NDArray out);
/*! /*!
* \brief Sparse-sparse matrix multiplication. * \brief Sparse-sparse matrix multiplication.
* *
* The sparse matrices must have scalar weights (i.e. \a A_weights and \a B_weights * The sparse matrices must have scalar weights (i.e. \a A_weights and \a
* are 1D vectors.) * B_weights are 1D vectors.)
*/ */
std::pair<CSRMatrix, NDArray> CSRMM( std::pair<CSRMatrix, NDArray> CSRMM(
CSRMatrix A, CSRMatrix A, NDArray A_weights, CSRMatrix B, NDArray B_weights);
NDArray A_weights,
CSRMatrix B,
NDArray B_weights);
/*! /*!
* \brief Summing up a list of sparse matrices. * \brief Summing up a list of sparse matrices.
...@@ -71,8 +63,7 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -71,8 +63,7 @@ std::pair<CSRMatrix, NDArray> CSRMM(
* are 1D vectors.) * are 1D vectors.)
*/ */
std::pair<CSRMatrix, NDArray> CSRSum( std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& A, const std::vector<CSRMatrix>& A, const std::vector<NDArray>& A_weights);
const std::vector<NDArray>& A_weights);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -22,15 +22,17 @@ class Lazy { ...@@ -22,15 +22,17 @@ class Lazy {
/*!\brief default constructor to construct a lazy object */ /*!\brief default constructor to construct a lazy object */
Lazy() {} Lazy() {}
/*!\brief constructor to construct an object with given value (non-lazy case) */ /*!
explicit Lazy(const T& val): ptr_(new T(val)) {} * \brief constructor to construct an object with given value (non-lazy case)
*/
explicit Lazy(const T& val) : ptr_(new T(val)) {}
/*!\brief destructor */ /*!\brief destructor */
~Lazy() = default; ~Lazy() = default;
/*! /*!
* \brief Get the value of this object. If the object has not been instantiated, * \brief Get the value of this object. If the object has not been
* using the provided function to create it. * instantiated, using the provided function to create it.
* \param fn The creator function. * \param fn The creator function.
* \return the object value. * \return the object value.
*/ */
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
#ifndef DGL_NODEFLOW_H_ #ifndef DGL_NODEFLOW_H_
#define DGL_NODEFLOW_H_ #define DGL_NODEFLOW_H_
#include <vector>
#include <string>
#include <memory> #include <memory>
#include <string>
#include <vector>
#include "./runtime/object.h" #include "./runtime/object.h"
#include "graph_interface.h" #include "graph_interface.h"
...@@ -18,12 +18,12 @@ namespace dgl { ...@@ -18,12 +18,12 @@ namespace dgl {
class ImmutableGraph; class ImmutableGraph;
/*! /*!
* \brief A NodeFlow graph stores the sampling results for a sampler that samples * \brief A NodeFlow graph stores the sampling results for a sampler that
* nodes/edges in layers. * samples nodes/edges in layers.
* *
* We store multiple layers of the sampling results in a single graph, which results * We store multiple layers of the sampling results in a single graph, which
* in a more compact format. We store extra information, * results in a more compact format. We store extra information, such as the
* such as the node and edge mapping from the NodeFlow graph to the parent graph. * node and edge mapping from the NodeFlow graph to the parent graph.
*/ */
struct NodeFlowObject : public runtime::Object { struct NodeFlowObject : public runtime::Object {
/*! \brief The graph. */ /*! \brief The graph. */
...@@ -45,7 +45,7 @@ struct NodeFlowObject : public runtime::Object { ...@@ -45,7 +45,7 @@ struct NodeFlowObject : public runtime::Object {
*/ */
IdArray edge_mapping; IdArray edge_mapping;
static constexpr const char* _type_key = "graph.NodeFlow"; static constexpr const char *_type_key = "graph.NodeFlow";
DGL_DECLARE_OBJECT_TYPE_INFO(NodeFlowObject, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(NodeFlowObject, runtime::Object);
}; };
...@@ -80,8 +80,8 @@ class NodeFlow : public runtime::ObjectRef { ...@@ -80,8 +80,8 @@ class NodeFlow : public runtime::ObjectRef {
* of an edge and the column represents the source. * of an edge and the column represents the source.
* *
* If fmt == "csr", the function returns three arrays: indptr, indices, eid. * If fmt == "csr", the function returns three arrays: indptr, indices, eid.
* If fmt == "coo", the function returns two arrays: idx, eid. Here, the idx array * If fmt == "coo", the function returns two arrays: idx, eid. Here, the idx
* is the concatenation of src and dst node id arrays. * array is the concatenation of src and dst node id arrays.
* *
* \param graph An immutable graph. * \param graph An immutable graph.
* \param fmt the format of the returned adjacency matrix. * \param fmt the format of the returned adjacency matrix.
...@@ -92,11 +92,10 @@ class NodeFlow : public runtime::ObjectRef { ...@@ -92,11 +92,10 @@ class NodeFlow : public runtime::ObjectRef {
* space. * space.
* \return a vector of IdArrays. * \return a vector of IdArrays.
*/ */
std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::string &fmt, std::vector<IdArray> GetNodeFlowSlice(
size_t layer0_size, size_t layer1_start, const ImmutableGraph &graph, const std::string &fmt, size_t layer0_size,
size_t layer1_end, bool remap); size_t layer1_start, size_t layer1_end, bool remap);
} // namespace dgl } // namespace dgl
#endif // DGL_NODEFLOW_H_ #endif // DGL_NODEFLOW_H_
...@@ -7,14 +7,14 @@ ...@@ -7,14 +7,14 @@
#ifndef DGL_PACKED_FUNC_EXT_H_ #ifndef DGL_PACKED_FUNC_EXT_H_
#define DGL_PACKED_FUNC_EXT_H_ #define DGL_PACKED_FUNC_EXT_H_
#include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <memory>
#include <type_traits> #include <type_traits>
#include "./runtime/packed_func.h"
#include "./runtime/object.h"
#include "./runtime/container.h" #include "./runtime/container.h"
#include "./runtime/object.h"
#include "./runtime/packed_func.h"
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
...@@ -22,7 +22,7 @@ namespace runtime { ...@@ -22,7 +22,7 @@ namespace runtime {
* \brief Runtime type checker for node type. * \brief Runtime type checker for node type.
* \tparam T the type to be checked. * \tparam T the type to be checked.
*/ */
template<typename T> template <typename T>
struct ObjectTypeChecker { struct ObjectTypeChecker {
static inline bool Check(Object* sptr) { static inline bool Check(Object* sptr) {
// This is the only place in the project where RTTI is used // This is the only place in the project where RTTI is used
...@@ -31,13 +31,13 @@ struct ObjectTypeChecker { ...@@ -31,13 +31,13 @@ struct ObjectTypeChecker {
using ContainerType = typename T::ContainerType; using ContainerType = typename T::ContainerType;
return sptr->derived_from<ContainerType>(); return sptr->derived_from<ContainerType>();
} }
static inline void PrintName(std::ostringstream& os) { // NOLINT(*) static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType; using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key; os << ContainerType::_type_key;
} }
}; };
template<typename T> template <typename T>
struct ObjectTypeChecker<List<T> > { struct ObjectTypeChecker<List<T> > {
static inline bool Check(Object* sptr) { static inline bool Check(Object* sptr) {
if (sptr == nullptr) return false; if (sptr == nullptr) return false;
...@@ -48,14 +48,14 @@ struct ObjectTypeChecker<List<T> > { ...@@ -48,14 +48,14 @@ struct ObjectTypeChecker<List<T> > {
} }
return true; return true;
} }
static inline void PrintName(std::ostringstream& os) { // NOLINT(*) static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "list<"; os << "list<";
ObjectTypeChecker<T>::PrintName(os); ObjectTypeChecker<T>::PrintName(os);
os << ">"; os << ">";
} }
}; };
template<typename V> template <typename V>
struct ObjectTypeChecker<Map<std::string, V> > { struct ObjectTypeChecker<Map<std::string, V> > {
static inline bool Check(Object* sptr) { static inline bool Check(Object* sptr) {
if (sptr == nullptr) return false; if (sptr == nullptr) return false;
...@@ -66,7 +66,7 @@ struct ObjectTypeChecker<Map<std::string, V> > { ...@@ -66,7 +66,7 @@ struct ObjectTypeChecker<Map<std::string, V> > {
} }
return true; return true;
} }
static inline void PrintName(std::ostringstream& os) { // NOLINT(*) static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<string"; os << "map<string";
os << ','; os << ',';
ObjectTypeChecker<V>::PrintName(os); ObjectTypeChecker<V>::PrintName(os);
...@@ -74,9 +74,7 @@ struct ObjectTypeChecker<Map<std::string, V> > { ...@@ -74,9 +74,7 @@ struct ObjectTypeChecker<Map<std::string, V> > {
} }
}; };
template <typename K, typename V>
template<typename K, typename V>
struct ObjectTypeChecker<Map<K, V> > { struct ObjectTypeChecker<Map<K, V> > {
static inline bool Check(Object* sptr) { static inline bool Check(Object* sptr) {
if (sptr == nullptr) return false; if (sptr == nullptr) return false;
...@@ -88,7 +86,7 @@ struct ObjectTypeChecker<Map<K, V> > { ...@@ -88,7 +86,7 @@ struct ObjectTypeChecker<Map<K, V> > {
} }
return true; return true;
} }
static inline void PrintName(std::ostringstream& os) { // NOLINT(*) static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<"; os << "map<";
ObjectTypeChecker<K>::PrintName(os); ObjectTypeChecker<K>::PrintName(os);
os << ','; os << ',';
...@@ -97,7 +95,7 @@ struct ObjectTypeChecker<Map<K, V> > { ...@@ -97,7 +95,7 @@ struct ObjectTypeChecker<Map<K, V> > {
} }
}; };
template<typename T> template <typename T>
inline std::string NodeTypeName() { inline std::string NodeTypeName() {
std::ostringstream os; std::ostringstream os;
ObjectTypeChecker<T>::PrintName(os); ObjectTypeChecker<T>::PrintName(os);
...@@ -106,7 +104,7 @@ inline std::string NodeTypeName() { ...@@ -106,7 +104,7 @@ inline std::string NodeTypeName() {
// extensions for DGLArgValue // extensions for DGLArgValue
template<typename TObjectRef> template <typename TObjectRef>
inline TObjectRef DGLArgValue::AsObjectRef() const { inline TObjectRef DGLArgValue::AsObjectRef() const {
static_assert( static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value, std::is_base_of<ObjectRef, TObjectRef>::value,
...@@ -115,8 +113,8 @@ inline TObjectRef DGLArgValue::AsObjectRef() const { ...@@ -115,8 +113,8 @@ inline TObjectRef DGLArgValue::AsObjectRef() const {
DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle); DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle);
std::shared_ptr<Object>& sptr = *ptr<std::shared_ptr<Object> >(); std::shared_ptr<Object>& sptr = *ptr<std::shared_ptr<Object> >();
CHECK(ObjectTypeChecker<TObjectRef>::Check(sptr.get())) CHECK(ObjectTypeChecker<TObjectRef>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<TObjectRef>() << "Expected type " << NodeTypeName<TObjectRef>() << " but get "
<< " but get " << sptr->type_key(); << sptr->type_key();
return TObjectRef(sptr); return TObjectRef(sptr);
} }
...@@ -125,12 +123,10 @@ inline std::shared_ptr<Object>& DGLArgValue::obj_sptr() { ...@@ -125,12 +123,10 @@ inline std::shared_ptr<Object>& DGLArgValue::obj_sptr() {
return *ptr<std::shared_ptr<Object> >(); return *ptr<std::shared_ptr<Object> >();
} }
template <typename TObjectRef, typename>
template<typename TObjectRef, typename>
inline bool DGLArgValue::IsObjectType() const { inline bool DGLArgValue::IsObjectType() const {
DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle); DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle);
std::shared_ptr<Object>& sptr = std::shared_ptr<Object>& sptr = *ptr<std::shared_ptr<Object> >();
*ptr<std::shared_ptr<Object> >();
return ObjectTypeChecker<TObjectRef>::Check(sptr.get()); return ObjectTypeChecker<TObjectRef>::Check(sptr.get());
} }
...@@ -155,7 +151,7 @@ inline DGLRetValue& DGLRetValue::operator=(const ObjectRef& other) { ...@@ -155,7 +151,7 @@ inline DGLRetValue& DGLRetValue::operator=(const ObjectRef& other) {
return *this; return *this;
} }
template<typename TObjectRef> template <typename TObjectRef>
inline TObjectRef DGLRetValue::AsObjectRef() const { inline TObjectRef DGLRetValue::AsObjectRef() const {
static_assert( static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value, std::is_base_of<ObjectRef, TObjectRef>::value,
...@@ -165,7 +161,8 @@ inline TObjectRef DGLRetValue::AsObjectRef() const { ...@@ -165,7 +161,8 @@ inline TObjectRef DGLRetValue::AsObjectRef() const {
return TObjectRef(*ptr<std::shared_ptr<Object> >()); return TObjectRef(*ptr<std::shared_ptr<Object> >());
} }
inline void DGLArgsSetter::operator()(size_t i, const ObjectRef& other) const { // NOLINT(*) inline void DGLArgsSetter::operator()(
size_t i, const ObjectRef& other) const { // NOLINT(*)
if (other.defined()) { if (other.defined()) {
values_[i].v_handle = const_cast<std::shared_ptr<Object>*>(&(other.obj_)); values_[i].v_handle = const_cast<std::shared_ptr<Object>*>(&(other.obj_));
type_codes_[i] = kObjectHandle; type_codes_[i] = kObjectHandle;
......
...@@ -8,8 +8,9 @@ ...@@ -8,8 +8,9 @@
#define DGL_RANDOM_H_ #define DGL_RANDOM_H_
#include <dgl/array.h> #include <dgl/array.h>
#include <dmlc/thread_local.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <random> #include <random>
#include <thread> #include <thread>
#include <vector> #include <vector>
...@@ -46,33 +47,27 @@ class RandomEngine { ...@@ -46,33 +47,27 @@ class RandomEngine {
} }
/*! \brief Constructor with given seed */ /*! \brief Constructor with given seed */
explicit RandomEngine(uint32_t seed) { explicit RandomEngine(uint32_t seed) { SetSeed(seed); }
SetSeed(seed);
}
/*! \brief Get the thread-local random number generator instance */ /*! \brief Get the thread-local random number generator instance */
static RandomEngine *ThreadLocal() { static RandomEngine* ThreadLocal() {
return dmlc::ThreadLocalStore<RandomEngine>::Get(); return dmlc::ThreadLocalStore<RandomEngine>::Get();
} }
/*! /*!
* \brief Set the seed of this random number generator * \brief Set the seed of this random number generator
*/ */
void SetSeed(uint32_t seed) { void SetSeed(uint32_t seed) { rng_.seed(seed + GetThreadId()); }
rng_.seed(seed + GetThreadId());
}
/*! /*!
* \brief Generate an arbitrary random 32-bit integer. * \brief Generate an arbitrary random 32-bit integer.
*/ */
int32_t RandInt32() { int32_t RandInt32() { return static_cast<int32_t>(rng_()); }
return static_cast<int32_t>(rng_());
}
/*! /*!
* \brief Generate a uniform random integer in [0, upper) * \brief Generate a uniform random integer in [0, upper)
*/ */
template<typename T> template <typename T>
T RandInt(T upper) { T RandInt(T upper) {
return RandInt<T>(0, upper); return RandInt<T>(0, upper);
} }
...@@ -80,7 +75,7 @@ class RandomEngine { ...@@ -80,7 +75,7 @@ class RandomEngine {
/*! /*!
* \brief Generate a uniform random integer in [lower, upper) * \brief Generate a uniform random integer in [lower, upper)
*/ */
template<typename T> template <typename T>
T RandInt(T lower, T upper) { T RandInt(T lower, T upper) {
CHECK_LT(lower, upper); CHECK_LT(lower, upper);
std::uniform_int_distribution<T> dist(lower, upper - 1); std::uniform_int_distribution<T> dist(lower, upper - 1);
...@@ -90,7 +85,7 @@ class RandomEngine { ...@@ -90,7 +85,7 @@ class RandomEngine {
/*! /*!
* \brief Generate a uniform random float in [0, 1) * \brief Generate a uniform random float in [0, 1)
*/ */
template<typename T> template <typename T>
T Uniform() { T Uniform() {
return Uniform<T>(0., 1.); return Uniform<T>(0., 1.);
} }
...@@ -98,7 +93,7 @@ class RandomEngine { ...@@ -98,7 +93,7 @@ class RandomEngine {
/*! /*!
* \brief Generate a uniform random float in [lower, upper) * \brief Generate a uniform random float in [lower, upper)
*/ */
template<typename T> template <typename T>
T Uniform(T lower, T upper) { T Uniform(T lower, T upper) {
// Although the result is in [lower, upper), we allow lower == upper as in // Although the result is in [lower, upper), we allow lower == upper as in
// www.cplusplus.com/reference/random/uniform_real_distribution/uniform_real_distribution/ // www.cplusplus.com/reference/random/uniform_real_distribution/uniform_real_distribution/
...@@ -108,23 +103,27 @@ class RandomEngine { ...@@ -108,23 +103,27 @@ class RandomEngine {
} }
/*! /*!
* \brief Pick a random integer between 0 to N-1 according to given probabilities * \brief Pick a random integer between 0 to N-1 according to given
* \tparam IdxType Return integer type * probabilities.
* \param prob Array of N unnormalized probability of each element. Must be non-negative. * \tparam IdxType Return integer type.
* \param prob Array of N unnormalized probability of each element. Must be
* non-negative.
* \return An integer randomly picked from 0 to N-1. * \return An integer randomly picked from 0 to N-1.
*/ */
template<typename IdxType> template <typename IdxType>
IdxType Choice(FloatArray prob); IdxType Choice(FloatArray prob);
/*! /*!
* \brief Pick random integers between 0 to N-1 according to given probabilities * \brief Pick random integers between 0 to N-1 according to given
* * probabilities
*
* If replace is false, the number of picked integers must not larger than N. * If replace is false, the number of picked integers must not larger than N.
* *
* \tparam IdxType Id type * \tparam IdxType Id type
* \tparam FloatType Probability value type * \tparam FloatType Probability value type
* \param num Number of integers to choose * \param num Number of integers to choose
* \param prob Array of N unnormalized probability of each element. Must be non-negative. * \param prob Array of N unnormalized probability of each element. Must be
* non-negative.
* \param out The output buffer to write selected indices. * \param out The output buffer to write selected indices.
* \param replace If true, choose with replacement. * \param replace If true, choose with replacement.
*/ */
...@@ -132,14 +131,16 @@ class RandomEngine { ...@@ -132,14 +131,16 @@ class RandomEngine {
void Choice(IdxType num, FloatArray prob, IdxType* out, bool replace = true); void Choice(IdxType num, FloatArray prob, IdxType* out, bool replace = true);
/*! /*!
* \brief Pick random integers between 0 to N-1 according to given probabilities * \brief Pick random integers between 0 to N-1 according to given
* * probabilities
*
* If replace is false, the number of picked integers must not larger than N. * If replace is false, the number of picked integers must not larger than N.
* *
* \tparam IdxType Id type * \tparam IdxType Id type
* \tparam FloatType Probability value type * \tparam FloatType Probability value type
* \param num Number of integers to choose * \param num Number of integers to choose
* \param prob Array of N unnormalized probability of each element. Must be non-negative. * \param prob Array of N unnormalized probability of each element. Must be
* non-negative.
* \param replace If true, choose with replacement. * \param replace If true, choose with replacement.
* \return Picked indices * \return Picked indices
*/ */
...@@ -147,7 +148,8 @@ class RandomEngine { ...@@ -147,7 +148,8 @@ class RandomEngine {
IdArray Choice(IdxType num, FloatArray prob, bool replace = true) { IdArray Choice(IdxType num, FloatArray prob, bool replace = true) {
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1}; const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, prob->ctx); IdArray ret = IdArray::Empty({num}, dtype, prob->ctx);
Choice<IdxType, FloatType>(num, prob, static_cast<IdxType*>(ret->data), replace); Choice<IdxType, FloatType>(
num, prob, static_cast<IdxType*>(ret->data), replace);
return ret; return ret;
} }
...@@ -163,7 +165,8 @@ class RandomEngine { ...@@ -163,7 +165,8 @@ class RandomEngine {
* \param replace If true, choose with replacement. * \param replace If true, choose with replacement.
*/ */
template <typename IdxType> template <typename IdxType>
void UniformChoice(IdxType num, IdxType population, IdxType* out, bool replace = true); void UniformChoice(
IdxType num, IdxType population, IdxType* out, bool replace = true);
/*! /*!
* \brief Pick random integers from population by uniform distribution. * \brief Pick random integers from population by uniform distribution.
...@@ -181,43 +184,48 @@ class RandomEngine { ...@@ -181,43 +184,48 @@ class RandomEngine {
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1}; const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
// TODO(minjie): only CPU implementation right now // TODO(minjie): only CPU implementation right now
IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0}); IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0});
UniformChoice<IdxType>(num, population, static_cast<IdxType*>(ret->data), replace); UniformChoice<IdxType>(
num, population, static_cast<IdxType*>(ret->data), replace);
return ret; return ret;
} }
/*! /*!
* \brief Pick random integers with different probability for different segments. * \brief Pick random integers with different probability for different
* segments.
* *
* For example, if split=[0, 4, 10] and bias=[1.5, 1], it means to pick some integers * For example, if split=[0, 4, 10] and bias=[1.5, 1], it means to pick some
* from 0 to 9, which is divided into two segments. 0-3 are in the first segment and the rest * integers from 0 to 9, which is divided into two segments. 0-3 are in the
* belongs to the second. The weight(bias) of each candidate in the first segment is upweighted * first segment and the rest belongs to the second. The weight(bias) of each
* to 1.5. * candidate in the first segment is upweighted to 1.5.
* *
* candidate | 0 1 2 3 | 4 5 6 7 8 9 | * candidate | 0 1 2 3 | 4 5 6 7 8 9 |
* split ^ ^ ^ * split ^ ^ ^
* bias | 1.5 | 1 | * bias | 1.5 | 1 |
* *
* *
* The complexity of this operator is O(k * log(T)) where k is the number of integers we want * The complexity of this operator is O(k * log(T)) where k is the number of
* to pick, and T is the number of segments. It is much faster compared with assigning * integers we want to pick, and T is the number of segments. It is much
* probability for each candidate, of which the complexity is O(k * log(N)) where N is the * faster compared with assigning probability for each candidate, of which the
* number of all candidates. * complexity is O(k * log(N)) where N is the number of all candidates.
* *
* If replace is false, num must not be larger than population. * If replace is false, num must not be larger than population.
* *
* \tparam IdxType Return integer type * \tparam IdxType Return integer type
* \param num Number of integers to choose * \param num Number of integers to choose
* \param split Array of T+1 split positions of different segments(including start and end) * \param split Array of T+1 split positions of different segments(including
* \param bias Array of T weight of each segments * start and end)
* \param bias Array of T weight of each segments.
* \param out The output buffer to write selected indices. * \param out The output buffer to write selected indices.
* \param replace If true, choose with replacement. * \param replace If true, choose with replacement.
*/ */
template <typename IdxType, typename FloatType> template <typename IdxType, typename FloatType>
void BiasedChoice( void BiasedChoice(
IdxType num, const IdxType *split, FloatArray bias, IdxType* out, bool replace = true); IdxType num, const IdxType* split, FloatArray bias, IdxType* out,
bool replace = true);
/*! /*!
* \brief Pick random integers with different probability for different segments. * \brief Pick random integers with different probability for different
* segments.
* *
* If replace is false, num must not be larger than population. * If replace is false, num must not be larger than population.
* *
...@@ -229,10 +237,11 @@ class RandomEngine { ...@@ -229,10 +237,11 @@ class RandomEngine {
*/ */
template <typename IdxType, typename FloatType> template <typename IdxType, typename FloatType>
IdArray BiasedChoice( IdArray BiasedChoice(
IdxType num, const IdxType *split, FloatArray bias, bool replace = true) { IdxType num, const IdxType* split, FloatArray bias, bool replace = true) {
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1}; const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0}); IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0});
BiasedChoice<IdxType, FloatType>(num, split, bias, static_cast<IdxType*>(ret->data), replace); BiasedChoice<IdxType, FloatType>(
num, split, bias, static_cast<IdxType*>(ret->data), replace);
return ret; return ret;
} }
......
...@@ -27,9 +27,8 @@ extern "C" { ...@@ -27,9 +27,8 @@ extern "C" {
* \param out The result function. * \param out The result function.
* \return 0 when no error is thrown, -1 when failure happens * \return 0 when no error is thrown, -1 when failure happens
*/ */
DGL_DLL int DGLBackendGetFuncFromEnv(void* mod_node, DGL_DLL int DGLBackendGetFuncFromEnv(
const char* func_name, void* mod_node, const char* func_name, DGLFunctionHandle* out);
DGLFunctionHandle *out);
/*! /*!
* \brief Backend function to register system-wide library symbol. * \brief Backend function to register system-wide library symbol.
* *
...@@ -42,22 +41,21 @@ DGL_DLL int DGLBackendRegisterSystemLibSymbol(const char* name, void* ptr); ...@@ -42,22 +41,21 @@ DGL_DLL int DGLBackendRegisterSystemLibSymbol(const char* name, void* ptr);
/*! /*!
* \brief Backend function to allocate temporal workspace. * \brief Backend function to allocate temporal workspace.
* *
* \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment. * \note The result allocate spaced is ensured to be aligned to
* kTempAllocaAlignment.
* *
* \param nbytes The size of the space requested. * \param nbytes The size of the space requested.
* \param device_type The device type which the space will be allocated. * \param device_type The device type which the space will be allocated.
* \param device_id The device id which the space will be allocated. * \param device_id The device id which the space will be allocated.
* \param dtype_code_hint The type code of the array elements. Only used in * \param dtype_code_hint The type code of the array elements. Only used in
* certain backends such as OpenGL. * certain backends such as OpenGL.
* \param dtype_bits_hint The type bits of the array elements. Only used in * \param dtype_bits_hint The type bits of the array elements. Only used in
* certain backends such as OpenGL. * certain backends such as OpenGL.
* \return nullptr when error is thrown, a valid ptr if success * \return nullptr when error is thrown, a valid ptr if success
*/ */
DGL_DLL void* DGLBackendAllocWorkspace(int device_type, DGL_DLL void* DGLBackendAllocWorkspace(
int device_id, int device_type, int device_id, uint64_t nbytes, int dtype_code_hint,
uint64_t nbytes, int dtype_bits_hint);
int dtype_code_hint,
int dtype_bits_hint);
/*! /*!
* \brief Backend function to free temporal workspace. * \brief Backend function to free temporal workspace.
...@@ -69,9 +67,7 @@ DGL_DLL void* DGLBackendAllocWorkspace(int device_type, ...@@ -69,9 +67,7 @@ DGL_DLL void* DGLBackendAllocWorkspace(int device_type,
* *
* \sa DGLBackendAllocWorkspace * \sa DGLBackendAllocWorkspace
*/ */
DGL_DLL int DGLBackendFreeWorkspace(int device_type, DGL_DLL int DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr);
int device_id,
void* ptr);
/*! /*!
* \brief Environment for DGL parallel task. * \brief Environment for DGL parallel task.
...@@ -100,13 +96,12 @@ typedef int (*FDGLParallelLambda)( ...@@ -100,13 +96,12 @@ typedef int (*FDGLParallelLambda)(
* \param flambda The parallel function to be launched. * \param flambda The parallel function to be launched.
* \param cdata The closure data. * \param cdata The closure data.
* \param num_task Number of tasks to launch, can be 0, means launch * \param num_task Number of tasks to launch, can be 0, means launch
* with all available threads. * with all available threads.
* *
* \return 0 when no error is thrown, -1 when failure happens * \return 0 when no error is thrown, -1 when failure happens
*/ */
DGL_DLL int DGLBackendParallelLaunch(FDGLParallelLambda flambda, DGL_DLL int DGLBackendParallelLaunch(
void* cdata, FDGLParallelLambda flambda, void* cdata, int num_task);
int num_task);
/*! /*!
* \brief BSP barrrier between parallel threads * \brief BSP barrrier between parallel threads
...@@ -116,7 +111,6 @@ DGL_DLL int DGLBackendParallelLaunch(FDGLParallelLambda flambda, ...@@ -116,7 +111,6 @@ DGL_DLL int DGLBackendParallelLaunch(FDGLParallelLambda flambda,
*/ */
DGL_DLL int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv); DGL_DLL int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv);
/*! /*!
* \brief Simple static initialization fucntion. * \brief Simple static initialization fucntion.
* Run f once and set handle to be not null. * Run f once and set handle to be not null.
...@@ -128,10 +122,8 @@ DGL_DLL int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv); ...@@ -128,10 +122,8 @@ DGL_DLL int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv);
* \param nbytes Number of bytes in the closure data. * \param nbytes Number of bytes in the closure data.
* \return 0 when no error is thrown, -1 when failure happens * \return 0 when no error is thrown, -1 when failure happens
*/ */
DGL_DLL int DGLBackendRunOnce(void** handle, DGL_DLL int DGLBackendRunOnce(
int (*f)(void*), void** handle, int (*f)(void*), void* cdata, int nbytes);
void *cdata,
int nbytes);
#ifdef __cplusplus #ifdef __cplusplus
} // DGL_EXTERN_C } // DGL_EXTERN_C
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment