"csrc/git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "fff381c54204ae9b4e61fe4df7ee5211366a1304"
Unverified Commit 07dc8fb6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4800)



* clang-format

* manul

* manul

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 9d9280cb
...@@ -15,9 +15,9 @@ namespace featgraph { ...@@ -15,9 +15,9 @@ namespace featgraph {
void LoadFeatGraphModule(const std::string& path); void LoadFeatGraphModule(const std::string& path);
/* \brief Call Featgraph's SDDMM kernel. */ /* \brief Call Featgraph's SDDMM kernel. */
void SDDMMTreeReduction(DLManagedTensor* row, DLManagedTensor* col, void SDDMMTreeReduction(
DLManagedTensor* lhs, DLManagedTensor* rhs, DLManagedTensor* row, DLManagedTensor* col, DLManagedTensor* lhs,
DLManagedTensor* out); DLManagedTensor* rhs, DLManagedTensor* out);
} // namespace featgraph } // namespace featgraph
} // namespace dgl } // namespace dgl
......
...@@ -3,18 +3,18 @@ ...@@ -3,18 +3,18 @@
* \file featgraph/src/featgraph.cc * \file featgraph/src/featgraph.cc
* \brief FeatGraph kernels. * \brief FeatGraph kernels.
*/ */
#include <dmlc/logging.h>
#include <featgraph.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <dmlc/logging.h>
#include <featgraph.h>
namespace dgl { namespace dgl {
namespace featgraph { namespace featgraph {
/* \brief Singleton that loads the featgraph module. */ /* \brief Singleton that loads the featgraph module. */
class FeatGraphModule { class FeatGraphModule {
public: public:
static FeatGraphModule* Global() { static FeatGraphModule* Global() {
static FeatGraphModule inst; static FeatGraphModule inst;
return &inst; return &inst;
...@@ -32,7 +32,8 @@ public: ...@@ -32,7 +32,8 @@ public:
} }
return ret; return ret;
} }
private:
private:
tvm::runtime::Module mod; tvm::runtime::Module mod;
FeatGraphModule() {} FeatGraphModule() {}
}; };
...@@ -44,34 +45,36 @@ void LoadFeatGraphModule(const std::string& path) { ...@@ -44,34 +45,36 @@ void LoadFeatGraphModule(const std::string& path) {
/* \brief Convert DLDataType to string. */ /* \brief Convert DLDataType to string. */
inline std::string DTypeAsStr(const DLDataType& t) { inline std::string DTypeAsStr(const DLDataType& t) {
switch(t.code) { switch (t.code) {
case 0U: return "int" + std::to_string(t.bits); case 0U:
case 1U: return "uint" + std::to_string(t.bits); return "int" + std::to_string(t.bits);
case 2U: return "float" + std::to_string(t.bits); case 1U:
case 3U: return "bfloat" + std::to_string(t.bits); return "uint" + std::to_string(t.bits);
default: LOG(FATAL) << "Type code " << t.code << " not recognized"; case 2U:
return "float" + std::to_string(t.bits);
case 3U:
return "bfloat" + std::to_string(t.bits);
default:
LOG(FATAL) << "Type code " << t.code << " not recognized";
} }
} }
/* \brief Get operator filename. */ /* \brief Get operator filename. */
inline std::string GetOperatorName( inline std::string GetOperatorName(
const std::string& base_name, const std::string& base_name, const DLDataType& dtype,
const DLDataType& dtype,
const DLDataType& idtype) { const DLDataType& idtype) {
return base_name + "_" + DTypeAsStr(dtype) + "_" + DTypeAsStr(idtype); return base_name + "_" + DTypeAsStr(dtype) + "_" + DTypeAsStr(idtype);
} }
/* \brief Call FeatGraph's SDDMM kernel. */ /* \brief Call FeatGraph's SDDMM kernel. */
void SDDMMTreeReduction(DLManagedTensor* row, DLManagedTensor* col, void SDDMMTreeReduction(
DLManagedTensor* lhs, DLManagedTensor* rhs, DLManagedTensor* row, DLManagedTensor* col, DLManagedTensor* lhs,
DLManagedTensor* out) { DLManagedTensor* rhs, DLManagedTensor* out) {
tvm::runtime::ModuleNode* mod = FeatGraphModule::Global()->Get(); tvm::runtime::ModuleNode* mod = FeatGraphModule::Global()->Get();
std::string f_name = GetOperatorName("SDDMMTreeReduction", std::string f_name = GetOperatorName(
(row->dl_tensor).dtype, "SDDMMTreeReduction", (row->dl_tensor).dtype, (lhs->dl_tensor).dtype);
(lhs->dl_tensor).dtype);
tvm::runtime::PackedFunc f = mod->GetFunction(f_name); tvm::runtime::PackedFunc f = mod->GetFunction(f_name);
if (f != nullptr) if (f != nullptr) f(row, col, lhs, rhs, out);
f(row, col, lhs, rhs, out);
} }
} // namespace featgraph } // namespace featgraph
......
/* /*
* NOTE(zihao): this file was modified from TVM project: * NOTE(zihao): this file was modified from TVM project:
* - https://github.com/apache/tvm/blob/9713d675c64ae3075e10be5acadeef1328a44bb5/apps/howto_deploy/tvm_runtime_pack.cc * -
* * https://github.com/apache/tvm/blob/9713d675c64ae3075e10be5acadeef1328a44bb5/apps/howto_deploy/tvm_runtime_pack.cc
*
* Licensed to the Apache Software Foundation (ASF) under one * Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file * or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information * distributed with this work for additional information
......
...@@ -8,10 +8,10 @@ ...@@ -8,10 +8,10 @@
*/ */
#ifndef DGL_ARRAY_H_ #ifndef DGL_ARRAY_H_
#define DGL_ARRAY_H_ #define DGL_ARRAY_H_
#include "./aten/types.h"
#include "./aten/array_ops.h" #include "./aten/array_ops.h"
#include "./aten/coo.h"
#include "./aten/csr.h"
#include "./aten/macro.h" #include "./aten/macro.h"
#include "./aten/spmat.h" #include "./aten/spmat.h"
#include "./aten/csr.h" #include "./aten/types.h"
#include "./aten/coo.h"
#endif // DGL_ARRAY_H_ #endif // DGL_ARRAY_H_
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
#define CUB_INLINE inline #define CUB_INLINE inline
#endif // __CUDA_ARCH__ #endif // __CUDA_ARCH__
#include <iterator>
#include <algorithm> #include <algorithm>
#include <iterator>
#include <utility> #include <utility>
namespace dgl { namespace dgl {
...@@ -51,8 +51,8 @@ CUB_INLINE void swap(const Pair<DType>& r1, const Pair<DType>& r2) { ...@@ -51,8 +51,8 @@ CUB_INLINE void swap(const Pair<DType>& r1, const Pair<DType>& r2) {
r1.swap(r2); r1.swap(r2);
} }
// PairRef and PairIterator that serves as an iterator over a pair of arrays in a // PairRef and PairIterator that serves as an iterator over a pair of arrays in
// zipped fashion like zip(a, b). // a zipped fashion like zip(a, b).
template <typename DType> template <typename DType>
struct PairRef { struct PairRef {
PairRef() = delete; PairRef() = delete;
...@@ -69,9 +69,7 @@ struct PairRef { ...@@ -69,9 +69,7 @@ struct PairRef {
*b = val.second; *b = val.second;
return *this; return *this;
} }
CUB_INLINE operator Pair<DType>() const { CUB_INLINE operator Pair<DType>() const { return Pair<DType>(*a, *b); }
return Pair<DType>(*a, *b);
}
CUB_INLINE operator std::pair<DType, DType>() const { CUB_INLINE operator std::pair<DType, DType>() const {
return std::make_pair(*a, *b); return std::make_pair(*a, *b);
} }
...@@ -91,11 +89,9 @@ CUB_INLINE void swap(const PairRef<DType>& r1, const PairRef<DType>& r2) { ...@@ -91,11 +89,9 @@ CUB_INLINE void swap(const PairRef<DType>& r1, const PairRef<DType>& r2) {
} }
template <typename DType> template <typename DType>
struct PairIterator : public std::iterator<std::random_access_iterator_tag, struct PairIterator : public std::iterator<
Pair<DType>, std::random_access_iterator_tag, Pair<DType>,
std::ptrdiff_t, std::ptrdiff_t, Pair<DType*>, PairRef<DType>> {
Pair<DType*>,
PairRef<DType>> {
PairIterator() = default; PairIterator() = default;
PairIterator(const PairIterator& other) = default; PairIterator(const PairIterator& other) = default;
PairIterator(PairIterator&& other) = default; PairIterator(PairIterator&& other) = default;
...@@ -103,12 +99,24 @@ struct PairIterator : public std::iterator<std::random_access_iterator_tag, ...@@ -103,12 +99,24 @@ struct PairIterator : public std::iterator<std::random_access_iterator_tag,
PairIterator& operator=(const PairIterator& other) = default; PairIterator& operator=(const PairIterator& other) = default;
PairIterator& operator=(PairIterator&& other) = default; PairIterator& operator=(PairIterator&& other) = default;
~PairIterator() = default; ~PairIterator() = default;
CUB_INLINE bool operator==(const PairIterator& other) const { return a == other.a; } CUB_INLINE bool operator==(const PairIterator& other) const {
CUB_INLINE bool operator!=(const PairIterator& other) const { return a != other.a; } return a == other.a;
CUB_INLINE bool operator<(const PairIterator& other) const { return a < other.a; } }
CUB_INLINE bool operator>(const PairIterator& other) const { return a > other.a; } CUB_INLINE bool operator!=(const PairIterator& other) const {
CUB_INLINE bool operator<=(const PairIterator& other) const { return a <= other.a; } return a != other.a;
CUB_INLINE bool operator>=(const PairIterator& other) const { return a >= other.a; } }
CUB_INLINE bool operator<(const PairIterator& other) const {
return a < other.a;
}
CUB_INLINE bool operator>(const PairIterator& other) const {
return a > other.a;
}
CUB_INLINE bool operator<=(const PairIterator& other) const {
return a <= other.a;
}
CUB_INLINE bool operator>=(const PairIterator& other) const {
return a >= other.a;
}
CUB_INLINE PairIterator& operator+=(const std::ptrdiff_t& movement) { CUB_INLINE PairIterator& operator+=(const std::ptrdiff_t& movement) {
a += movement; a += movement;
b += movement; b += movement;
...@@ -148,12 +156,8 @@ struct PairIterator : public std::iterator<std::random_access_iterator_tag, ...@@ -148,12 +156,8 @@ struct PairIterator : public std::iterator<std::random_access_iterator_tag,
CUB_INLINE std::ptrdiff_t operator-(const PairIterator& other) const { CUB_INLINE std::ptrdiff_t operator-(const PairIterator& other) const {
return a - other.a; return a - other.a;
} }
CUB_INLINE PairRef<DType> operator*() const { CUB_INLINE PairRef<DType> operator*() const { return PairRef<DType>(a, b); }
return PairRef<DType>(a, b); CUB_INLINE PairRef<DType> operator*() { return PairRef<DType>(a, b); }
}
CUB_INLINE PairRef<DType> operator*() {
return PairRef<DType>(a, b);
}
CUB_INLINE PairRef<DType> operator[](size_t offset) const { CUB_INLINE PairRef<DType> operator[](size_t offset) const {
return PairRef<DType>(a + offset, b + offset); return PairRef<DType>(a + offset, b + offset);
} }
......
...@@ -8,8 +8,9 @@ ...@@ -8,8 +8,9 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "./types.h"
#include "../runtime/object.h" #include "../runtime/object.h"
#include "./types.h"
namespace dgl { namespace dgl {
...@@ -56,31 +57,28 @@ inline std::string ToStringSparseFormat(SparseFormat sparse_format) { ...@@ -56,31 +57,28 @@ inline std::string ToStringSparseFormat(SparseFormat sparse_format) {
inline std::vector<SparseFormat> CodeToSparseFormats(dgl_format_code_t code) { inline std::vector<SparseFormat> CodeToSparseFormats(dgl_format_code_t code) {
std::vector<SparseFormat> ret; std::vector<SparseFormat> ret;
if (code & COO_CODE) if (code & COO_CODE) ret.push_back(SparseFormat::kCOO);
ret.push_back(SparseFormat::kCOO); if (code & CSR_CODE) ret.push_back(SparseFormat::kCSR);
if (code & CSR_CODE) if (code & CSC_CODE) ret.push_back(SparseFormat::kCSC);
ret.push_back(SparseFormat::kCSR);
if (code & CSC_CODE)
ret.push_back(SparseFormat::kCSC);
return ret; return ret;
} }
inline dgl_format_code_t inline dgl_format_code_t SparseFormatsToCode(
SparseFormatsToCode(const std::vector<SparseFormat> &formats) { const std::vector<SparseFormat>& formats) {
dgl_format_code_t ret = 0; dgl_format_code_t ret = 0;
for (auto format : formats) { for (auto format : formats) {
switch (format) { switch (format) {
case SparseFormat::kCOO: case SparseFormat::kCOO:
ret |= COO_CODE; ret |= COO_CODE;
break; break;
case SparseFormat::kCSR: case SparseFormat::kCSR:
ret |= CSR_CODE; ret |= CSR_CODE;
break; break;
case SparseFormat::kCSC: case SparseFormat::kCSC:
ret |= CSC_CODE; ret |= CSC_CODE;
break; break;
default: default:
LOG(FATAL) << "Only support COO/CSR/CSC formats."; LOG(FATAL) << "Only support COO/CSR/CSC formats.";
} }
} }
return ret; return ret;
...@@ -88,20 +86,15 @@ SparseFormatsToCode(const std::vector<SparseFormat> &formats) { ...@@ -88,20 +86,15 @@ SparseFormatsToCode(const std::vector<SparseFormat> &formats) {
inline std::string CodeToStr(dgl_format_code_t code) { inline std::string CodeToStr(dgl_format_code_t code) {
std::string ret = ""; std::string ret = "";
if (code & COO_CODE) if (code & COO_CODE) ret += "coo ";
ret += "coo "; if (code & CSR_CODE) ret += "csr ";
if (code & CSR_CODE) if (code & CSC_CODE) ret += "csc ";
ret += "csr ";
if (code & CSC_CODE)
ret += "csc ";
return ret; return ret;
} }
inline SparseFormat DecodeFormat(dgl_format_code_t code) { inline SparseFormat DecodeFormat(dgl_format_code_t code) {
if (code & COO_CODE) if (code & COO_CODE) return SparseFormat::kCOO;
return SparseFormat::kCOO; if (code & CSC_CODE) return SparseFormat::kCSC;
if (code & CSC_CODE)
return SparseFormat::kCSC;
return SparseFormat::kCSR; return SparseFormat::kCSR;
} }
...@@ -113,20 +106,25 @@ struct SparseMatrix : public runtime::Object { ...@@ -113,20 +106,25 @@ struct SparseMatrix : public runtime::Object {
// Shape of this matrix. // Shape of this matrix.
int64_t num_rows = 0, num_cols = 0; int64_t num_rows = 0, num_cols = 0;
// Index arrays. For CSR, it is {indptr, indices, data}. For COO, it is {row, col, data}. // Index arrays. For CSR, it is {indptr, indices, data}. For COO, it is {row,
// col, data}.
std::vector<IdArray> indices; std::vector<IdArray> indices;
// Boolean flags. // Boolean flags.
// TODO(minjie): We might revisit this later to provide a more general solution. Currently, // TODO(minjie): We might revisit this later to provide a more general
// we only consider aten::COOMatrix and aten::CSRMatrix. // solution. Currently, we only consider aten::COOMatrix and aten::CSRMatrix.
std::vector<bool> flags; std::vector<bool> flags;
SparseMatrix() {} SparseMatrix() {}
SparseMatrix(int32_t fmt, int64_t nrows, int64_t ncols, SparseMatrix(
const std::vector<IdArray>& idx, int32_t fmt, int64_t nrows, int64_t ncols,
const std::vector<bool>& flg) const std::vector<IdArray>& idx, const std::vector<bool>& flg)
: format(fmt), num_rows(nrows), num_cols(ncols), indices(idx), flags(flg) {} : format(fmt),
num_rows(nrows),
num_cols(ncols),
indices(idx),
flags(flg) {}
static constexpr const char* _type_key = "aten.SparseMatrix"; static constexpr const char* _type_key = "aten.SparseMatrix";
DGL_DECLARE_OBJECT_TYPE_INFO(SparseMatrix, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(SparseMatrix, runtime::Object);
......
/*! /*!
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* \file dgl/aten/types.h * \file dgl/aten/types.h
* \brief Array and ID types * \brief Array and ID types
*/ */
#ifndef DGL_ATEN_TYPES_H_ #ifndef DGL_ATEN_TYPES_H_
#define DGL_ATEN_TYPES_H_ #define DGL_ATEN_TYPES_H_
#include <cstdint> #include <cstdint>
#include "../runtime/ndarray.h" #include "../runtime/ndarray.h"
namespace dgl { namespace dgl {
...@@ -15,7 +16,7 @@ typedef uint64_t dgl_id_t; ...@@ -15,7 +16,7 @@ typedef uint64_t dgl_id_t;
typedef uint64_t dgl_type_t; typedef uint64_t dgl_type_t;
/*! \brief Type for dgl fomrat code, whose binary representation indices /*! \brief Type for dgl fomrat code, whose binary representation indices
* which sparse format is in use and which is not. * which sparse format is in use and which is not.
* *
* Suppose the binary representation is xyz, then * Suppose the binary representation is xyz, then
* - x indicates whether csc is in use (1 for true and 0 for false). * - x indicates whether csc is in use (1 for true and 0 for false).
* - y indicates whether csr is in use. * - y indicates whether csr is in use.
......
...@@ -7,17 +7,17 @@ ...@@ -7,17 +7,17 @@
#ifndef DGL_BASE_HETEROGRAPH_H_ #ifndef DGL_BASE_HETEROGRAPH_H_
#define DGL_BASE_HETEROGRAPH_H_ #define DGL_BASE_HETEROGRAPH_H_
#include <string>
#include <vector>
#include <utility>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <string>
#include <utility>
#include <vector>
#include "./runtime/object.h" #include "./runtime/object.h"
#include "array.h"
#include "aten/spmat.h" #include "aten/spmat.h"
#include "aten/types.h" #include "aten/types.h"
#include "graph_interface.h" #include "graph_interface.h"
#include "array.h"
namespace dgl { namespace dgl {
...@@ -52,31 +52,26 @@ enum class EdgeDir { ...@@ -52,31 +52,26 @@ enum class EdgeDir {
*/ */
class BaseHeteroGraph : public runtime::Object { class BaseHeteroGraph : public runtime::Object {
public: public:
explicit BaseHeteroGraph(GraphPtr meta_graph): meta_graph_(meta_graph) {} explicit BaseHeteroGraph(GraphPtr meta_graph) : meta_graph_(meta_graph) {}
virtual ~BaseHeteroGraph() = default; virtual ~BaseHeteroGraph() = default;
////////////////////////// query/operations on meta graph //////////////////////// ////////////////////// query/operations on meta graph ///////////////////////
/*! \return the number of vertex types */ /*! \return the number of vertex types */
virtual uint64_t NumVertexTypes() const { virtual uint64_t NumVertexTypes() const { return meta_graph_->NumVertices(); }
return meta_graph_->NumVertices();
}
/*! \return the number of edge types */ /*! \return the number of edge types */
virtual uint64_t NumEdgeTypes() const { virtual uint64_t NumEdgeTypes() const { return meta_graph_->NumEdges(); }
return meta_graph_->NumEdges();
}
/*! \return given the edge type, find the source type */ /*! \return given the edge type, find the source type */
virtual std::pair<dgl_type_t, dgl_type_t> GetEndpointTypes(dgl_type_t etype) const { virtual std::pair<dgl_type_t, dgl_type_t> GetEndpointTypes(
dgl_type_t etype) const {
return meta_graph_->FindEdge(etype); return meta_graph_->FindEdge(etype);
} }
/*! \return the meta graph */ /*! \return the meta graph */
virtual GraphPtr meta_graph() const { virtual GraphPtr meta_graph() const { return meta_graph_; }
return meta_graph_;
}
/*! /*!
* \brief Return the bipartite graph of the given edge type. * \brief Return the bipartite graph of the given edge type.
...@@ -85,7 +80,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -85,7 +80,7 @@ class BaseHeteroGraph : public runtime::Object {
*/ */
virtual HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const = 0; virtual HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const = 0;
////////////////////////// query/operations on realized graph //////////////////////// ///////////////////// query/operations on realized graph /////////////////////
/*! \brief Add vertices to the given vertex type */ /*! \brief Add vertices to the given vertex type */
virtual void AddVertices(dgl_type_t vtype, uint64_t num_vertices) = 0; virtual void AddVertices(dgl_type_t vtype, uint64_t num_vertices) = 0;
...@@ -128,7 +123,8 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -128,7 +123,8 @@ class BaseHeteroGraph : public runtime::Object {
virtual void RecordStream(DGLStreamHandle stream) = 0; virtual void RecordStream(DGLStreamHandle stream) = 0;
/*! /*!
* \brief Get the number of integer bits used to store node/edge ids (32 or 64). * \brief Get the number of integer bits used to store node/edge ids (32 or
* 64).
*/ */
// TODO(BarclayII) replace NumBits() calls to DataType() calls // TODO(BarclayII) replace NumBits() calls to DataType() calls
virtual uint8_t NumBits() const = 0; virtual uint8_t NumBits() const = 0;
...@@ -156,14 +152,18 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -156,14 +152,18 @@ class BaseHeteroGraph : public runtime::Object {
/*! \return true if the given vertex is in the graph.*/ /*! \return true if the given vertex is in the graph.*/
virtual bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const = 0; virtual bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const = 0;
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/ /*! \return a 0-1 array indicating whether the given vertices are in the
* graph.
*/
virtual BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const = 0; virtual BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const = 0;
/*! \return true if the given edge is in the graph.*/ /*! \return true if the given edge is in the graph.*/
virtual bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0; virtual bool HasEdgeBetween(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0;
/*! \return a 0-1 array indicating whether the given edges are in the graph.*/ /*! \return a 0-1 array indicating whether the given edges are in the graph.*/
virtual BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const = 0; virtual BoolArray HasEdgesBetween(
dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const = 0;
/*! /*!
* \brief Find the predecessors of a vertex. * \brief Find the predecessors of a vertex.
...@@ -187,14 +187,13 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -187,14 +187,13 @@ class BaseHeteroGraph : public runtime::Object {
/*! /*!
* \brief Get all edge ids between the two given endpoints * \brief Get all edge ids between the two given endpoints
* \note The given src and dst vertices should belong to the source vertex type * \note The given src and dst vertices should belong to the source vertex
* and the dest vertex type of the given edge type, respectively. * type and the dest vertex type of the given edge type, respectively. \param
* \param etype The edge type * etype The edge type \param src The source vertex. \param dst The
* \param src The source vertex. * destination vertex. \return the edge id array.
* \param dst The destination vertex.
* \return the edge id array.
*/ */
virtual IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0; virtual IdArray EdgeId(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0;
/*! /*!
* \brief Get all edge ids between the given endpoint pairs. * \brief Get all edge ids between the given endpoint pairs.
...@@ -204,34 +203,40 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -204,34 +203,40 @@ class BaseHeteroGraph : public runtime::Object {
* \param dst The dst vertex ids. * \param dst The dst vertex ids.
* \return EdgeArray containing all edges between all pairs. * \return EdgeArray containing all edges between all pairs.
*/ */
virtual EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const = 0; virtual EdgeArray EdgeIdsAll(
dgl_type_t etype, IdArray src, IdArray dst) const = 0;
/*! /*!
* \brief Get edge ids between the given endpoint pairs. * \brief Get edge ids between the given endpoint pairs.
* *
* Only find one matched edge Ids even if there are multiple matches due to parallel * Only find one matched edge Ids even if there are multiple matches due to
* edges. The i^th Id in the returned array is for edge (src[i], dst[i]). * parallel edges. The i^th Id in the returned array is for edge (src[i],
* dst[i]).
* *
* \param etype The edge type * \param etype The edge type
* \param src The src vertex ids. * \param src The src vertex ids.
* \param dst The dst vertex ids. * \param dst The dst vertex ids.
* \return EdgeArray containing all edges between all pairs. * \return EdgeArray containing all edges between all pairs.
*/ */
virtual IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const = 0; virtual IdArray EdgeIdsOne(
dgl_type_t etype, IdArray src, IdArray dst) const = 0;
/*! /*!
* \brief Find the edge ID and return the pair of endpoints * \brief Find the edge ID and return the pair of endpoints
* \param etype The edge type * \param etype The edge type
* \param eid The edge ID * \param eid The edge ID
* \return a pair whose first element is the source and the second the destination. * \return a pair whose first element is the source and the second the
* destination.
*/ */
virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const = 0; virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(
dgl_type_t etype, dgl_id_t eid) const = 0;
/*! /*!
* \brief Find the edge IDs and return their source and target node IDs. * \brief Find the edge IDs and return their source and target node IDs.
* \param etype The edge type * \param etype The edge type
* \param eids The edge ID array. * \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved. * \return EdgeArray containing all edges with id in eid. The order is
* preserved.
*/ */
virtual EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const = 0; virtual EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const = 0;
...@@ -277,14 +282,14 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -277,14 +282,14 @@ class BaseHeteroGraph : public runtime::Object {
/*! /*!
* \brief Get all the edges in the graph. * \brief Get all the edges in the graph.
* \note If order is "srcdst", the returned edges list is sorted by their src and * \note If order is "srcdst", the returned edges list is sorted by their src
* dst ids. If order is "eid", they are in their edge id order. * and dst ids. If order is "eid", they are in their edge id order. Otherwise,
* Otherwise, in the arbitrary order. * in the arbitrary order. \param etype The edge type \param order The order
* \param etype The edge type * of the returned edge list. \return the id arrays of the two endpoints of
* \param order The order of the returned edge list. * the edges.
* \return the id arrays of the two endpoints of the edges.
*/ */
virtual EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const = 0; virtual EdgeArray Edges(
dgl_type_t etype, const std::string& order = "") const = 0;
/*! /*!
* \brief Get the in degree of the given vertex. * \brief Get the in degree of the given vertex.
...@@ -373,15 +378,15 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -373,15 +378,15 @@ class BaseHeteroGraph : public runtime::Object {
* If the fmt is 'csr', the function should return three arrays, representing * If the fmt is 'csr', the function should return three arrays, representing
* indptr, indices and edge ids * indptr, indices and edge ids
* *
* If the fmt is 'coo', the function should return one array of shape (2, nnz), * If the fmt is 'coo', the function should return one array of shape (2,
* representing a horitonzal stack of row and col indices. * nnz), representing a horitonzal stack of row and col indices.
* *
* \param transpose A flag to transpose the returned adjacency matrix. * \param transpose A flag to transpose the returned adjacency matrix.
* \param fmt the format of the returned adjacency matrix. * \param fmt the format of the returned adjacency matrix.
* \return a vector of IdArrays. * \return a vector of IdArrays.
*/ */
virtual std::vector<IdArray> GetAdj( virtual std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const = 0; dgl_type_t etype, bool transpose, const std::string& fmt) const = 0;
/*! /*!
* \brief Determine which format to use with a preference. * \brief Determine which format to use with a preference.
...@@ -451,39 +456,42 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -451,39 +456,42 @@ class BaseHeteroGraph : public runtime::Object {
/*! /*!
* \brief Extract the induced subgraph by the given vertices. * \brief Extract the induced subgraph by the given vertices.
* *
* The length of the given vector should be equal to the number of vertex types. * The length of the given vector should be equal to the number of vertex
* Empty arrays can be provided if no vertex is needed for the type. The result * types. Empty arrays can be provided if no vertex is needed for the type.
* subgraph has the same meta graph with the parent, but some types can have no * The result subgraph has the same meta graph with the parent, but some types
* node/edge. * can have no node/edge.
* *
* \param vids the induced vertices per type. * \param vids the induced vertices per type.
* \return the subgraph. * \return the subgraph.
*/ */
virtual HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const = 0; virtual HeteroSubgraph VertexSubgraph(
const std::vector<IdArray>& vids) const = 0;
/*! /*!
* \brief Extract the induced subgraph by the given edges. * \brief Extract the induced subgraph by the given edges.
* *
* The length of the given vector should be equal to the number of edge types. * The length of the given vector should be equal to the number of edge types.
* Empty arrays can be provided if no edge is needed for the type. The result * Empty arrays can be provided if no edge is needed for the type. The result
* subgraph has the same meta graph with the parent, but some types can have no * subgraph has the same meta graph with the parent, but some types can have
* node/edge. * no node/edge.
* *
* \param eids The edges in the subgraph. * \param eids The edges in the subgraph.
* \param preserve_nodes If true, the vertices will not be relabeled, so some vertices * \param preserve_nodes If true, the vertices will not be relabeled, so some
* may have no incident edges. * vertices may have no incident edges. \return the subgraph.
* \return the subgraph.
*/ */
virtual HeteroSubgraph EdgeSubgraph( virtual HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const = 0; const std::vector<IdArray>& eids, bool preserve_nodes = false) const = 0;
/*! /*!
* \brief Convert the list of requested unitgraph graphs into a single unitgraph graph. * \brief Convert the list of requested unitgraph graphs into a single
* unitgraph graph.
* *
* \param etypes The list of edge type IDs. * \param etypes The list of edge type IDs.
* \return The flattened graph, with induced source/edge/destination types/IDs. * \return The flattened graph, with induced source/edge/destination
* types/IDs.
*/ */
virtual FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const { virtual FlattenedHeteroGraphPtr Flatten(
const std::vector<dgl_type_t>& etypes) const {
LOG(FATAL) << "Flatten operation unsupported"; LOG(FATAL) << "Flatten operation unsupported";
return nullptr; return nullptr;
} }
...@@ -502,7 +510,7 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -502,7 +510,7 @@ class BaseHeteroGraph : public runtime::Object {
GraphPtr meta_graph_; GraphPtr meta_graph_;
// empty constructor // empty constructor
BaseHeteroGraph(){} BaseHeteroGraph() {}
}; };
// Define HeteroGraphRef // Define HeteroGraphRef
...@@ -527,9 +535,9 @@ struct HeteroSubgraph : public runtime::Object { ...@@ -527,9 +535,9 @@ struct HeteroSubgraph : public runtime::Object {
HeteroGraphPtr graph; HeteroGraphPtr graph;
/*! /*!
* \brief The induced vertex ids of each entity type. * \brief The induced vertex ids of each entity type.
* The vector length is equal to the number of vertex types in the parent graph. * The vector length is equal to the number of vertex types in the parent
* Each array i has the same length as the number of vertices in type i. * graph. Each array i has the same length as the number of vertices in type
* Empty array is allowed if the mapping is identity. * i. Empty array is allowed if the mapping is identity.
*/ */
std::vector<IdArray> induced_vertices; std::vector<IdArray> induced_vertices;
/*! /*!
...@@ -553,7 +561,8 @@ struct FlattenedHeteroGraph : public runtime::Object { ...@@ -553,7 +561,8 @@ struct FlattenedHeteroGraph : public runtime::Object {
HeteroGraphRef graph; HeteroGraphRef graph;
/*! /*!
* \brief Mapping from source node ID to node type in parent graph * \brief Mapping from source node ID to node type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously. * \note The induced type array guarantees that the same type always appear
* contiguously.
*/ */
IdArray induced_srctype; IdArray induced_srctype;
/*! /*!
...@@ -564,7 +573,8 @@ struct FlattenedHeteroGraph : public runtime::Object { ...@@ -564,7 +573,8 @@ struct FlattenedHeteroGraph : public runtime::Object {
IdArray induced_srcid; IdArray induced_srcid;
/*! /*!
* \brief Mapping from edge ID to edge type in parent graph * \brief Mapping from edge ID to edge type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously. * \note The induced type array guarantees that the same type always appear
* contiguously.
*/ */
IdArray induced_etype; IdArray induced_etype;
/*! /*!
...@@ -575,17 +585,20 @@ struct FlattenedHeteroGraph : public runtime::Object { ...@@ -575,17 +585,20 @@ struct FlattenedHeteroGraph : public runtime::Object {
IdArray induced_eid; IdArray induced_eid;
/*! /*!
* \brief Mapping from destination node ID to node type in parent graph * \brief Mapping from destination node ID to node type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously. * \note The induced type array guarantees that the same type always appear
* contiguously.
*/ */
IdArray induced_dsttype; IdArray induced_dsttype;
/*! /*!
* \brief The set of node types in parent graph appearing in destination nodes. * \brief The set of node types in parent graph appearing in destination
* nodes.
*/ */
IdArray induced_dsttype_set; IdArray induced_dsttype_set;
/*! \brief Mapping from destination node ID to local node ID in parent graph */ /*! \brief Mapping from destination node ID to local node ID in parent graph
*/
IdArray induced_dstid; IdArray induced_dstid;
void VisitAttrs(runtime::AttrVisitor *v) final { void VisitAttrs(runtime::AttrVisitor* v) final {
v->Visit("graph", &graph); v->Visit("graph", &graph);
v->Visit("induced_srctype", &induced_srctype); v->Visit("induced_srctype", &induced_srctype);
v->Visit("induced_srctype_set", &induced_srctype_set); v->Visit("induced_srctype_set", &induced_srctype_set);
...@@ -610,9 +623,8 @@ DGL_DEFINE_OBJECT_REF(FlattenedHeteroGraphRef, FlattenedHeteroGraph); ...@@ -610,9 +623,8 @@ DGL_DEFINE_OBJECT_REF(FlattenedHeteroGraphRef, FlattenedHeteroGraph);
* additionally specifying number of nodes per type. * additionally specifying number of nodes per type.
*/ */
HeteroGraphPtr CreateHeteroGraph( HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<HeteroGraphPtr> &rel_graphs, const std::vector<int64_t>& num_nodes_per_type = {});
const std::vector<int64_t> &num_nodes_per_type = {});
/*! /*!
* \brief Create a heterograph from COO input. * \brief Create a heterograph from COO input.
...@@ -629,8 +641,8 @@ HeteroGraphPtr CreateHeteroGraph( ...@@ -629,8 +641,8 @@ HeteroGraphPtr CreateHeteroGraph(
* \return A heterograph pointer. * \return A heterograph pointer.
*/ */
HeteroGraphPtr CreateFromCOO( HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,
IdArray row, IdArray col, bool row_sorted = false, bool col_sorted = false, IdArray col, bool row_sorted = false, bool col_sorted = false,
dgl_format_code_t formats = ALL_CODE); dgl_format_code_t formats = ALL_CODE);
/*! /*!
...@@ -656,9 +668,8 @@ HeteroGraphPtr CreateFromCOO( ...@@ -656,9 +668,8 @@ HeteroGraphPtr CreateFromCOO(
* \return A heterograph pointer. * \return A heterograph pointer.
*/ */
HeteroGraphPtr CreateFromCSR( HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE);
dgl_format_code_t formats = ALL_CODE);
/*! /*!
* \brief Create a heterograph from CSR input. * \brief Create a heterograph from CSR input.
...@@ -683,9 +694,8 @@ HeteroGraphPtr CreateFromCSR( ...@@ -683,9 +694,8 @@ HeteroGraphPtr CreateFromCSR(
* \return A heterograph pointer. * \return A heterograph pointer.
*/ */
HeteroGraphPtr CreateFromCSC( HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE);
dgl_format_code_t formats = ALL_CODE);
/*! /*!
* \brief Create a heterograph from CSC input. * \brief Create a heterograph from CSC input.
...@@ -702,23 +712,25 @@ HeteroGraphPtr CreateFromCSC( ...@@ -702,23 +712,25 @@ HeteroGraphPtr CreateFromCSC(
* \brief Extract the subgraph of the in edges of the given nodes. * \brief Extract the subgraph of the in edges of the given nodes.
* \param graph Graph * \param graph Graph
* \param nodes Node IDs of each type * \param nodes Node IDs of each type
* \param relabel_nodes Whether to remove isolated nodes and relabel the rest ones * \param relabel_nodes Whether to remove isolated nodes and relabel the rest
* \return Subgraph containing only the in edges. The returned graph has the same * ones \return Subgraph containing only the in edges. The returned graph has
* schema as the original one. * the same schema as the original one.
*/ */
HeteroSubgraph InEdgeGraph( HeteroSubgraph InEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& nodes, bool relabel_nodes = false); const HeteroGraphPtr graph, const std::vector<IdArray>& nodes,
bool relabel_nodes = false);
/*! /*!
* \brief Extract the subgraph of the out edges of the given nodes. * \brief Extract the subgraph of the out edges of the given nodes.
* \param graph Graph * \param graph Graph
* \param nodes Node IDs of each type * \param nodes Node IDs of each type
* \param relabel_nodes Whether to remove isolated nodes and relabel the rest ones * \param relabel_nodes Whether to remove isolated nodes and relabel the rest
* \return Subgraph containing only the out edges. The returned graph has the same * ones \return Subgraph containing only the out edges. The returned graph has
* schema as the original one. * the same schema as the original one.
*/ */
HeteroSubgraph OutEdgeGraph( HeteroSubgraph OutEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& nodes, bool relabel_nodes = false); const HeteroGraphPtr graph, const std::vector<IdArray>& nodes,
bool relabel_nodes = false);
/*! /*!
* \brief Joint union multiple graphs into one graph. * \brief Joint union multiple graphs into one graph.
...@@ -735,7 +747,8 @@ HeteroGraphPtr JointUnionHeteroGraph( ...@@ -735,7 +747,8 @@ HeteroGraphPtr JointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs); GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
/*! /*!
* \brief Union multiple graphs into one with each input graph as one disjoint component. * \brief Union multiple graphs into one with each input graph as one disjoint
* component.
* *
* All input graphs should have the same metagraph. * All input graphs should have the same metagraph.
* *
...@@ -754,7 +767,8 @@ HeteroGraphPtr DisjointUnionHeteroGraph2( ...@@ -754,7 +767,8 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs); GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
/*! /*!
* \brief Slice a contiguous subgraph, e.g. retrieve a component graph from a batched graph. * \brief Slice a contiguous subgraph, e.g. retrieve a component graph from a
* batched graph.
* *
* TODO(mufei): remove the meta_graph argument * TODO(mufei): remove the meta_graph argument
* *
...@@ -767,25 +781,22 @@ HeteroGraphPtr DisjointUnionHeteroGraph2( ...@@ -767,25 +781,22 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
* \return Sliced graph * \return Sliced graph
*/ */
HeteroGraphPtr SliceHeteroGraph( HeteroGraphPtr SliceHeteroGraph(
GraphPtr meta_graph, GraphPtr meta_graph, HeteroGraphPtr batched_graph,
HeteroGraphPtr batched_graph, IdArray num_nodes_per_type, IdArray start_nid_per_type,
IdArray num_nodes_per_type, IdArray num_edges_per_type, IdArray start_eid_per_type);
IdArray start_nid_per_type,
IdArray num_edges_per_type,
IdArray start_eid_per_type);
/*! /*!
* \brief Split a graph into multiple disjoin components. * \brief Split a graph into multiple disjoin components.
* *
* Edges across different components are ignored. All the result graphs have the same * Edges across different components are ignored. All the result graphs have the
* metagraph as the input one. * same metagraph as the input one.
* *
* The `vertex_sizes` and `edge_sizes` arrays the concatenation of arrays of each * The `vertex_sizes` and `edge_sizes` arrays the concatenation of arrays of
* node/edge type. Suppose there are N vertex types, then the array length should * each node/edge type. Suppose there are N vertex types, then the array length
* be B*N, where B is the number of components to split. * should be B*N, where B is the number of components to split.
* *
* TODO(minjie): remove the meta_graph argument; use vector<IdArray> for vertex_sizes * TODO(minjie): remove the meta_graph argument; use vector<IdArray> for
* and edge_sizes. * vertex_sizes and edge_sizes.
* *
* \tparam IdType Graph's index data type, can be int32_t or int64_t * \tparam IdType Graph's index data type, can be int32_t or int64_t
* \param meta_graph Metagraph. * \param meta_graph Metagraph.
...@@ -796,16 +807,11 @@ HeteroGraphPtr SliceHeteroGraph( ...@@ -796,16 +807,11 @@ HeteroGraphPtr SliceHeteroGraph(
*/ */
template <class IdType> template <class IdType>
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
GraphPtr meta_graph, GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
HeteroGraphPtr batched_graph,
IdArray vertex_sizes,
IdArray edge_sizes); IdArray edge_sizes);
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
GraphPtr meta_graph, GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
HeteroGraphPtr batched_graph,
IdArray vertex_sizes,
IdArray edge_sizes); IdArray edge_sizes);
/*! /*!
...@@ -862,7 +868,8 @@ DGL_DEFINE_OBJECT_REF(HeteroPickleStatesRef, HeteroPickleStates); ...@@ -862,7 +868,8 @@ DGL_DEFINE_OBJECT_REF(HeteroPickleStatesRef, HeteroPickleStates);
HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states); HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states);
/*! /*!
* \brief Get the pickling state of the relation graph structure in backend tensors. * \brief Get the pickling state of the relation graph structure in backend
* tensors.
* *
* \return a HeteroPickleStates object * \return a HeteroPickleStates object
*/ */
...@@ -886,8 +893,8 @@ HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states); ...@@ -886,8 +893,8 @@ HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states);
HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states); HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states);
/*! /*!
* \brief Get the pickling states of the relation graph structure in backend tensors for * \brief Get the pickling states of the relation graph structure in backend
* ForkingPickler. * tensors for ForkingPickler.
* *
* This is different from HeteroPickle where * This is different from HeteroPickle where
* (1) Backward compatibility is not required, * (1) Backward compatibility is not required,
...@@ -895,14 +902,11 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states); ...@@ -895,14 +902,11 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states);
*/ */
HeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph); HeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph);
#define FORMAT_HAS_CSC(format) \ #define FORMAT_HAS_CSC(format) ((format)&CSC_CODE)
((format) & CSC_CODE)
#define FORMAT_HAS_CSR(format) \ #define FORMAT_HAS_CSR(format) ((format)&CSR_CODE)
((format) & CSR_CODE)
#define FORMAT_HAS_COO(format) \ #define FORMAT_HAS_COO(format) ((format)&COO_CODE)
((format) & COO_CODE)
} // namespace dgl } // namespace dgl
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "./runtime/ndarray.h" #include "./runtime/ndarray.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -37,7 +38,7 @@ struct BcastOff { ...@@ -37,7 +38,7 @@ struct BcastOff {
/*! \brief Whether broadcast is required or not. */ /*! \brief Whether broadcast is required or not. */
bool use_bcast; bool use_bcast;
/*! /*!
* \brief Auxiliary information for kernel computation * \brief Auxiliary information for kernel computation
* \note lhs_len refers to the left hand side operand length. * \note lhs_len refers to the left hand side operand length.
* e.g. 15 for shape (1, 3, 5) * e.g. 15 for shape (1, 3, 5)
* rhs_len refers to the right hand side operand length. * rhs_len refers to the right hand side operand length.
...@@ -61,7 +62,6 @@ struct BcastOff { ...@@ -61,7 +62,6 @@ struct BcastOff {
*/ */
BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs); BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs);
} // namespace dgl } // namespace dgl
#endif // DGL_BCAST_H_ #endif // DGL_BCAST_H_
...@@ -6,13 +6,12 @@ ...@@ -6,13 +6,12 @@
#ifndef DGL_GRAPH_H_ #ifndef DGL_GRAPH_H_
#define DGL_GRAPH_H_ #define DGL_GRAPH_H_
#include <string>
#include <vector>
#include <string>
#include <cstdint> #include <cstdint>
#include <utility>
#include <tuple>
#include <memory> #include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "graph_interface.h" #include "graph_interface.h"
...@@ -23,7 +22,7 @@ class GraphOp; ...@@ -23,7 +22,7 @@ class GraphOp;
typedef std::shared_ptr<Graph> MutableGraphPtr; typedef std::shared_ptr<Graph> MutableGraphPtr;
/*! \brief Mutable graph based on adjacency list. */ /*! \brief Mutable graph based on adjacency list. */
class Graph: public GraphInterface { class Graph : public GraphInterface {
public: public:
/*! \brief default constructor */ /*! \brief default constructor */
Graph() {} Graph() {}
...@@ -89,13 +88,9 @@ class Graph: public GraphInterface { ...@@ -89,13 +88,9 @@ class Graph: public GraphInterface {
num_edges_ = 0; num_edges_ = 0;
} }
DGLContext Context() const override { DGLContext Context() const override { return DGLContext{kDGLCPU, 0}; }
return DGLContext{kDGLCPU, 0};
}
uint8_t NumBits() const override { uint8_t NumBits() const override { return 64; }
return 64;
}
/*! /*!
* \note not const since we have caches * \note not const since we have caches
...@@ -106,21 +101,17 @@ class Graph: public GraphInterface { ...@@ -106,21 +101,17 @@ class Graph: public GraphInterface {
/*! /*!
* \return whether the graph is read-only * \return whether the graph is read-only
*/ */
bool IsReadonly() const override { bool IsReadonly() const override { return false; }
return false;
}
/*! \return the number of vertices in the graph.*/ /*! \return the number of vertices in the graph.*/
uint64_t NumVertices() const override { uint64_t NumVertices() const override { return adjlist_.size(); }
return adjlist_.size();
}
/*! \return the number of edges in the graph.*/ /*! \return the number of edges in the graph.*/
uint64_t NumEdges() const override { uint64_t NumEdges() const override { return num_edges_; }
return num_edges_;
}
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/ /*! \return a 0-1 array indicating whether the given vertices are in the
* graph.
*/
BoolArray HasVertices(IdArray vids) const override; BoolArray HasVertices(IdArray vids) const override;
/*! \return true if the given edge is in the graph.*/ /*! \return true if the given edge is in the graph.*/
...@@ -132,7 +123,8 @@ class Graph: public GraphInterface { ...@@ -132,7 +123,8 @@ class Graph: public GraphInterface {
/*! /*!
* \brief Find the predecessors of a vertex. * \brief Find the predecessors of a vertex.
* \param vid The vertex id. * \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1). * \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the predecessor id array. * \return the predecessor id array.
*/ */
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override; IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override;
...@@ -140,7 +132,8 @@ class Graph: public GraphInterface { ...@@ -140,7 +132,8 @@ class Graph: public GraphInterface {
/*! /*!
* \brief Find the successors of a vertex. * \brief Find the successors of a vertex.
* \param vid The vertex id. * \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1). * \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the successor id array. * \return the successor id array.
*/ */
IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override; IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override;
...@@ -169,7 +162,8 @@ class Graph: public GraphInterface { ...@@ -169,7 +162,8 @@ class Graph: public GraphInterface {
/*! /*!
* \brief Find the edge ID and return the pair of endpoints * \brief Find the edge ID and return the pair of endpoints
* \param eid The edge ID * \param eid The edge ID
* \return a pair whose first element is the source and the second the destination. * \return a pair whose first element is the source and the second the
* destination.
*/ */
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override { std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
return std::make_pair(all_edges_src_[eid], all_edges_dst_[eid]); return std::make_pair(all_edges_src_[eid], all_edges_dst_[eid]);
...@@ -178,7 +172,8 @@ class Graph: public GraphInterface { ...@@ -178,7 +172,8 @@ class Graph: public GraphInterface {
/*! /*!
* \brief Find the edge IDs and return their source and target node IDs. * \brief Find the edge IDs and return their source and target node IDs.
* \param eids The edge ID array. * \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved. * \return EdgeArray containing all edges with id in eid. The order is
* preserved.
*/ */
EdgeArray FindEdges(IdArray eids) const override; EdgeArray FindEdges(IdArray eids) const override;
...@@ -216,10 +211,11 @@ class Graph: public GraphInterface { ...@@ -216,10 +211,11 @@ class Graph: public GraphInterface {
* \brief Get all the edges in the graph. * \brief Get all the edges in the graph.
* \note If sorted is true, the returned edges list is sorted by their src and * \note If sorted is true, the returned edges list is sorted by their src and
* dst ids. Otherwise, they are in their edge id order. * dst ids. Otherwise, they are in their edge id order.
* \param sorted Whether the returned edge list is sorted by their src and dst ids * \param sorted Whether the returned edge list is sorted by their src and dst
* ids.
* \return the id arrays of the two endpoints of the edges. * \return the id arrays of the two endpoints of the edges.
*/ */
EdgeArray Edges(const std::string &order = "") const override; EdgeArray Edges(const std::string& order = "") const override;
/*! /*!
* \brief Get the in degree of the given vertex. * \brief Get the in degree of the given vertex.
...@@ -258,13 +254,14 @@ class Graph: public GraphInterface { ...@@ -258,13 +254,14 @@ class Graph: public GraphInterface {
/*! /*!
* \brief Construct the induced subgraph of the given vertices. * \brief Construct the induced subgraph of the given vertices.
* *
* The induced subgraph is a subgraph formed by specifying a set of vertices V' and then * The induced subgraph is a subgraph formed by specifying a set of vertices
* selecting all of the edges from the original graph that connect two vertices in V'. * V' and then selecting all of the edges from the original graph that connect
* two vertices in V'.
* *
* Vertices and edges in the original graph will be "reindexed" to local index. The local * Vertices and edges in the original graph will be "reindexed" to local
* index of the vertices preserve the order of the given id array, while the local index * index. The local index of the vertices preserve the order of the given id
* of the edges preserve the index order in the original graph. Vertices not in the * array, while the local index of the edges preserve the index order in the
* original graph are ignored. * original graph. Vertices not in the original graph are ignored.
* *
* The result subgraph is read-only. * The result subgraph is read-only.
* *
...@@ -276,20 +273,22 @@ class Graph: public GraphInterface { ...@@ -276,20 +273,22 @@ class Graph: public GraphInterface {
/*! /*!
* \brief Construct the induced edge subgraph of the given edges. * \brief Construct the induced edge subgraph of the given edges.
* *
* The induced edges subgraph is a subgraph formed by specifying a set of edges E' and then * The induced edges subgraph is a subgraph formed by specifying a set of
* selecting all of the nodes from the original graph that are endpoints in E'. * edges E' and then selecting all of the nodes from the original graph that
* are endpoints in E'.
* *
* Vertices and edges in the original graph will be "reindexed" to local index. The local * Vertices and edges in the original graph will be "reindexed" to local
* index of the edges preserve the order of the given id array, while the local index * index. The local index of the edges preserve the order of the given id
* of the vertices preserve the index order in the original graph. Edges not in the * array, while the local index of the vertices preserve the index order in
* original graph are ignored. * the original graph. Edges not in the original graph are ignored.
* *
* The result subgraph is read-only. * The result subgraph is read-only.
* *
* \param eids The edges in the subgraph. * \param eids The edges in the subgraph.
* \return the induced edge subgraph * \return the induced edge subgraph
*/ */
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override; Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const override;
/*! /*!
* \brief Return the successor vector * \brief Return the successor vector
...@@ -344,12 +343,11 @@ class Graph: public GraphInterface { ...@@ -344,12 +343,11 @@ class Graph: public GraphInterface {
* \param fmt the format of the returned adjacency matrix. * \param fmt the format of the returned adjacency matrix.
* \return a vector of three IdArray. * \return a vector of three IdArray.
*/ */
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override; std::vector<IdArray> GetAdj(
bool transpose, const std::string& fmt) const override;
/*! \brief Create an empty graph */ /*! \brief Create an empty graph */
static MutableGraphPtr Create() { static MutableGraphPtr Create() { return std::make_shared<Graph>(); }
return std::make_shared<Graph>();
}
/*! \brief Create from coo */ /*! \brief Create from coo */
static MutableGraphPtr CreateFromCOO( static MutableGraphPtr CreateFromCOO(
......
...@@ -6,11 +6,11 @@ ...@@ -6,11 +6,11 @@
#ifndef DGL_GRAPH_INTERFACE_H_ #ifndef DGL_GRAPH_INTERFACE_H_
#define DGL_GRAPH_INTERFACE_H_ #define DGL_GRAPH_INTERFACE_H_
#include <string>
#include <vector>
#include <utility>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <string>
#include <utility>
#include <vector>
#include "./runtime/object.h" #include "./runtime/object.h"
#include "array.h" #include "array.h"
...@@ -23,7 +23,8 @@ const dgl_id_t DGL_INVALID_ID = static_cast<dgl_id_t>(-1); ...@@ -23,7 +23,8 @@ const dgl_id_t DGL_INVALID_ID = static_cast<dgl_id_t>(-1);
* \brief This class references data in std::vector. * \brief This class references data in std::vector.
* *
* This isn't a STL-style iterator. It provides a STL data container interface. * This isn't a STL-style iterator. It provides a STL data container interface.
* but it doesn't own data itself. instead, it only references data in std::vector. * but it doesn't own data itself. instead, it only references data in
* std::vector.
*/ */
class DGLIdIters { class DGLIdIters {
public: public:
...@@ -34,18 +35,11 @@ class DGLIdIters { ...@@ -34,18 +35,11 @@ class DGLIdIters {
this->begin_ = begin; this->begin_ = begin;
this->end_ = end; this->end_ = end;
} }
const dgl_id_t *begin() const { const dgl_id_t *begin() const { return this->begin_; }
return this->begin_; const dgl_id_t *end() const { return this->end_; }
} dgl_id_t operator[](int64_t i) const { return *(this->begin_ + i); }
const dgl_id_t *end() const { size_t size() const { return this->end_ - this->begin_; }
return this->end_;
}
dgl_id_t operator[](int64_t i) const {
return *(this->begin_ + i);
}
size_t size() const {
return this->end_ - this->begin_;
}
private: private:
const dgl_id_t *begin_{nullptr}, *end_{nullptr}; const dgl_id_t *begin_{nullptr}, *end_{nullptr};
}; };
...@@ -63,23 +57,15 @@ class DGLIdIters32 { ...@@ -63,23 +57,15 @@ class DGLIdIters32 {
this->begin_ = begin; this->begin_ = begin;
this->end_ = end; this->end_ = end;
} }
const int32_t *begin() const { const int32_t *begin() const { return this->begin_; }
return this->begin_; const int32_t *end() const { return this->end_; }
} int32_t operator[](int32_t i) const { return *(this->begin_ + i); }
const int32_t *end() const { size_t size() const { return this->end_ - this->begin_; }
return this->end_;
}
int32_t operator[](int32_t i) const {
return *(this->begin_ + i);
}
size_t size() const {
return this->end_ - this->begin_;
}
private: private:
const int32_t *begin_{nullptr}, *end_{nullptr}; const int32_t *begin_{nullptr}, *end_{nullptr};
}; };
/* \brief structure used to represent a list of edges */ /* \brief structure used to represent a list of edges */
typedef struct { typedef struct {
/* \brief the two endpoints and the id of the edge */ /* \brief the two endpoints and the id of the edge */
...@@ -140,7 +126,8 @@ class GraphInterface : public runtime::Object { ...@@ -140,7 +126,8 @@ class GraphInterface : public runtime::Object {
virtual DGLContext Context() const = 0; virtual DGLContext Context() const = 0;
/*! /*!
* \brief Get the number of integer bits used to store node/edge ids (32 or 64). * \brief Get the number of integer bits used to store node/edge ids
* (32 or 64).
*/ */
virtual uint8_t NumBits() const = 0; virtual uint8_t NumBits() const = 0;
...@@ -192,11 +179,11 @@ class GraphInterface : public runtime::Object { ...@@ -192,11 +179,11 @@ class GraphInterface : public runtime::Object {
virtual uint64_t NumEdges() const = 0; virtual uint64_t NumEdges() const = 0;
/*! \return true if the given vertex is in the graph.*/ /*! \return true if the given vertex is in the graph.*/
virtual bool HasVertex(dgl_id_t vid) const { virtual bool HasVertex(dgl_id_t vid) const { return vid < NumVertices(); }
return vid < NumVertices();
}
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/ /*! \return a 0-1 array indicating whether the given vertices are in the
* graph.
*/
virtual BoolArray HasVertices(IdArray vids) const = 0; virtual BoolArray HasVertices(IdArray vids) const = 0;
/*! \return true if the given edge is in the graph.*/ /*! \return true if the given edge is in the graph.*/
...@@ -208,7 +195,8 @@ class GraphInterface : public runtime::Object { ...@@ -208,7 +195,8 @@ class GraphInterface : public runtime::Object {
/*! /*!
* \brief Find the predecessors of a vertex. * \brief Find the predecessors of a vertex.
* \param vid The vertex id. * \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1). * \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the predecessor id array. * \return the predecessor id array.
*/ */
virtual IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const = 0; virtual IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const = 0;
...@@ -216,7 +204,8 @@ class GraphInterface : public runtime::Object { ...@@ -216,7 +204,8 @@ class GraphInterface : public runtime::Object {
/*! /*!
* \brief Find the successors of a vertex. * \brief Find the successors of a vertex.
* \param vid The vertex id. * \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1). * \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the successor id array. * \return the successor id array.
*/ */
virtual IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const = 0; virtual IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const = 0;
...@@ -245,14 +234,16 @@ class GraphInterface : public runtime::Object { ...@@ -245,14 +234,16 @@ class GraphInterface : public runtime::Object {
/*! /*!
* \brief Find the edge ID and return the pair of endpoints * \brief Find the edge ID and return the pair of endpoints
* \param eid The edge ID * \param eid The edge ID
* \return a pair whose first element is the source and the second the destination. * \return a pair whose first element is the source and the second the
* destination.
*/ */
virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const = 0; virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const = 0;
/*! /*!
* \brief Find the edge IDs and return their source and target node IDs. * \brief Find the edge IDs and return their source and target node IDs.
* \param eids The edge ID array. * \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved. * \return EdgeArray containing all edges with id in eid. The order is
* preserved.
*/ */
virtual EdgeArray FindEdges(IdArray eids) const = 0; virtual EdgeArray FindEdges(IdArray eids) const = 0;
...@@ -288,8 +279,8 @@ class GraphInterface : public runtime::Object { ...@@ -288,8 +279,8 @@ class GraphInterface : public runtime::Object {
/*! /*!
* \brief Get all the edges in the graph. * \brief Get all the edges in the graph.
* \note If order is "srcdst", the returned edges list is sorted by their src and * \note If order is "srcdst", the returned edges list is sorted by their src
* dst ids. If order is "eid", they are in their edge id order. * and dst ids. If order is "eid", they are in their edge id order.
* Otherwise, in the arbitrary order. * Otherwise, in the arbitrary order.
* \param order The order of the returned edge list. * \param order The order of the returned edge list.
* \return the id arrays of the two endpoints of the edges. * \return the id arrays of the two endpoints of the edges.
...@@ -327,13 +318,14 @@ class GraphInterface : public runtime::Object { ...@@ -327,13 +318,14 @@ class GraphInterface : public runtime::Object {
/*! /*!
* \brief Construct the induced subgraph of the given vertices. * \brief Construct the induced subgraph of the given vertices.
* *
* The induced subgraph is a subgraph formed by specifying a set of vertices V' and then * The induced subgraph is a subgraph formed by specifying a set of vertices
* selecting all of the edges from the original graph that connect two vertices in V'. * V' and then selecting all of the edges from the original graph that connect
* two vertices in V'.
* *
* Vertices and edges in the original graph will be "reindexed" to local index. The local * Vertices and edges in the original graph will be "reindexed" to local
* index of the vertices preserve the order of the given id array, while the local index * index. The local index of the vertices preserve the order of the given id
* of the edges preserve the index order in the original graph. Vertices not in the * array, while the local index of the edges preserve the index order in the
* original graph are ignored. * original graph. Vertices not in the original graph are ignored.
* *
* The result subgraph is read-only. * The result subgraph is read-only.
* *
...@@ -345,22 +337,24 @@ class GraphInterface : public runtime::Object { ...@@ -345,22 +337,24 @@ class GraphInterface : public runtime::Object {
/*! /*!
* \brief Construct the induced edge subgraph of the given edges. * \brief Construct the induced edge subgraph of the given edges.
* *
* The induced edges subgraph is a subgraph formed by specifying a set of edges E' and then * The induced edges subgraph is a subgraph formed by specifying a set of
* selecting all of the nodes from the original graph that are endpoints in E'. * edges E' and then selecting all of the nodes from the original graph that
* are endpoints in E'.
* *
* Vertices and edges in the original graph will be "reindexed" to local index. The local * Vertices and edges in the original graph will be "reindexed" to local
* index of the edges preserve the order of the given id array, while the local index * index. The local index of the edges preserve the order of the given id
* of the vertices preserve the index order in the original graph. Edges not in the * array, while the local index of the vertices preserve the index order in
* original graph are ignored. * the original graph. Edges not in the original graph are ignored.
* *
* The result subgraph is read-only. * The result subgraph is read-only.
* *
* \param eids The edges in the subgraph. * \param eids The edges in the subgraph.
* \param preserve_nodes If true, the vertices will not be relabeled, so some vertices * \param preserve_nodes If true, the vertices will not be relabeled, so some
* may have no incident edges. * vertices may have no incident edges.
* \return the induced edge subgraph * \return the induced edge subgraph
*/ */
virtual Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const = 0; virtual Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const = 0;
/*! /*!
* \brief Return the successor vector * \brief Return the successor vector
...@@ -399,14 +393,15 @@ class GraphInterface : public runtime::Object { ...@@ -399,14 +393,15 @@ class GraphInterface : public runtime::Object {
* If the fmt is 'csr', the function should return three arrays, representing * If the fmt is 'csr', the function should return three arrays, representing
* indptr, indices and edge ids * indptr, indices and edge ids
* *
* If the fmt is 'coo', the function should return one array of shape (2, nnz), * If the fmt is 'coo', the function should return one array of shape (2,
* representing a horitonzal stack of row and col indices. * nnz), representing a horitonzal stack of row and col indices.
* *
* \param transpose A flag to transpose the returned adjacency matrix. * \param transpose A flag to transpose the returned adjacency matrix.
* \param fmt the format of the returned adjacency matrix. * \param fmt the format of the returned adjacency matrix.
* \return a vector of IdArrays. * \return a vector of IdArrays.
*/ */
virtual std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const = 0; virtual std::vector<IdArray> GetAdj(
bool transpose, const std::string &fmt) const = 0;
/*! /*!
* \brief Sort the columns in CSR. * \brief Sort the columns in CSR.
...@@ -414,10 +409,9 @@ class GraphInterface : public runtime::Object { ...@@ -414,10 +409,9 @@ class GraphInterface : public runtime::Object {
* This sorts the columns in each row based on the column Ids. * This sorts the columns in each row based on the column Ids.
* The edge ids should be sorted accordingly. * The edge ids should be sorted accordingly.
*/ */
virtual void SortCSR() { virtual void SortCSR() {}
}
static constexpr const char* _type_key = "graph.Graph"; static constexpr const char *_type_key = "graph.Graph";
DGL_DECLARE_OBJECT_TYPE_INFO(GraphInterface, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(GraphInterface, runtime::Object);
}; };
...@@ -430,21 +424,23 @@ struct Subgraph : public runtime::Object { ...@@ -430,21 +424,23 @@ struct Subgraph : public runtime::Object {
GraphPtr graph; GraphPtr graph;
/*! /*!
* \brief The induced vertex ids. * \brief The induced vertex ids.
* \note This is also a map from the new vertex id to the vertex id in the parent graph. * \note This is also a map from the new vertex id to the vertex id in the
* parent graph.
*/ */
IdArray induced_vertices; IdArray induced_vertices;
/*! /*!
* \brief The induced edge ids. * \brief The induced edge ids.
* \note This is also a map from the new edge id to the edge id in the parent graph. * \note This is also a map from the new edge id to the edge id in the parent
* graph.
*/ */
IdArray induced_edges; IdArray induced_edges;
static constexpr const char* _type_key = "graph.Subgraph"; static constexpr const char *_type_key = "graph.Subgraph";
DGL_DECLARE_OBJECT_TYPE_INFO(Subgraph, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(Subgraph, runtime::Object);
}; };
/*! \brief Subgraph data structure for negative subgraph */ /*! \brief Subgraph data structure for negative subgraph */
struct NegSubgraph: public Subgraph { struct NegSubgraph : public Subgraph {
/*! \brief The existence of the negative edges in the parent graph. */ /*! \brief The existence of the negative edges in the parent graph. */
IdArray exist; IdArray exist;
...@@ -456,7 +452,7 @@ struct NegSubgraph: public Subgraph { ...@@ -456,7 +452,7 @@ struct NegSubgraph: public Subgraph {
}; };
/*! \brief Subgraph data structure for halo subgraph */ /*! \brief Subgraph data structure for halo subgraph */
struct HaloSubgraph: public Subgraph { struct HaloSubgraph : public Subgraph {
/*! \brief Indicate if a node belongs to the partition. */ /*! \brief Indicate if a node belongs to the partition. */
IdArray inner_nodes; IdArray inner_nodes;
}; };
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define DGL_GRAPH_OP_H_ #define DGL_GRAPH_OP_H_
#include <vector> #include <vector>
#include "graph.h" #include "graph.h"
#include "immutable_graph.h" #include "immutable_graph.h"
...@@ -17,7 +18,8 @@ class GraphOp { ...@@ -17,7 +18,8 @@ class GraphOp {
/*! /*!
* \brief Return a new graph with all the edges reversed. * \brief Return a new graph with all the edges reversed.
* *
* The returned graph preserves the vertex and edge index in the original graph. * The returned graph preserves the vertex and edge index in the original
* graph.
* *
* \return the reversed graph * \return the reversed graph
*/ */
...@@ -40,10 +42,10 @@ class GraphOp { ...@@ -40,10 +42,10 @@ class GraphOp {
* \brief Return a disjoint union of the input graphs. * \brief Return a disjoint union of the input graphs.
* *
* The new graph will include all the nodes/edges in the given graphs. * The new graph will include all the nodes/edges in the given graphs.
* Nodes/Edges will be relabled by adding the cumsum of the previous graph sizes * Nodes/Edges will be relabled by adding the cumsum of the previous graph
* in the given sequence order. For example, giving input [g1, g2, g3], where * sizes in the given sequence order. For example, giving input [g1, g2, g3],
* they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7 * where they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become
* in the result graph. Edge ids are re-assigned similarly. * node#7 in the result graph. Edge ids are re-assigned similarly.
* *
* The input list must be either ALL mutable graphs or ALL immutable graphs. * The input list must be either ALL mutable graphs or ALL immutable graphs.
* The returned graph type is also determined by the input graph type. * The returned graph type is also determined by the input graph type.
...@@ -62,12 +64,13 @@ class GraphOp { ...@@ -62,12 +64,13 @@ class GraphOp {
* *
* If the input graph is mutable, the result graphs are mutable. * If the input graph is mutable, the result graphs are mutable.
* If the input graph is immutable, the result graphs are immutable. * If the input graph is immutable, the result graphs are immutable.
* *
* \param graph The graph to be partitioned. * \param graph The graph to be partitioned.
* \param num The number of partitions. * \param num The number of partitions.
* \return a list of partitioned graphs * \return a list of partitioned graphs
*/ */
static std::vector<GraphPtr> DisjointPartitionByNum(GraphPtr graph, int64_t num); static std::vector<GraphPtr> DisjointPartitionByNum(
GraphPtr graph, int64_t num);
/*! /*!
* \brief Partition the graph into several subgraphs. * \brief Partition the graph into several subgraphs.
...@@ -78,21 +81,22 @@ class GraphOp { ...@@ -78,21 +81,22 @@ class GraphOp {
* *
* If the input graph is mutable, the result graphs are mutable. * If the input graph is mutable, the result graphs are mutable.
* If the input graph is immutable, the result graphs are immutable. * If the input graph is immutable, the result graphs are immutable.
* *
* \param graph The graph to be partitioned. * \param graph The graph to be partitioned.
* \param sizes The number of partitions. * \param sizes The number of partitions.
* \return a list of partitioned graphs * \return a list of partitioned graphs
*/ */
static std::vector<GraphPtr> DisjointPartitionBySizes(GraphPtr graph, IdArray sizes); static std::vector<GraphPtr> DisjointPartitionBySizes(
GraphPtr graph, IdArray sizes);
/*! /*!
* \brief Map vids in the parent graph to the vids in the subgraph. * \brief Map vids in the parent graph to the vids in the subgraph.
* *
* If the Id doesn't exist in the subgraph, -1 will be used. * If the Id doesn't exist in the subgraph, -1 will be used.
* *
* \param parent_vid_map An array that maps the vids in the parent graph to the * \param parent_vid_map An array that maps the vids in the parent graph to
* subgraph. The elements store the vertex Ids in the parent graph, and the * the subgraph. The elements store the vertex Ids in the parent graph, and
* indices indicate the vertex Ids in the subgraph. * the indices indicate the vertex Ids in the subgraph.
* \param query The vertex Ids in the parent graph. * \param query The vertex Ids in the parent graph.
* \return an Id array that contains the subgraph node Ids. * \return an Id array that contains the subgraph node Ids.
*/ */
...@@ -134,15 +138,18 @@ class GraphOp { ...@@ -134,15 +138,18 @@ class GraphOp {
static GraphPtr ToBidirectedMutableGraph(GraphPtr graph); static GraphPtr ToBidirectedMutableGraph(GraphPtr graph);
/*! /*!
* \brief Same as BidirectedMutableGraph except that the returned graph is immutable. * \brief Same as BidirectedMutableGraph except that the returned graph is
* immutable.
* \param graph The input graph. * \param graph The input graph.
* \return a new immutable bidirected graph. * \return a new immutable bidirected
* graph.
*/ */
static GraphPtr ToBidirectedImmutableGraph(GraphPtr graph); static GraphPtr ToBidirectedImmutableGraph(GraphPtr graph);
/*! /*!
* \brief Same as BidirectedMutableGraph except that the returned graph is immutable * \brief Same as BidirectedMutableGraph except that the returned graph is
* and call gk_csr_MakeSymmetric in GKlib. This is more efficient than ToBidirectedImmutableGraph. * immutable and call gk_csr_MakeSymmetric in GKlib. This is more efficient
* It return a null pointer if the conversion fails. * than ToBidirectedImmutableGraph. It return a null pointer if the conversion
* fails.
* *
* \param graph The input graph. * \param graph The input graph.
* \return a new immutable bidirected graph. * \return a new immutable bidirected graph.
...@@ -151,21 +158,25 @@ class GraphOp { ...@@ -151,21 +158,25 @@ class GraphOp {
/*! /*!
* \brief Get a induced subgraph with HALO nodes. * \brief Get a induced subgraph with HALO nodes.
* The HALO nodes are the ones that can be reached from `nodes` within `num_hops`. * The HALO nodes are the ones that can be reached from `nodes` within
* `num_hops`.
* \param graph The input graph. * \param graph The input graph.
* \param nodes The input nodes that form the core of the induced subgraph. * \param nodes The input nodes that form the core of the induced subgraph.
* \param num_hops The number of hops to reach. * \param num_hops The number of hops to reach.
* \return the induced subgraph with HALO nodes. * \return the induced subgraph with HALO nodes.
*/ */
static HaloSubgraph GetSubgraphWithHalo(GraphPtr graph, IdArray nodes, int num_hops); static HaloSubgraph GetSubgraphWithHalo(
GraphPtr graph, IdArray nodes, int num_hops);
/*! /*!
* \brief Reorder the nodes in the immutable graph. * \brief Reorder the nodes in the immutable graph.
* \param graph The input graph. * \param graph The input graph.
* \param new_order The node Ids in the new graph. The index in `new_order` is old node Ids. * \param new_order The node Ids in the new graph. The index in `new_order` is
* old node Ids.
* \return the graph with reordered node Ids * \return the graph with reordered node Ids
*/ */
static GraphPtr ReorderImmutableGraph(ImmutableGraphPtr ig, IdArray new_order); static GraphPtr ReorderImmutableGraph(
ImmutableGraphPtr ig, IdArray new_order);
}; };
} // namespace dgl } // namespace dgl
......
...@@ -16,7 +16,8 @@ namespace dgl { ...@@ -16,7 +16,8 @@ namespace dgl {
* \brief Class for representing frontiers. * \brief Class for representing frontiers.
* *
* Each frontier is a list of nodes/edges (specified by their ids). * Each frontier is a list of nodes/edges (specified by their ids).
* An optional tag can be specified on each node/edge (represented by an int value). * An optional tag can be specified on each node/edge (represented by an int
* value).
*/ */
struct Frontiers { struct Frontiers {
/*!\brief a vector store for the nodes/edges in all the frontiers */ /*!\brief a vector store for the nodes/edges in all the frontiers */
...@@ -78,10 +79,10 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source); ...@@ -78,10 +79,10 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
* FORWARD(0), REVERSE(1), NONTREE(2) * FORWARD(0), REVERSE(1), NONTREE(2)
* *
* A FORWARD edge is one in which `u` has been visisted but `v` has not. * A FORWARD edge is one in which `u` has been visisted but `v` has not.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the edge * A REVERSE edge is one in which both `u` and `v` have been visisted and the
* is in the DFS tree. * edge is in the DFS tree.
* A NONTREE edge is one in which both `u` and `v` have been visisted but the edge * A NONTREE edge is one in which both `u` and `v` have been visisted but the
* is NOT in the DFS tree. * edge is NOT in the DFS tree.
* *
* \param csr The input csr matrix. * \param csr The input csr matrix.
* \param sources Source nodes. * \param sources Source nodes.
...@@ -90,11 +91,9 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source); ...@@ -90,11 +91,9 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
* \param return_labels If true, return the recorded edge tags. * \param return_labels If true, return the recorded edge tags.
* \return A Frontiers object containing the search result * \return A Frontiers object containing the search result
*/ */
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, Frontiers DGLDFSLabeledEdges(
IdArray source, const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
const bool has_reverse_edge, const bool has_nontree_edge, const bool return_labels);
const bool has_nontree_edge,
const bool return_labels);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -6,17 +6,18 @@ ...@@ -6,17 +6,18 @@
#ifndef DGL_IMMUTABLE_GRAPH_H_ #ifndef DGL_IMMUTABLE_GRAPH_H_
#define DGL_IMMUTABLE_GRAPH_H_ #define DGL_IMMUTABLE_GRAPH_H_
#include <vector>
#include <string>
#include <cstdint>
#include <utility>
#include <tuple>
#include <algorithm> #include <algorithm>
#include <cstdint>
#include <memory> #include <memory>
#include "runtime/ndarray.h" #include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "base_heterograph.h"
#include "graph_interface.h" #include "graph_interface.h"
#include "lazy.h" #include "lazy.h"
#include "base_heterograph.h" #include "runtime/ndarray.h"
namespace dgl { namespace dgl {
...@@ -37,16 +38,16 @@ class CSR : public GraphInterface { ...@@ -37,16 +38,16 @@ class CSR : public GraphInterface {
CSR(int64_t num_vertices, int64_t num_edges); CSR(int64_t num_vertices, int64_t num_edges);
// Create a csr graph whose memory is stored in the shared memory // Create a csr graph whose memory is stored in the shared memory
// that has the given number of verts and edges. // that has the given number of verts and edges.
CSR(const std::string &shared_mem_name, CSR(const std::string &shared_mem_name, int64_t num_vertices,
int64_t num_vertices, int64_t num_edges); int64_t num_edges);
// Create a csr graph that shares the given indptr and indices. // Create a csr graph that shares the given indptr and indices.
CSR(IdArray indptr, IdArray indices, IdArray edge_ids); CSR(IdArray indptr, IdArray indices, IdArray edge_ids);
// Create a csr graph by data iterator // Create a csr graph by data iterator
template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter> template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR(int64_t num_vertices, int64_t num_edges, CSR(int64_t num_vertices, int64_t num_edges, IndptrIter indptr_begin,
IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin); IndicesIter indices_begin, EdgeIdIter edge_ids_begin);
// Create a csr graph whose memory is stored in the shared memory // Create a csr graph whose memory is stored in the shared memory
// and the structure is given by the indptr and indcies. // and the structure is given by the indptr and indcies.
...@@ -65,31 +66,19 @@ class CSR : public GraphInterface { ...@@ -65,31 +66,19 @@ class CSR : public GraphInterface {
LOG(FATAL) << "CSR graph does not allow mutation."; LOG(FATAL) << "CSR graph does not allow mutation.";
} }
void Clear() override { void Clear() override { LOG(FATAL) << "CSR graph does not allow mutation."; }
LOG(FATAL) << "CSR graph does not allow mutation.";
}
DGLContext Context() const override { DGLContext Context() const override { return adj_.indptr->ctx; }
return adj_.indptr->ctx;
}
uint8_t NumBits() const override { uint8_t NumBits() const override { return adj_.indices->dtype.bits; }
return adj_.indices->dtype.bits;
}
bool IsMultigraph() const override; bool IsMultigraph() const override;
bool IsReadonly() const override { bool IsReadonly() const override { return true; }
return true;
}
uint64_t NumVertices() const override { uint64_t NumVertices() const override { return adj_.indptr->shape[0] - 1; }
return adj_.indptr->shape[0] - 1;
}
uint64_t NumEdges() const override { uint64_t NumEdges() const override { return adj_.indices->shape[0]; }
return adj_.indices->shape[0];
}
BoolArray HasVertices(IdArray vids) const override { BoolArray HasVertices(IdArray vids) const override {
LOG(FATAL) << "Not enabled for CSR graph"; LOG(FATAL) << "Not enabled for CSR graph";
...@@ -102,7 +91,7 @@ class CSR : public GraphInterface { ...@@ -102,7 +91,7 @@ class CSR : public GraphInterface {
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override { IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
LOG(FATAL) << "CSR graph does not support efficient predecessor query." LOG(FATAL) << "CSR graph does not support efficient predecessor query."
<< " Please use successors on the reverse CSR graph."; << " Please use successors on the reverse CSR graph.";
return {}; return {};
} }
...@@ -114,25 +103,25 @@ class CSR : public GraphInterface { ...@@ -114,25 +103,25 @@ class CSR : public GraphInterface {
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override { std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
LOG(FATAL) << "CSR graph does not support efficient FindEdge." LOG(FATAL) << "CSR graph does not support efficient FindEdge."
<< " Please use COO graph."; << " Please use COO graph.";
return {}; return {};
} }
EdgeArray FindEdges(IdArray eids) const override { EdgeArray FindEdges(IdArray eids) const override {
LOG(FATAL) << "CSR graph does not support efficient FindEdges." LOG(FATAL) << "CSR graph does not support efficient FindEdges."
<< " Please use COO graph."; << " Please use COO graph.";
return {}; return {};
} }
EdgeArray InEdges(dgl_id_t vid) const override { EdgeArray InEdges(dgl_id_t vid) const override {
LOG(FATAL) << "CSR graph does not support efficient inedges query." LOG(FATAL) << "CSR graph does not support efficient inedges query."
<< " Please use outedges on the reverse CSR graph."; << " Please use outedges on the reverse CSR graph.";
return {}; return {};
} }
EdgeArray InEdges(IdArray vids) const override { EdgeArray InEdges(IdArray vids) const override {
LOG(FATAL) << "CSR graph does not support efficient inedges query." LOG(FATAL) << "CSR graph does not support efficient inedges query."
<< " Please use outedges on the reverse CSR graph."; << " Please use outedges on the reverse CSR graph.";
return {}; return {};
} }
...@@ -144,13 +133,13 @@ class CSR : public GraphInterface { ...@@ -144,13 +133,13 @@ class CSR : public GraphInterface {
uint64_t InDegree(dgl_id_t vid) const override { uint64_t InDegree(dgl_id_t vid) const override {
LOG(FATAL) << "CSR graph does not support efficient indegree query." LOG(FATAL) << "CSR graph does not support efficient indegree query."
<< " Please use outdegree on the reverse CSR graph."; << " Please use outdegree on the reverse CSR graph.";
return 0; return 0;
} }
DegreeArray InDegrees(IdArray vids) const override { DegreeArray InDegrees(IdArray vids) const override {
LOG(FATAL) << "CSR graph does not support efficient indegree query." LOG(FATAL) << "CSR graph does not support efficient indegree query."
<< " Please use outdegree on the reverse CSR graph."; << " Please use outdegree on the reverse CSR graph.";
return {}; return {};
} }
...@@ -162,9 +151,10 @@ class CSR : public GraphInterface { ...@@ -162,9 +151,10 @@ class CSR : public GraphInterface {
Subgraph VertexSubgraph(IdArray vids) const override; Subgraph VertexSubgraph(IdArray vids) const override;
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override { Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const override {
LOG(FATAL) << "CSR graph does not support efficient EdgeSubgraph." LOG(FATAL) << "CSR graph does not support efficient EdgeSubgraph."
<< " Please use COO graph instead."; << " Please use COO graph instead.";
return {}; return {};
} }
...@@ -174,25 +164,24 @@ class CSR : public GraphInterface { ...@@ -174,25 +164,24 @@ class CSR : public GraphInterface {
DGLIdIters PredVec(dgl_id_t vid) const override { DGLIdIters PredVec(dgl_id_t vid) const override {
LOG(FATAL) << "CSR graph does not support efficient PredVec." LOG(FATAL) << "CSR graph does not support efficient PredVec."
<< " Please use SuccVec on the reverse CSR graph."; << " Please use SuccVec on the reverse CSR graph.";
return DGLIdIters(nullptr, nullptr); return DGLIdIters(nullptr, nullptr);
} }
DGLIdIters InEdgeVec(dgl_id_t vid) const override { DGLIdIters InEdgeVec(dgl_id_t vid) const override {
LOG(FATAL) << "CSR graph does not support efficient InEdgeVec." LOG(FATAL) << "CSR graph does not support efficient InEdgeVec."
<< " Please use OutEdgeVec on the reverse CSR graph."; << " Please use OutEdgeVec on the reverse CSR graph.";
return DGLIdIters(nullptr, nullptr); return DGLIdIters(nullptr, nullptr);
} }
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override { std::vector<IdArray> GetAdj(
bool transpose, const std::string &fmt) const override {
CHECK(!transpose && fmt == "csr") << "Not valid adj format request."; CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
return {adj_.indptr, adj_.indices, adj_.data}; return {adj_.indptr, adj_.indices, adj_.data};
} }
/*! \brief Indicate whether this uses shared memory. */ /*! \brief Indicate whether this uses shared memory. */
bool IsSharedMem() const { bool IsSharedMem() const { return !shared_mem_name_.empty(); }
return !shared_mem_name_.empty();
}
/*! \brief Return the reverse of this CSR graph (i.e, a CSC graph) */ /*! \brief Return the reverse of this CSR graph (i.e, a CSC graph) */
CSRPtr Transpose() const; CSRPtr Transpose() const;
...@@ -205,16 +194,14 @@ class CSR : public GraphInterface { ...@@ -205,16 +194,14 @@ class CSR : public GraphInterface {
* \note The csr matrix shares the storage with this graph. * \note The csr matrix shares the storage with this graph.
* The data field of the CSR matrix stores the edge ids. * The data field of the CSR matrix stores the edge ids.
*/ */
aten::CSRMatrix ToCSRMatrix() const { aten::CSRMatrix ToCSRMatrix() const { return adj_; }
return adj_;
}
/*! /*!
* \brief Copy the data to another context. * \brief Copy the data to another context.
* \param ctx The target context. * \param ctx The target context.
* \return The graph under another context. * \return The graph under another context.
*/ */
CSR CopyTo(const DGLContext& ctx) const; CSR CopyTo(const DGLContext &ctx) const;
/*! /*!
* \brief Copy data to shared memory. * \brief Copy data to shared memory.
...@@ -242,11 +229,10 @@ class CSR : public GraphInterface { ...@@ -242,11 +229,10 @@ class CSR : public GraphInterface {
bool Load(dmlc::Stream *fs); bool Load(dmlc::Stream *fs);
/*! \return Save CSR to stream */ /*! \return Save CSR to stream */
void Save(dmlc::Stream* fs) const; void Save(dmlc::Stream *fs) const;
void SortCSR() override { void SortCSR() override {
if (adj_.sorted) if (adj_.sorted) return;
return;
aten::CSRSort_(&adj_); aten::CSRSort_(&adj_);
} }
...@@ -254,7 +240,7 @@ class CSR : public GraphInterface { ...@@ -254,7 +240,7 @@ class CSR : public GraphInterface {
friend class Serializer; friend class Serializer;
/*! \brief private default constructor */ /*! \brief private default constructor */
CSR() {adj_.sorted = false;} CSR() { adj_.sorted = false; }
// The internal CSR adjacency matrix. // The internal CSR adjacency matrix.
// The data field stores edge ids. // The data field stores edge ids.
aten::CSRMatrix adj_; aten::CSRMatrix adj_;
...@@ -267,8 +253,8 @@ class CSR : public GraphInterface { ...@@ -267,8 +253,8 @@ class CSR : public GraphInterface {
class COO : public GraphInterface { class COO : public GraphInterface {
public: public:
// Create a coo graph that shares the given src and dst // Create a coo graph that shares the given src and dst
COO(int64_t num_vertices, IdArray src, IdArray dst, COO(int64_t num_vertices, IdArray src, IdArray dst, bool row_sorted = false,
bool row_sorted = false, bool col_sorted = false); bool col_sorted = false);
// TODO(da): add constructor for creating COO from shared memory // TODO(da): add constructor for creating COO from shared memory
...@@ -284,35 +270,21 @@ class COO : public GraphInterface { ...@@ -284,35 +270,21 @@ class COO : public GraphInterface {
LOG(FATAL) << "COO graph does not allow mutation."; LOG(FATAL) << "COO graph does not allow mutation.";
} }
void Clear() override { void Clear() override { LOG(FATAL) << "COO graph does not allow mutation."; }
LOG(FATAL) << "COO graph does not allow mutation.";
}
DGLContext Context() const override { DGLContext Context() const override { return adj_.row->ctx; }
return adj_.row->ctx;
}
uint8_t NumBits() const override { uint8_t NumBits() const override { return adj_.row->dtype.bits; }
return adj_.row->dtype.bits;
}
bool IsMultigraph() const override; bool IsMultigraph() const override;
bool IsReadonly() const override { bool IsReadonly() const override { return true; }
return true;
}
uint64_t NumVertices() const override { uint64_t NumVertices() const override { return adj_.num_rows; }
return adj_.num_rows;
}
uint64_t NumEdges() const override { uint64_t NumEdges() const override { return adj_.row->shape[0]; }
return adj_.row->shape[0];
}
bool HasVertex(dgl_id_t vid) const override { bool HasVertex(dgl_id_t vid) const override { return vid < NumVertices(); }
return vid < NumVertices();
}
BoolArray HasVertices(IdArray vids) const override { BoolArray HasVertices(IdArray vids) const override {
LOG(FATAL) << "Not enabled for COO graph"; LOG(FATAL) << "Not enabled for COO graph";
...@@ -321,37 +293,37 @@ class COO : public GraphInterface { ...@@ -321,37 +293,37 @@ class COO : public GraphInterface {
bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override { bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {
LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween." LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return false; return false;
} }
BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override { BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override {
LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween." LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return {}; return {};
} }
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override { IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
LOG(FATAL) << "COO graph does not support efficient Predecessors." LOG(FATAL) << "COO graph does not support efficient Predecessors."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return {}; return {};
} }
IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override { IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override {
LOG(FATAL) << "COO graph does not support efficient Successors." LOG(FATAL) << "COO graph does not support efficient Successors."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return {}; return {};
} }
IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override { IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override {
LOG(FATAL) << "COO graph does not support efficient EdgeId." LOG(FATAL) << "COO graph does not support efficient EdgeId."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return {}; return {};
} }
EdgeArray EdgeIds(IdArray src, IdArray dst) const override { EdgeArray EdgeIds(IdArray src, IdArray dst) const override {
LOG(FATAL) << "COO graph does not support efficient EdgeId." LOG(FATAL) << "COO graph does not support efficient EdgeId."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return {}; return {};
} }
...@@ -361,25 +333,25 @@ class COO : public GraphInterface { ...@@ -361,25 +333,25 @@ class COO : public GraphInterface {
EdgeArray InEdges(dgl_id_t vid) const override { EdgeArray InEdges(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient InEdges." LOG(FATAL) << "COO graph does not support efficient InEdges."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return {}; return {};
} }
EdgeArray InEdges(IdArray vids) const override { EdgeArray InEdges(IdArray vids) const override {
LOG(FATAL) << "COO graph does not support efficient InEdges." LOG(FATAL) << "COO graph does not support efficient InEdges."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return {}; return {};
} }
EdgeArray OutEdges(dgl_id_t vid) const override { EdgeArray OutEdges(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient OutEdges." LOG(FATAL) << "COO graph does not support efficient OutEdges."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return {}; return {};
} }
EdgeArray OutEdges(IdArray vids) const override { EdgeArray OutEdges(IdArray vids) const override {
LOG(FATAL) << "COO graph does not support efficient OutEdges." LOG(FATAL) << "COO graph does not support efficient OutEdges."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return {}; return {};
} }
...@@ -387,61 +359,63 @@ class COO : public GraphInterface { ...@@ -387,61 +359,63 @@ class COO : public GraphInterface {
uint64_t InDegree(dgl_id_t vid) const override { uint64_t InDegree(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient InDegree." LOG(FATAL) << "COO graph does not support efficient InDegree."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return 0; return 0;
} }
DegreeArray InDegrees(IdArray vids) const override { DegreeArray InDegrees(IdArray vids) const override {
LOG(FATAL) << "COO graph does not support efficient InDegrees." LOG(FATAL) << "COO graph does not support efficient InDegrees."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return {}; return {};
} }
uint64_t OutDegree(dgl_id_t vid) const override { uint64_t OutDegree(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient OutDegree." LOG(FATAL) << "COO graph does not support efficient OutDegree."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return 0; return 0;
} }
DegreeArray OutDegrees(IdArray vids) const override { DegreeArray OutDegrees(IdArray vids) const override {
LOG(FATAL) << "COO graph does not support efficient OutDegrees." LOG(FATAL) << "COO graph does not support efficient OutDegrees."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return {}; return {};
} }
Subgraph VertexSubgraph(IdArray vids) const override { Subgraph VertexSubgraph(IdArray vids) const override {
LOG(FATAL) << "COO graph does not support efficient VertexSubgraph." LOG(FATAL) << "COO graph does not support efficient VertexSubgraph."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return {}; return {};
} }
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override; Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const override;
DGLIdIters SuccVec(dgl_id_t vid) const override { DGLIdIters SuccVec(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient SuccVec." LOG(FATAL) << "COO graph does not support efficient SuccVec."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return DGLIdIters(nullptr, nullptr); return DGLIdIters(nullptr, nullptr);
} }
DGLIdIters OutEdgeVec(dgl_id_t vid) const override { DGLIdIters OutEdgeVec(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient OutEdgeVec." LOG(FATAL) << "COO graph does not support efficient OutEdgeVec."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return DGLIdIters(nullptr, nullptr); return DGLIdIters(nullptr, nullptr);
} }
DGLIdIters PredVec(dgl_id_t vid) const override { DGLIdIters PredVec(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient PredVec." LOG(FATAL) << "COO graph does not support efficient PredVec."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return DGLIdIters(nullptr, nullptr); return DGLIdIters(nullptr, nullptr);
} }
DGLIdIters InEdgeVec(dgl_id_t vid) const override { DGLIdIters InEdgeVec(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient InEdgeVec." LOG(FATAL) << "COO graph does not support efficient InEdgeVec."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return DGLIdIters(nullptr, nullptr); return DGLIdIters(nullptr, nullptr);
} }
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override { std::vector<IdArray> GetAdj(
bool transpose, const std::string &fmt) const override {
CHECK(fmt == "coo") << "Not valid adj format request."; CHECK(fmt == "coo") << "Not valid adj format request.";
if (transpose) { if (transpose) {
return {aten::HStack(adj_.col, adj_.row)}; return {aten::HStack(adj_.col, adj_.row)};
...@@ -463,16 +437,14 @@ class COO : public GraphInterface { ...@@ -463,16 +437,14 @@ class COO : public GraphInterface {
* \note The coo matrix shares the storage with this graph. * \note The coo matrix shares the storage with this graph.
* The data field of the coo matrix is none. * The data field of the coo matrix is none.
*/ */
aten::COOMatrix ToCOOMatrix() const { aten::COOMatrix ToCOOMatrix() const { return adj_; }
return adj_;
}
/*! /*!
* \brief Copy the data to another context. * \brief Copy the data to another context.
* \param ctx The target context. * \param ctx The target context.
* \return The graph under another context. * \return The graph under another context.
*/ */
COO CopyTo(const DGLContext& ctx) const; COO CopyTo(const DGLContext &ctx) const;
/*! /*!
* \brief Copy data to shared memory. * \brief Copy data to shared memory.
...@@ -489,9 +461,7 @@ class COO : public GraphInterface { ...@@ -489,9 +461,7 @@ class COO : public GraphInterface {
COO AsNumBits(uint8_t bits) const; COO AsNumBits(uint8_t bits) const;
/*! \brief Indicate whether this uses shared memory. */ /*! \brief Indicate whether this uses shared memory. */
bool IsSharedMem() const { bool IsSharedMem() const { return false; }
return false;
}
// member getters // member getters
...@@ -513,40 +483,40 @@ class COO : public GraphInterface { ...@@ -513,40 +483,40 @@ class COO : public GraphInterface {
* *
* DGL's graph is directed. Vertices are integers enumerated from zero. * DGL's graph is directed. Vertices are integers enumerated from zero.
*/ */
class ImmutableGraph: public GraphInterface { class ImmutableGraph : public GraphInterface {
public: public:
/*! \brief Construct an immutable graph from the COO format. */ /*! \brief Construct an immutable graph from the COO format. */
explicit ImmutableGraph(COOPtr coo): coo_(coo) { } explicit ImmutableGraph(COOPtr coo) : coo_(coo) {}
/*! /*!
* \brief Construct an immutable graph from the CSR format. * \brief Construct an immutable graph from the CSR format.
* *
* For a single graph, we need two CSRs, one stores the in-edges of vertices and * For a single graph, we need two CSRs, one stores the in-edges of vertices
* the other stores the out-edges of vertices. These two CSRs stores the same edges. * and the other stores the out-edges of vertices. These two CSRs stores the
* The reason we need both is that some operators are faster on in-edge CSR and * same edges. The reason we need both is that some operators are faster on
* the other operators are faster on out-edge CSR. * in-edge CSR and the other operators are faster on out-edge CSR.
* *
* However, not both CSRs are required. Technically, one CSR contains all information. * However, not both CSRs are required. Technically, one CSR contains all
* Thus, when we construct a temporary graphs (e.g., the sampled subgraphs), we only * information. Thus, when we construct a temporary graphs (e.g., the sampled
* construct one of the CSRs that runs fast for some operations we expect and construct * subgraphs), we only construct one of the CSRs that runs fast for some
* the other CSR on demand. * operations we expect and construct the other CSR on demand.
*/ */
ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr) ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr)
: in_csr_(in_csr), out_csr_(out_csr) { : in_csr_(in_csr), out_csr_(out_csr) {
CHECK(in_csr_ || out_csr_) << "Both CSR are missing."; CHECK(in_csr_ || out_csr_) << "Both CSR are missing.";
} }
/*! \brief Construct an immutable graph from one CSR. */ /*! \brief Construct an immutable graph from one CSR. */
explicit ImmutableGraph(CSRPtr csr): out_csr_(csr) { } explicit ImmutableGraph(CSRPtr csr) : out_csr_(csr) {}
/*! \brief default copy constructor */ /*! \brief default copy constructor */
ImmutableGraph(const ImmutableGraph& other) = default; ImmutableGraph(const ImmutableGraph &other) = default;
#ifndef _MSC_VER #ifndef _MSC_VER
/*! \brief default move constructor */ /*! \brief default move constructor */
ImmutableGraph(ImmutableGraph&& other) = default; ImmutableGraph(ImmutableGraph &&other) = default;
#else #else
ImmutableGraph(ImmutableGraph&& other) { ImmutableGraph(ImmutableGraph &&other) {
this->in_csr_ = other.in_csr_; this->in_csr_ = other.in_csr_;
this->out_csr_ = other.out_csr_; this->out_csr_ = other.out_csr_;
this->coo_ = other.coo_; this->coo_ = other.coo_;
...@@ -557,7 +527,7 @@ class ImmutableGraph: public GraphInterface { ...@@ -557,7 +527,7 @@ class ImmutableGraph: public GraphInterface {
#endif // _MSC_VER #endif // _MSC_VER
/*! \brief default assign constructor */ /*! \brief default assign constructor */
ImmutableGraph& operator=(const ImmutableGraph& other) = default; ImmutableGraph &operator=(const ImmutableGraph &other) = default;
/*! \brief default destructor */ /*! \brief default destructor */
~ImmutableGraph() = default; ~ImmutableGraph() = default;
...@@ -578,28 +548,20 @@ class ImmutableGraph: public GraphInterface { ...@@ -578,28 +548,20 @@ class ImmutableGraph: public GraphInterface {
LOG(FATAL) << "Clear isn't supported in ImmutableGraph"; LOG(FATAL) << "Clear isn't supported in ImmutableGraph";
} }
DGLContext Context() const override { DGLContext Context() const override { return AnyGraph()->Context(); }
return AnyGraph()->Context();
}
uint8_t NumBits() const override { uint8_t NumBits() const override { return AnyGraph()->NumBits(); }
return AnyGraph()->NumBits();
}
/*! /*!
* \note not const since we have caches * \note not const since we have caches
* \return whether the graph is a multigraph * \return whether the graph is a multigraph
*/ */
bool IsMultigraph() const override { bool IsMultigraph() const override { return AnyGraph()->IsMultigraph(); }
return AnyGraph()->IsMultigraph();
}
/*! /*!
* \return whether the graph is read-only * \return whether the graph is read-only
*/ */
bool IsReadonly() const override { bool IsReadonly() const override { return true; }
return true;
}
/** /**
* \brief Check if the graph is unibipartite. * \brief Check if the graph is unibipartite.
...@@ -616,19 +578,13 @@ class ImmutableGraph: public GraphInterface { ...@@ -616,19 +578,13 @@ class ImmutableGraph: public GraphInterface {
} }
/*! \return the number of vertices in the graph.*/ /*! \return the number of vertices in the graph.*/
uint64_t NumVertices() const override { uint64_t NumVertices() const override { return AnyGraph()->NumVertices(); }
return AnyGraph()->NumVertices();
}
/*! \return the number of edges in the graph.*/ /*! \return the number of edges in the graph.*/
uint64_t NumEdges() const override { uint64_t NumEdges() const override { return AnyGraph()->NumEdges(); }
return AnyGraph()->NumEdges();
}
/*! \return true if the given vertex is in the graph.*/ /*! \return true if the given vertex is in the graph.*/
bool HasVertex(dgl_id_t vid) const override { bool HasVertex(dgl_id_t vid) const override { return vid < NumVertices(); }
return vid < NumVertices();
}
BoolArray HasVertices(IdArray vids) const override; BoolArray HasVertices(IdArray vids) const override;
...@@ -652,7 +608,8 @@ class ImmutableGraph: public GraphInterface { ...@@ -652,7 +608,8 @@ class ImmutableGraph: public GraphInterface {
/*! /*!
* \brief Find the predecessors of a vertex. * \brief Find the predecessors of a vertex.
* \param vid The vertex id. * \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1). * \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the predecessor id array. * \return the predecessor id array.
*/ */
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override { IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
...@@ -662,7 +619,8 @@ class ImmutableGraph: public GraphInterface { ...@@ -662,7 +619,8 @@ class ImmutableGraph: public GraphInterface {
/*! /*!
* \brief Find the successors of a vertex. * \brief Find the successors of a vertex.
* \param vid The vertex id. * \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1). * \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the successor id array. * \return the successor id array.
*/ */
IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override { IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override {
...@@ -706,7 +664,8 @@ class ImmutableGraph: public GraphInterface { ...@@ -706,7 +664,8 @@ class ImmutableGraph: public GraphInterface {
/*! /*!
* \brief Find the edge ID and return the pair of endpoints * \brief Find the edge ID and return the pair of endpoints
* \param eid The edge ID * \param eid The edge ID
* \return a pair whose first element is the source and the second the destination. * \return a pair whose first element is the source and the second the
* destination.
*/ */
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override { std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
return GetCOO()->FindEdge(eid); return GetCOO()->FindEdge(eid);
...@@ -715,7 +674,8 @@ class ImmutableGraph: public GraphInterface { ...@@ -715,7 +674,8 @@ class ImmutableGraph: public GraphInterface {
/*! /*!
* \brief Find the edge IDs and return their source and target node IDs. * \brief Find the edge IDs and return their source and target node IDs.
* \param eids The edge ID array. * \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved. * \return EdgeArray containing all edges with id in eid. The order is
* preserved.
*/ */
EdgeArray FindEdges(IdArray eids) const override { EdgeArray FindEdges(IdArray eids) const override {
return GetCOO()->FindEdges(eids); return GetCOO()->FindEdges(eids);
...@@ -728,7 +688,7 @@ class ImmutableGraph: public GraphInterface { ...@@ -728,7 +688,7 @@ class ImmutableGraph: public GraphInterface {
* \return the edges * \return the edges
*/ */
EdgeArray InEdges(dgl_id_t vid) const override { EdgeArray InEdges(dgl_id_t vid) const override {
const EdgeArray& ret = GetInCSR()->OutEdges(vid); const EdgeArray &ret = GetInCSR()->OutEdges(vid);
return {ret.dst, ret.src, ret.id}; return {ret.dst, ret.src, ret.id};
} }
...@@ -738,7 +698,7 @@ class ImmutableGraph: public GraphInterface { ...@@ -738,7 +698,7 @@ class ImmutableGraph: public GraphInterface {
* \return the id arrays of the two endpoints of the edges. * \return the id arrays of the two endpoints of the edges.
*/ */
EdgeArray InEdges(IdArray vids) const override { EdgeArray InEdges(IdArray vids) const override {
const EdgeArray& ret = GetInCSR()->OutEdges(vids); const EdgeArray &ret = GetInCSR()->OutEdges(vids);
return {ret.dst, ret.src, ret.id}; return {ret.dst, ret.src, ret.id};
} }
...@@ -765,7 +725,8 @@ class ImmutableGraph: public GraphInterface { ...@@ -765,7 +725,8 @@ class ImmutableGraph: public GraphInterface {
* \brief Get all the edges in the graph. * \brief Get all the edges in the graph.
* \note If sorted is true, the returned edges list is sorted by their src and * \note If sorted is true, the returned edges list is sorted by their src and
* dst ids. Otherwise, they are in their edge id order. * dst ids. Otherwise, they are in their edge id order.
* \param sorted Whether the returned edge list is sorted by their src and dst ids * \param sorted Whether the returned edge list is sorted by their src and dst
* ids.
* \return the id arrays of the two endpoints of the edges. * \return the id arrays of the two endpoints of the edges.
*/ */
EdgeArray Edges(const std::string &order = "") const override; EdgeArray Edges(const std::string &order = "") const override;
...@@ -809,13 +770,14 @@ class ImmutableGraph: public GraphInterface { ...@@ -809,13 +770,14 @@ class ImmutableGraph: public GraphInterface {
/*! /*!
* \brief Construct the induced subgraph of the given vertices. * \brief Construct the induced subgraph of the given vertices.
* *
* The induced subgraph is a subgraph formed by specifying a set of vertices V' and then * The induced subgraph is a subgraph formed by specifying a set of vertices
* selecting all of the edges from the original graph that connect two vertices in V'. * V' and then selecting all of the edges from the original graph that connect
* two vertices in V'.
* *
* Vertices and edges in the original graph will be "reindexed" to local index. The local * Vertices and edges in the original graph will be "reindexed" to local
* index of the vertices preserve the order of the given id array, while the local index * index. The local index of the vertices preserve the order of the given id
* of the edges preserve the index order in the original graph. Vertices not in the * array, while the local index of the edges preserve the index order in the
* original graph are ignored. * original graph. Vertices not in the original graph are ignored.
* *
* The result subgraph is read-only. * The result subgraph is read-only.
* *
...@@ -827,20 +789,22 @@ class ImmutableGraph: public GraphInterface { ...@@ -827,20 +789,22 @@ class ImmutableGraph: public GraphInterface {
/*! /*!
* \brief Construct the induced edge subgraph of the given edges. * \brief Construct the induced edge subgraph of the given edges.
* *
* The induced edges subgraph is a subgraph formed by specifying a set of edges E' and then * The induced edges subgraph is a subgraph formed by specifying a set of
* selecting all of the nodes from the original graph that are endpoints in E'. * edges E' and then selecting all of the nodes from the original graph that
* are endpoints in E'.
* *
* Vertices and edges in the original graph will be "reindexed" to local index. The local * Vertices and edges in the original graph will be "reindexed" to local
* index of the edges preserve the order of the given id array, while the local index * index. The local index of the edges preserve the order of the given id
* of the vertices preserve the index order in the original graph. Edges not in the * array, while the local index of the vertices preserve the index order in
* original graph are ignored. * the original graph. Edges not in the original graph are ignored.
* *
* The result subgraph is read-only. * The result subgraph is read-only.
* *
* \param eids The edges in the subgraph. * \param eids The edges in the subgraph.
* \return the induced edge subgraph * \return the induced edge subgraph
*/ */
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override; Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const override;
/*! /*!
* \brief Return the successor vector * \brief Return the successor vector
...@@ -887,7 +851,8 @@ class ImmutableGraph: public GraphInterface { ...@@ -887,7 +851,8 @@ class ImmutableGraph: public GraphInterface {
* \param fmt the format of the returned adjacency matrix. * \param fmt the format of the returned adjacency matrix.
* \return a vector of three IdArray. * \return a vector of three IdArray.
*/ */
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override; std::vector<IdArray> GetAdj(
bool transpose, const std::string &fmt) const override;
/* !\brief Return in csr. If not exist, transpose the other one.*/ /* !\brief Return in csr. If not exist, transpose the other one.*/
CSRPtr GetInCSR() const; CSRPtr GetInCSR() const;
...@@ -900,14 +865,15 @@ class ImmutableGraph: public GraphInterface { ...@@ -900,14 +865,15 @@ class ImmutableGraph: public GraphInterface {
/*! \brief Create an immutable graph from CSR. */ /*! \brief Create an immutable graph from CSR. */
static ImmutableGraphPtr CreateFromCSR( static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir); IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir);
static ImmutableGraphPtr CreateFromCSR(const std::string &shared_mem_name); static ImmutableGraphPtr CreateFromCSR(const std::string &shared_mem_name);
/*! \brief Create an immutable graph from COO. */ /*! \brief Create an immutable graph from COO. */
static ImmutableGraphPtr CreateFromCOO( static ImmutableGraphPtr CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst, int64_t num_vertices, IdArray src, IdArray dst, bool row_osrted = false,
bool row_osrted = false, bool col_sorted = false); bool col_sorted = false);
/*! /*!
* \brief Convert the given graph to an immutable graph. * \brief Convert the given graph to an immutable graph.
...@@ -925,14 +891,15 @@ class ImmutableGraph: public GraphInterface { ...@@ -925,14 +891,15 @@ class ImmutableGraph: public GraphInterface {
* \param ctx The target context. * \param ctx The target context.
* \return The graph under another context. * \return The graph under another context.
*/ */
static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DGLContext& ctx); static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DGLContext &ctx);
/*! /*!
* \brief Copy data to shared memory. * \brief Copy data to shared memory.
* \param name The name of the shared memory. * \param name The name of the shared memory.
* \return The graph in the shared memory * \return The graph in the shared memory
*/ */
static ImmutableGraphPtr CopyToSharedMem(ImmutableGraphPtr g, const std::string &name); static ImmutableGraphPtr CopyToSharedMem(
ImmutableGraphPtr g, const std::string &name);
/*! /*!
* \brief Convert the graph to use the given number of bits for storage. * \brief Convert the graph to use the given number of bits for storage.
...@@ -944,7 +911,8 @@ class ImmutableGraph: public GraphInterface { ...@@ -944,7 +911,8 @@ class ImmutableGraph: public GraphInterface {
/*! /*!
* \brief Return a new graph with all the edges reversed. * \brief Return a new graph with all the edges reversed.
* *
* The returned graph preserves the vertex and edge index in the original graph. * The returned graph preserves the vertex and edge index in the original
* graph.
* *
* \return the reversed graph * \return the reversed graph
*/ */
...@@ -954,20 +922,16 @@ class ImmutableGraph: public GraphInterface { ...@@ -954,20 +922,16 @@ class ImmutableGraph: public GraphInterface {
bool Load(dmlc::Stream *fs); bool Load(dmlc::Stream *fs);
/*! \return Save ImmutableGraph to stream, using out csr */ /*! \return Save ImmutableGraph to stream, using out csr */
void Save(dmlc::Stream* fs) const; void Save(dmlc::Stream *fs) const;
void SortCSR() override { void SortCSR() override {
GetInCSR()->SortCSR(); GetInCSR()->SortCSR();
GetOutCSR()->SortCSR(); GetOutCSR()->SortCSR();
} }
bool HasInCSR() const { bool HasInCSR() const { return in_csr_ != NULL; }
return in_csr_ != NULL;
}
bool HasOutCSR() const { bool HasOutCSR() const { return out_csr_ != NULL; }
return out_csr_ != NULL;
}
/*! \brief Cast this graph to a heterograph */ /*! \brief Cast this graph to a heterograph */
HeteroGraphPtr AsHeteroGraph() const; HeteroGraphPtr AsHeteroGraph() const;
...@@ -981,12 +945,13 @@ class ImmutableGraph: public GraphInterface { ...@@ -981,12 +945,13 @@ class ImmutableGraph: public GraphInterface {
/* !\brief internal constructor for all the members */ /* !\brief internal constructor for all the members */
ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo) ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo)
: in_csr_(in_csr), out_csr_(out_csr), coo_(coo) { : in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
CHECK(AnyGraph()) << "At least one graph structure should exist."; CHECK(AnyGraph()) << "At least one graph structure should exist.";
} }
ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, const std::string shared_mem_name) ImmutableGraph(
: in_csr_(in_csr), out_csr_(out_csr) { CSRPtr in_csr, CSRPtr out_csr, const std::string shared_mem_name)
: in_csr_(in_csr), out_csr_(out_csr) {
CHECK(in_csr_ || out_csr_) << "Both CSR are missing."; CHECK(in_csr_ || out_csr_) << "Both CSR are missing.";
this->shared_mem_name_ = shared_mem_name; this->shared_mem_name_ = shared_mem_name;
} }
...@@ -1025,18 +990,19 @@ class ImmutableGraph: public GraphInterface { ...@@ -1025,18 +990,19 @@ class ImmutableGraph: public GraphInterface {
// inline implementations // inline implementations
template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter> template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR::CSR(int64_t num_vertices, int64_t num_edges, CSR::CSR(
IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin) { int64_t num_vertices, int64_t num_edges, IndptrIter indptr_begin,
IndicesIter indices_begin, EdgeIdIter edge_ids_begin) {
// TODO(minjie): this should be changed to a device-agnostic implementation // TODO(minjie): this should be changed to a device-agnostic implementation
// in the future // in the future.
adj_.num_rows = num_vertices; adj_.num_rows = num_vertices;
adj_.num_cols = num_vertices; adj_.num_cols = num_vertices;
adj_.indptr = aten::NewIdArray(num_vertices + 1); adj_.indptr = aten::NewIdArray(num_vertices + 1);
adj_.indices = aten::NewIdArray(num_edges); adj_.indices = aten::NewIdArray(num_edges);
adj_.data = aten::NewIdArray(num_edges); adj_.data = aten::NewIdArray(num_edges);
dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data); dgl_id_t *indptr_data = static_cast<dgl_id_t *>(adj_.indptr->data);
dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data); dgl_id_t *indices_data = static_cast<dgl_id_t *>(adj_.indices->data);
dgl_id_t* edge_ids_data = static_cast<dgl_id_t*>(adj_.data->data); dgl_id_t *edge_ids_data = static_cast<dgl_id_t *>(adj_.data->data);
for (int64_t i = 0; i < num_vertices + 1; ++i) for (int64_t i = 0; i < num_vertices + 1; ++i)
*(indptr_data++) = *(indptr_begin++); *(indptr_data++) = *(indptr_begin++);
for (int64_t i = 0; i < num_edges; ++i) { for (int64_t i = 0; i < num_edges; ++i) {
......
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#define DGL_KERNEL_H_ #define DGL_KERNEL_H_
#include <string> #include <string>
#include <vector>
#include <utility> #include <utility>
#include <vector>
#include "array.h"
#include "./bcast.h"
#include "./base_heterograph.h" #include "./base_heterograph.h"
#include "./bcast.h"
#include "array.h"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -30,12 +30,9 @@ namespace aten { ...@@ -30,12 +30,9 @@ namespace aten {
* as the argmax on source nodes and edges for reduce operators such as * as the argmax on source nodes and edges for reduce operators such as
* `min` and `max`. * `min` and `max`.
*/ */
void SpMM(const std::string& op, const std::string& reduce, void SpMM(
HeteroGraphPtr graph, const std::string& op, const std::string& reduce, HeteroGraphPtr graph,
NDArray ufeat, NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux);
/*! /*!
* \brief Generalized Sampled Dense-Dense Matrix Multiplication. * \brief Generalized Sampled Dense-Dense Matrix Multiplication.
...@@ -46,23 +43,18 @@ void SpMM(const std::string& op, const std::string& reduce, ...@@ -46,23 +43,18 @@ void SpMM(const std::string& op, const std::string& reduce,
* \param vfeat The destination node feature. * \param vfeat The destination node feature.
* \param out The output feature on edge. * \param out The output feature on edge.
*/ */
void SDDMM(const std::string& op, void SDDMM(
HeteroGraphPtr graph, const std::string& op, HeteroGraphPtr graph, NDArray ufeat, NDArray efeat,
NDArray ufeat, NDArray out);
NDArray efeat,
NDArray out);
/*! /*!
* \brief Sparse-sparse matrix multiplication. * \brief Sparse-sparse matrix multiplication.
* *
* The sparse matrices must have scalar weights (i.e. \a A_weights and \a B_weights * The sparse matrices must have scalar weights (i.e. \a A_weights and \a
* are 1D vectors.) * B_weights are 1D vectors.)
*/ */
std::pair<CSRMatrix, NDArray> CSRMM( std::pair<CSRMatrix, NDArray> CSRMM(
CSRMatrix A, CSRMatrix A, NDArray A_weights, CSRMatrix B, NDArray B_weights);
NDArray A_weights,
CSRMatrix B,
NDArray B_weights);
/*! /*!
* \brief Summing up a list of sparse matrices. * \brief Summing up a list of sparse matrices.
...@@ -71,8 +63,7 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -71,8 +63,7 @@ std::pair<CSRMatrix, NDArray> CSRMM(
* are 1D vectors.) * are 1D vectors.)
*/ */
std::pair<CSRMatrix, NDArray> CSRSum( std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& A, const std::vector<CSRMatrix>& A, const std::vector<NDArray>& A_weights);
const std::vector<NDArray>& A_weights);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -22,15 +22,17 @@ class Lazy { ...@@ -22,15 +22,17 @@ class Lazy {
/*!\brief default constructor to construct a lazy object */ /*!\brief default constructor to construct a lazy object */
Lazy() {} Lazy() {}
/*!\brief constructor to construct an object with given value (non-lazy case) */ /*!
explicit Lazy(const T& val): ptr_(new T(val)) {} * \brief constructor to construct an object with given value (non-lazy case)
*/
explicit Lazy(const T& val) : ptr_(new T(val)) {}
/*!\brief destructor */ /*!\brief destructor */
~Lazy() = default; ~Lazy() = default;
/*! /*!
* \brief Get the value of this object. If the object has not been instantiated, * \brief Get the value of this object. If the object has not been
* using the provided function to create it. * instantiated, using the provided function to create it.
* \param fn The creator function. * \param fn The creator function.
* \return the object value. * \return the object value.
*/ */
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
#ifndef DGL_NODEFLOW_H_ #ifndef DGL_NODEFLOW_H_
#define DGL_NODEFLOW_H_ #define DGL_NODEFLOW_H_
#include <vector>
#include <string>
#include <memory> #include <memory>
#include <string>
#include <vector>
#include "./runtime/object.h" #include "./runtime/object.h"
#include "graph_interface.h" #include "graph_interface.h"
...@@ -18,12 +18,12 @@ namespace dgl { ...@@ -18,12 +18,12 @@ namespace dgl {
class ImmutableGraph; class ImmutableGraph;
/*! /*!
* \brief A NodeFlow graph stores the sampling results for a sampler that samples * \brief A NodeFlow graph stores the sampling results for a sampler that
* nodes/edges in layers. * samples nodes/edges in layers.
* *
* We store multiple layers of the sampling results in a single graph, which results * We store multiple layers of the sampling results in a single graph, which
* in a more compact format. We store extra information, * results in a more compact format. We store extra information, such as the
* such as the node and edge mapping from the NodeFlow graph to the parent graph. * node and edge mapping from the NodeFlow graph to the parent graph.
*/ */
struct NodeFlowObject : public runtime::Object { struct NodeFlowObject : public runtime::Object {
/*! \brief The graph. */ /*! \brief The graph. */
...@@ -45,7 +45,7 @@ struct NodeFlowObject : public runtime::Object { ...@@ -45,7 +45,7 @@ struct NodeFlowObject : public runtime::Object {
*/ */
IdArray edge_mapping; IdArray edge_mapping;
static constexpr const char* _type_key = "graph.NodeFlow"; static constexpr const char *_type_key = "graph.NodeFlow";
DGL_DECLARE_OBJECT_TYPE_INFO(NodeFlowObject, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(NodeFlowObject, runtime::Object);
}; };
...@@ -80,8 +80,8 @@ class NodeFlow : public runtime::ObjectRef { ...@@ -80,8 +80,8 @@ class NodeFlow : public runtime::ObjectRef {
* of an edge and the column represents the source. * of an edge and the column represents the source.
* *
* If fmt == "csr", the function returns three arrays: indptr, indices, eid. * If fmt == "csr", the function returns three arrays: indptr, indices, eid.
* If fmt == "coo", the function returns two arrays: idx, eid. Here, the idx array * If fmt == "coo", the function returns two arrays: idx, eid. Here, the idx
* is the concatenation of src and dst node id arrays. * array is the concatenation of src and dst node id arrays.
* *
* \param graph An immutable graph. * \param graph An immutable graph.
* \param fmt the format of the returned adjacency matrix. * \param fmt the format of the returned adjacency matrix.
...@@ -92,11 +92,10 @@ class NodeFlow : public runtime::ObjectRef { ...@@ -92,11 +92,10 @@ class NodeFlow : public runtime::ObjectRef {
* space. * space.
* \return a vector of IdArrays. * \return a vector of IdArrays.
*/ */
std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::string &fmt, std::vector<IdArray> GetNodeFlowSlice(
size_t layer0_size, size_t layer1_start, const ImmutableGraph &graph, const std::string &fmt, size_t layer0_size,
size_t layer1_end, bool remap); size_t layer1_start, size_t layer1_end, bool remap);
} // namespace dgl } // namespace dgl
#endif // DGL_NODEFLOW_H_ #endif // DGL_NODEFLOW_H_
...@@ -7,14 +7,14 @@ ...@@ -7,14 +7,14 @@
#ifndef DGL_PACKED_FUNC_EXT_H_ #ifndef DGL_PACKED_FUNC_EXT_H_
#define DGL_PACKED_FUNC_EXT_H_ #define DGL_PACKED_FUNC_EXT_H_
#include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <memory>
#include <type_traits> #include <type_traits>
#include "./runtime/packed_func.h"
#include "./runtime/object.h"
#include "./runtime/container.h" #include "./runtime/container.h"
#include "./runtime/object.h"
#include "./runtime/packed_func.h"
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
...@@ -22,7 +22,7 @@ namespace runtime { ...@@ -22,7 +22,7 @@ namespace runtime {
* \brief Runtime type checker for node type. * \brief Runtime type checker for node type.
* \tparam T the type to be checked. * \tparam T the type to be checked.
*/ */
template<typename T> template <typename T>
struct ObjectTypeChecker { struct ObjectTypeChecker {
static inline bool Check(Object* sptr) { static inline bool Check(Object* sptr) {
// This is the only place in the project where RTTI is used // This is the only place in the project where RTTI is used
...@@ -31,13 +31,13 @@ struct ObjectTypeChecker { ...@@ -31,13 +31,13 @@ struct ObjectTypeChecker {
using ContainerType = typename T::ContainerType; using ContainerType = typename T::ContainerType;
return sptr->derived_from<ContainerType>(); return sptr->derived_from<ContainerType>();
} }
static inline void PrintName(std::ostringstream& os) { // NOLINT(*) static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType; using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key; os << ContainerType::_type_key;
} }
}; };
template<typename T> template <typename T>
struct ObjectTypeChecker<List<T> > { struct ObjectTypeChecker<List<T> > {
static inline bool Check(Object* sptr) { static inline bool Check(Object* sptr) {
if (sptr == nullptr) return false; if (sptr == nullptr) return false;
...@@ -48,14 +48,14 @@ struct ObjectTypeChecker<List<T> > { ...@@ -48,14 +48,14 @@ struct ObjectTypeChecker<List<T> > {
} }
return true; return true;
} }
static inline void PrintName(std::ostringstream& os) { // NOLINT(*) static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "list<"; os << "list<";
ObjectTypeChecker<T>::PrintName(os); ObjectTypeChecker<T>::PrintName(os);
os << ">"; os << ">";
} }
}; };
template<typename V> template <typename V>
struct ObjectTypeChecker<Map<std::string, V> > { struct ObjectTypeChecker<Map<std::string, V> > {
static inline bool Check(Object* sptr) { static inline bool Check(Object* sptr) {
if (sptr == nullptr) return false; if (sptr == nullptr) return false;
...@@ -66,7 +66,7 @@ struct ObjectTypeChecker<Map<std::string, V> > { ...@@ -66,7 +66,7 @@ struct ObjectTypeChecker<Map<std::string, V> > {
} }
return true; return true;
} }
static inline void PrintName(std::ostringstream& os) { // NOLINT(*) static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<string"; os << "map<string";
os << ','; os << ',';
ObjectTypeChecker<V>::PrintName(os); ObjectTypeChecker<V>::PrintName(os);
...@@ -74,9 +74,7 @@ struct ObjectTypeChecker<Map<std::string, V> > { ...@@ -74,9 +74,7 @@ struct ObjectTypeChecker<Map<std::string, V> > {
} }
}; };
template <typename K, typename V>
template<typename K, typename V>
struct ObjectTypeChecker<Map<K, V> > { struct ObjectTypeChecker<Map<K, V> > {
static inline bool Check(Object* sptr) { static inline bool Check(Object* sptr) {
if (sptr == nullptr) return false; if (sptr == nullptr) return false;
...@@ -88,7 +86,7 @@ struct ObjectTypeChecker<Map<K, V> > { ...@@ -88,7 +86,7 @@ struct ObjectTypeChecker<Map<K, V> > {
} }
return true; return true;
} }
static inline void PrintName(std::ostringstream& os) { // NOLINT(*) static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<"; os << "map<";
ObjectTypeChecker<K>::PrintName(os); ObjectTypeChecker<K>::PrintName(os);
os << ','; os << ',';
...@@ -97,7 +95,7 @@ struct ObjectTypeChecker<Map<K, V> > { ...@@ -97,7 +95,7 @@ struct ObjectTypeChecker<Map<K, V> > {
} }
}; };
template<typename T> template <typename T>
inline std::string NodeTypeName() { inline std::string NodeTypeName() {
std::ostringstream os; std::ostringstream os;
ObjectTypeChecker<T>::PrintName(os); ObjectTypeChecker<T>::PrintName(os);
...@@ -106,7 +104,7 @@ inline std::string NodeTypeName() { ...@@ -106,7 +104,7 @@ inline std::string NodeTypeName() {
// extensions for DGLArgValue // extensions for DGLArgValue
template<typename TObjectRef> template <typename TObjectRef>
inline TObjectRef DGLArgValue::AsObjectRef() const { inline TObjectRef DGLArgValue::AsObjectRef() const {
static_assert( static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value, std::is_base_of<ObjectRef, TObjectRef>::value,
...@@ -115,8 +113,8 @@ inline TObjectRef DGLArgValue::AsObjectRef() const { ...@@ -115,8 +113,8 @@ inline TObjectRef DGLArgValue::AsObjectRef() const {
DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle); DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle);
std::shared_ptr<Object>& sptr = *ptr<std::shared_ptr<Object> >(); std::shared_ptr<Object>& sptr = *ptr<std::shared_ptr<Object> >();
CHECK(ObjectTypeChecker<TObjectRef>::Check(sptr.get())) CHECK(ObjectTypeChecker<TObjectRef>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<TObjectRef>() << "Expected type " << NodeTypeName<TObjectRef>() << " but get "
<< " but get " << sptr->type_key(); << sptr->type_key();
return TObjectRef(sptr); return TObjectRef(sptr);
} }
...@@ -125,12 +123,10 @@ inline std::shared_ptr<Object>& DGLArgValue::obj_sptr() { ...@@ -125,12 +123,10 @@ inline std::shared_ptr<Object>& DGLArgValue::obj_sptr() {
return *ptr<std::shared_ptr<Object> >(); return *ptr<std::shared_ptr<Object> >();
} }
template <typename TObjectRef, typename>
template<typename TObjectRef, typename>
inline bool DGLArgValue::IsObjectType() const { inline bool DGLArgValue::IsObjectType() const {
DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle); DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle);
std::shared_ptr<Object>& sptr = std::shared_ptr<Object>& sptr = *ptr<std::shared_ptr<Object> >();
*ptr<std::shared_ptr<Object> >();
return ObjectTypeChecker<TObjectRef>::Check(sptr.get()); return ObjectTypeChecker<TObjectRef>::Check(sptr.get());
} }
...@@ -155,7 +151,7 @@ inline DGLRetValue& DGLRetValue::operator=(const ObjectRef& other) { ...@@ -155,7 +151,7 @@ inline DGLRetValue& DGLRetValue::operator=(const ObjectRef& other) {
return *this; return *this;
} }
template<typename TObjectRef> template <typename TObjectRef>
inline TObjectRef DGLRetValue::AsObjectRef() const { inline TObjectRef DGLRetValue::AsObjectRef() const {
static_assert( static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value, std::is_base_of<ObjectRef, TObjectRef>::value,
...@@ -165,7 +161,8 @@ inline TObjectRef DGLRetValue::AsObjectRef() const { ...@@ -165,7 +161,8 @@ inline TObjectRef DGLRetValue::AsObjectRef() const {
return TObjectRef(*ptr<std::shared_ptr<Object> >()); return TObjectRef(*ptr<std::shared_ptr<Object> >());
} }
inline void DGLArgsSetter::operator()(size_t i, const ObjectRef& other) const { // NOLINT(*) inline void DGLArgsSetter::operator()(
size_t i, const ObjectRef& other) const { // NOLINT(*)
if (other.defined()) { if (other.defined()) {
values_[i].v_handle = const_cast<std::shared_ptr<Object>*>(&(other.obj_)); values_[i].v_handle = const_cast<std::shared_ptr<Object>*>(&(other.obj_));
type_codes_[i] = kObjectHandle; type_codes_[i] = kObjectHandle;
......
...@@ -8,8 +8,9 @@ ...@@ -8,8 +8,9 @@
#define DGL_RANDOM_H_ #define DGL_RANDOM_H_
#include <dgl/array.h> #include <dgl/array.h>
#include <dmlc/thread_local.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <random> #include <random>
#include <thread> #include <thread>
#include <vector> #include <vector>
...@@ -46,33 +47,27 @@ class RandomEngine { ...@@ -46,33 +47,27 @@ class RandomEngine {
} }
/*! \brief Constructor with given seed */ /*! \brief Constructor with given seed */
explicit RandomEngine(uint32_t seed) { explicit RandomEngine(uint32_t seed) { SetSeed(seed); }
SetSeed(seed);
}
/*! \brief Get the thread-local random number generator instance */ /*! \brief Get the thread-local random number generator instance */
static RandomEngine *ThreadLocal() { static RandomEngine* ThreadLocal() {
return dmlc::ThreadLocalStore<RandomEngine>::Get(); return dmlc::ThreadLocalStore<RandomEngine>::Get();
} }
/*! /*!
* \brief Set the seed of this random number generator * \brief Set the seed of this random number generator
*/ */
void SetSeed(uint32_t seed) { void SetSeed(uint32_t seed) { rng_.seed(seed + GetThreadId()); }
rng_.seed(seed + GetThreadId());
}
/*! /*!
* \brief Generate an arbitrary random 32-bit integer. * \brief Generate an arbitrary random 32-bit integer.
*/ */
int32_t RandInt32() { int32_t RandInt32() { return static_cast<int32_t>(rng_()); }
return static_cast<int32_t>(rng_());
}
/*! /*!
* \brief Generate a uniform random integer in [0, upper) * \brief Generate a uniform random integer in [0, upper)
*/ */
template<typename T> template <typename T>
T RandInt(T upper) { T RandInt(T upper) {
return RandInt<T>(0, upper); return RandInt<T>(0, upper);
} }
...@@ -80,7 +75,7 @@ class RandomEngine { ...@@ -80,7 +75,7 @@ class RandomEngine {
/*! /*!
* \brief Generate a uniform random integer in [lower, upper) * \brief Generate a uniform random integer in [lower, upper)
*/ */
template<typename T> template <typename T>
T RandInt(T lower, T upper) { T RandInt(T lower, T upper) {
CHECK_LT(lower, upper); CHECK_LT(lower, upper);
std::uniform_int_distribution<T> dist(lower, upper - 1); std::uniform_int_distribution<T> dist(lower, upper - 1);
...@@ -90,7 +85,7 @@ class RandomEngine { ...@@ -90,7 +85,7 @@ class RandomEngine {
/*! /*!
* \brief Generate a uniform random float in [0, 1) * \brief Generate a uniform random float in [0, 1)
*/ */
template<typename T> template <typename T>
T Uniform() { T Uniform() {
return Uniform<T>(0., 1.); return Uniform<T>(0., 1.);
} }
...@@ -98,7 +93,7 @@ class RandomEngine { ...@@ -98,7 +93,7 @@ class RandomEngine {
/*! /*!
* \brief Generate a uniform random float in [lower, upper) * \brief Generate a uniform random float in [lower, upper)
*/ */
template<typename T> template <typename T>
T Uniform(T lower, T upper) { T Uniform(T lower, T upper) {
// Although the result is in [lower, upper), we allow lower == upper as in // Although the result is in [lower, upper), we allow lower == upper as in
// www.cplusplus.com/reference/random/uniform_real_distribution/uniform_real_distribution/ // www.cplusplus.com/reference/random/uniform_real_distribution/uniform_real_distribution/
...@@ -108,23 +103,27 @@ class RandomEngine { ...@@ -108,23 +103,27 @@ class RandomEngine {
} }
/*! /*!
* \brief Pick a random integer between 0 to N-1 according to given probabilities * \brief Pick a random integer between 0 to N-1 according to given
* \tparam IdxType Return integer type * probabilities.
* \param prob Array of N unnormalized probability of each element. Must be non-negative. * \tparam IdxType Return integer type.
* \param prob Array of N unnormalized probability of each element. Must be
* non-negative.
* \return An integer randomly picked from 0 to N-1. * \return An integer randomly picked from 0 to N-1.
*/ */
template<typename IdxType> template <typename IdxType>
IdxType Choice(FloatArray prob); IdxType Choice(FloatArray prob);
/*! /*!
* \brief Pick random integers between 0 to N-1 according to given probabilities * \brief Pick random integers between 0 to N-1 according to given
* * probabilities
*
* If replace is false, the number of picked integers must not larger than N. * If replace is false, the number of picked integers must not larger than N.
* *
* \tparam IdxType Id type * \tparam IdxType Id type
* \tparam FloatType Probability value type * \tparam FloatType Probability value type
* \param num Number of integers to choose * \param num Number of integers to choose
* \param prob Array of N unnormalized probability of each element. Must be non-negative. * \param prob Array of N unnormalized probability of each element. Must be
* non-negative.
* \param out The output buffer to write selected indices. * \param out The output buffer to write selected indices.
* \param replace If true, choose with replacement. * \param replace If true, choose with replacement.
*/ */
...@@ -132,14 +131,16 @@ class RandomEngine { ...@@ -132,14 +131,16 @@ class RandomEngine {
void Choice(IdxType num, FloatArray prob, IdxType* out, bool replace = true); void Choice(IdxType num, FloatArray prob, IdxType* out, bool replace = true);
/*! /*!
* \brief Pick random integers between 0 to N-1 according to given probabilities * \brief Pick random integers between 0 to N-1 according to given
* * probabilities
*
* If replace is false, the number of picked integers must not larger than N. * If replace is false, the number of picked integers must not larger than N.
* *
* \tparam IdxType Id type * \tparam IdxType Id type
* \tparam FloatType Probability value type * \tparam FloatType Probability value type
* \param num Number of integers to choose * \param num Number of integers to choose
* \param prob Array of N unnormalized probability of each element. Must be non-negative. * \param prob Array of N unnormalized probability of each element. Must be
* non-negative.
* \param replace If true, choose with replacement. * \param replace If true, choose with replacement.
* \return Picked indices * \return Picked indices
*/ */
...@@ -147,7 +148,8 @@ class RandomEngine { ...@@ -147,7 +148,8 @@ class RandomEngine {
IdArray Choice(IdxType num, FloatArray prob, bool replace = true) { IdArray Choice(IdxType num, FloatArray prob, bool replace = true) {
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1}; const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, prob->ctx); IdArray ret = IdArray::Empty({num}, dtype, prob->ctx);
Choice<IdxType, FloatType>(num, prob, static_cast<IdxType*>(ret->data), replace); Choice<IdxType, FloatType>(
num, prob, static_cast<IdxType*>(ret->data), replace);
return ret; return ret;
} }
...@@ -163,7 +165,8 @@ class RandomEngine { ...@@ -163,7 +165,8 @@ class RandomEngine {
* \param replace If true, choose with replacement. * \param replace If true, choose with replacement.
*/ */
template <typename IdxType> template <typename IdxType>
void UniformChoice(IdxType num, IdxType population, IdxType* out, bool replace = true); void UniformChoice(
IdxType num, IdxType population, IdxType* out, bool replace = true);
/*! /*!
* \brief Pick random integers from population by uniform distribution. * \brief Pick random integers from population by uniform distribution.
...@@ -181,43 +184,48 @@ class RandomEngine { ...@@ -181,43 +184,48 @@ class RandomEngine {
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1}; const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
// TODO(minjie): only CPU implementation right now // TODO(minjie): only CPU implementation right now
IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0}); IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0});
UniformChoice<IdxType>(num, population, static_cast<IdxType*>(ret->data), replace); UniformChoice<IdxType>(
num, population, static_cast<IdxType*>(ret->data), replace);
return ret; return ret;
} }
/*! /*!
* \brief Pick random integers with different probability for different segments. * \brief Pick random integers with different probability for different
* segments.
* *
* For example, if split=[0, 4, 10] and bias=[1.5, 1], it means to pick some integers * For example, if split=[0, 4, 10] and bias=[1.5, 1], it means to pick some
* from 0 to 9, which is divided into two segments. 0-3 are in the first segment and the rest * integers from 0 to 9, which is divided into two segments. 0-3 are in the
* belongs to the second. The weight(bias) of each candidate in the first segment is upweighted * first segment and the rest belongs to the second. The weight(bias) of each
* to 1.5. * candidate in the first segment is upweighted to 1.5.
* *
* candidate | 0 1 2 3 | 4 5 6 7 8 9 | * candidate | 0 1 2 3 | 4 5 6 7 8 9 |
* split ^ ^ ^ * split ^ ^ ^
* bias | 1.5 | 1 | * bias | 1.5 | 1 |
* *
* *
* The complexity of this operator is O(k * log(T)) where k is the number of integers we want * The complexity of this operator is O(k * log(T)) where k is the number of
* to pick, and T is the number of segments. It is much faster compared with assigning * integers we want to pick, and T is the number of segments. It is much
* probability for each candidate, of which the complexity is O(k * log(N)) where N is the * faster compared with assigning probability for each candidate, of which the
* number of all candidates. * complexity is O(k * log(N)) where N is the number of all candidates.
* *
* If replace is false, num must not be larger than population. * If replace is false, num must not be larger than population.
* *
* \tparam IdxType Return integer type * \tparam IdxType Return integer type
* \param num Number of integers to choose * \param num Number of integers to choose
* \param split Array of T+1 split positions of different segments(including start and end) * \param split Array of T+1 split positions of different segments(including
* \param bias Array of T weight of each segments * start and end)
* \param bias Array of T weight of each segments.
* \param out The output buffer to write selected indices. * \param out The output buffer to write selected indices.
* \param replace If true, choose with replacement. * \param replace If true, choose with replacement.
*/ */
template <typename IdxType, typename FloatType> template <typename IdxType, typename FloatType>
void BiasedChoice( void BiasedChoice(
IdxType num, const IdxType *split, FloatArray bias, IdxType* out, bool replace = true); IdxType num, const IdxType* split, FloatArray bias, IdxType* out,
bool replace = true);
/*! /*!
* \brief Pick random integers with different probability for different segments. * \brief Pick random integers with different probability for different
* segments.
* *
* If replace is false, num must not be larger than population. * If replace is false, num must not be larger than population.
* *
...@@ -229,10 +237,11 @@ class RandomEngine { ...@@ -229,10 +237,11 @@ class RandomEngine {
*/ */
template <typename IdxType, typename FloatType> template <typename IdxType, typename FloatType>
IdArray BiasedChoice( IdArray BiasedChoice(
IdxType num, const IdxType *split, FloatArray bias, bool replace = true) { IdxType num, const IdxType* split, FloatArray bias, bool replace = true) {
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1}; const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0}); IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0});
BiasedChoice<IdxType, FloatType>(num, split, bias, static_cast<IdxType*>(ret->data), replace); BiasedChoice<IdxType, FloatType>(
num, split, bias, static_cast<IdxType*>(ret->data), replace);
return ret; return ret;
} }
......
...@@ -27,9 +27,8 @@ extern "C" { ...@@ -27,9 +27,8 @@ extern "C" {
* \param out The result function. * \param out The result function.
* \return 0 when no error is thrown, -1 when failure happens * \return 0 when no error is thrown, -1 when failure happens
*/ */
DGL_DLL int DGLBackendGetFuncFromEnv(void* mod_node, DGL_DLL int DGLBackendGetFuncFromEnv(
const char* func_name, void* mod_node, const char* func_name, DGLFunctionHandle* out);
DGLFunctionHandle *out);
/*! /*!
* \brief Backend function to register system-wide library symbol. * \brief Backend function to register system-wide library symbol.
* *
...@@ -42,22 +41,21 @@ DGL_DLL int DGLBackendRegisterSystemLibSymbol(const char* name, void* ptr); ...@@ -42,22 +41,21 @@ DGL_DLL int DGLBackendRegisterSystemLibSymbol(const char* name, void* ptr);
/*! /*!
* \brief Backend function to allocate temporal workspace. * \brief Backend function to allocate temporal workspace.
* *
* \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment. * \note The result allocate spaced is ensured to be aligned to
* kTempAllocaAlignment.
* *
* \param nbytes The size of the space requested. * \param nbytes The size of the space requested.
* \param device_type The device type which the space will be allocated. * \param device_type The device type which the space will be allocated.
* \param device_id The device id which the space will be allocated. * \param device_id The device id which the space will be allocated.
* \param dtype_code_hint The type code of the array elements. Only used in * \param dtype_code_hint The type code of the array elements. Only used in
* certain backends such as OpenGL. * certain backends such as OpenGL.
* \param dtype_bits_hint The type bits of the array elements. Only used in * \param dtype_bits_hint The type bits of the array elements. Only used in
* certain backends such as OpenGL. * certain backends such as OpenGL.
* \return nullptr when error is thrown, a valid ptr if success * \return nullptr when error is thrown, a valid ptr if success
*/ */
DGL_DLL void* DGLBackendAllocWorkspace(int device_type, DGL_DLL void* DGLBackendAllocWorkspace(
int device_id, int device_type, int device_id, uint64_t nbytes, int dtype_code_hint,
uint64_t nbytes, int dtype_bits_hint);
int dtype_code_hint,
int dtype_bits_hint);
/*! /*!
* \brief Backend function to free temporal workspace. * \brief Backend function to free temporal workspace.
...@@ -69,9 +67,7 @@ DGL_DLL void* DGLBackendAllocWorkspace(int device_type, ...@@ -69,9 +67,7 @@ DGL_DLL void* DGLBackendAllocWorkspace(int device_type,
* *
* \sa DGLBackendAllocWorkspace * \sa DGLBackendAllocWorkspace
*/ */
DGL_DLL int DGLBackendFreeWorkspace(int device_type, DGL_DLL int DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr);
int device_id,
void* ptr);
/*! /*!
* \brief Environment for DGL parallel task. * \brief Environment for DGL parallel task.
...@@ -100,13 +96,12 @@ typedef int (*FDGLParallelLambda)( ...@@ -100,13 +96,12 @@ typedef int (*FDGLParallelLambda)(
* \param flambda The parallel function to be launched. * \param flambda The parallel function to be launched.
* \param cdata The closure data. * \param cdata The closure data.
* \param num_task Number of tasks to launch, can be 0, means launch * \param num_task Number of tasks to launch, can be 0, means launch
* with all available threads. * with all available threads.
* *
* \return 0 when no error is thrown, -1 when failure happens * \return 0 when no error is thrown, -1 when failure happens
*/ */
DGL_DLL int DGLBackendParallelLaunch(FDGLParallelLambda flambda, DGL_DLL int DGLBackendParallelLaunch(
void* cdata, FDGLParallelLambda flambda, void* cdata, int num_task);
int num_task);
/*! /*!
* \brief BSP barrrier between parallel threads * \brief BSP barrrier between parallel threads
...@@ -116,7 +111,6 @@ DGL_DLL int DGLBackendParallelLaunch(FDGLParallelLambda flambda, ...@@ -116,7 +111,6 @@ DGL_DLL int DGLBackendParallelLaunch(FDGLParallelLambda flambda,
*/ */
DGL_DLL int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv); DGL_DLL int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv);
/*! /*!
* \brief Simple static initialization fucntion. * \brief Simple static initialization fucntion.
* Run f once and set handle to be not null. * Run f once and set handle to be not null.
...@@ -128,10 +122,8 @@ DGL_DLL int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv); ...@@ -128,10 +122,8 @@ DGL_DLL int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv);
* \param nbytes Number of bytes in the closure data. * \param nbytes Number of bytes in the closure data.
* \return 0 when no error is thrown, -1 when failure happens * \return 0 when no error is thrown, -1 when failure happens
*/ */
DGL_DLL int DGLBackendRunOnce(void** handle, DGL_DLL int DGLBackendRunOnce(
int (*f)(void*), void** handle, int (*f)(void*), void* cdata, int nbytes);
void *cdata,
int nbytes);
#ifdef __cplusplus #ifdef __cplusplus
} // DGL_EXTERN_C } // DGL_EXTERN_C
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment