"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "97c59be6536c1f7c4779c0a2ecd34efb878dfb13"
Unverified Commit 07dc8fb6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4800)



* clang-format

* manul

* manul

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 9d9280cb
......@@ -15,9 +15,9 @@ namespace featgraph {
void LoadFeatGraphModule(const std::string& path);
/* \brief Call Featgraph's SDDMM kernel. */
void SDDMMTreeReduction(DLManagedTensor* row, DLManagedTensor* col,
DLManagedTensor* lhs, DLManagedTensor* rhs,
DLManagedTensor* out);
void SDDMMTreeReduction(
DLManagedTensor* row, DLManagedTensor* col, DLManagedTensor* lhs,
DLManagedTensor* rhs, DLManagedTensor* out);
} // namespace featgraph
} // namespace dgl
......
......@@ -3,18 +3,18 @@
* \file featgraph/src/featgraph.cc
* \brief FeatGraph kernels.
*/
#include <dmlc/logging.h>
#include <featgraph.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <dmlc/logging.h>
#include <featgraph.h>
namespace dgl {
namespace featgraph {
/* \brief Singleton that loads the featgraph module. */
class FeatGraphModule {
public:
public:
static FeatGraphModule* Global() {
static FeatGraphModule inst;
return &inst;
......@@ -32,7 +32,8 @@ public:
}
return ret;
}
private:
private:
tvm::runtime::Module mod;
FeatGraphModule() {}
};
......@@ -44,34 +45,36 @@ void LoadFeatGraphModule(const std::string& path) {
/* \brief Convert DLDataType to string. */
inline std::string DTypeAsStr(const DLDataType& t) {
switch(t.code) {
case 0U: return "int" + std::to_string(t.bits);
case 1U: return "uint" + std::to_string(t.bits);
case 2U: return "float" + std::to_string(t.bits);
case 3U: return "bfloat" + std::to_string(t.bits);
default: LOG(FATAL) << "Type code " << t.code << " not recognized";
switch (t.code) {
case 0U:
return "int" + std::to_string(t.bits);
case 1U:
return "uint" + std::to_string(t.bits);
case 2U:
return "float" + std::to_string(t.bits);
case 3U:
return "bfloat" + std::to_string(t.bits);
default:
LOG(FATAL) << "Type code " << t.code << " not recognized";
}
}
/* \brief Get operator filename. */
inline std::string GetOperatorName(
const std::string& base_name,
const DLDataType& dtype,
const std::string& base_name, const DLDataType& dtype,
const DLDataType& idtype) {
return base_name + "_" + DTypeAsStr(dtype) + "_" + DTypeAsStr(idtype);
}
/* \brief Call FeatGraph's SDDMM kernel. */
void SDDMMTreeReduction(DLManagedTensor* row, DLManagedTensor* col,
DLManagedTensor* lhs, DLManagedTensor* rhs,
DLManagedTensor* out) {
void SDDMMTreeReduction(
DLManagedTensor* row, DLManagedTensor* col, DLManagedTensor* lhs,
DLManagedTensor* rhs, DLManagedTensor* out) {
tvm::runtime::ModuleNode* mod = FeatGraphModule::Global()->Get();
std::string f_name = GetOperatorName("SDDMMTreeReduction",
(row->dl_tensor).dtype,
(lhs->dl_tensor).dtype);
std::string f_name = GetOperatorName(
"SDDMMTreeReduction", (row->dl_tensor).dtype, (lhs->dl_tensor).dtype);
tvm::runtime::PackedFunc f = mod->GetFunction(f_name);
if (f != nullptr)
f(row, col, lhs, rhs, out);
if (f != nullptr) f(row, col, lhs, rhs, out);
}
} // namespace featgraph
......
/*
* NOTE(zihao): this file was modified from TVM project:
* - https://github.com/apache/tvm/blob/9713d675c64ae3075e10be5acadeef1328a44bb5/apps/howto_deploy/tvm_runtime_pack.cc
*
* -
* https://github.com/apache/tvm/blob/9713d675c64ae3075e10be5acadeef1328a44bb5/apps/howto_deploy/tvm_runtime_pack.cc
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
......
......@@ -8,10 +8,10 @@
*/
#ifndef DGL_ARRAY_H_
#define DGL_ARRAY_H_
#include "./aten/types.h"
#include "./aten/array_ops.h"
#include "./aten/coo.h"
#include "./aten/csr.h"
#include "./aten/macro.h"
#include "./aten/spmat.h"
#include "./aten/csr.h"
#include "./aten/coo.h"
#include "./aten/types.h"
#endif // DGL_ARRAY_H_
......@@ -12,8 +12,8 @@
#define CUB_INLINE inline
#endif // __CUDA_ARCH__
#include <iterator>
#include <algorithm>
#include <iterator>
#include <utility>
namespace dgl {
......@@ -51,8 +51,8 @@ CUB_INLINE void swap(const Pair<DType>& r1, const Pair<DType>& r2) {
r1.swap(r2);
}
// PairRef and PairIterator that serves as an iterator over a pair of arrays in a
// zipped fashion like zip(a, b).
// PairRef and PairIterator that serves as an iterator over a pair of arrays in
// a zipped fashion like zip(a, b).
template <typename DType>
struct PairRef {
PairRef() = delete;
......@@ -69,9 +69,7 @@ struct PairRef {
*b = val.second;
return *this;
}
CUB_INLINE operator Pair<DType>() const {
return Pair<DType>(*a, *b);
}
CUB_INLINE operator Pair<DType>() const { return Pair<DType>(*a, *b); }
CUB_INLINE operator std::pair<DType, DType>() const {
return std::make_pair(*a, *b);
}
......@@ -91,11 +89,9 @@ CUB_INLINE void swap(const PairRef<DType>& r1, const PairRef<DType>& r2) {
}
template <typename DType>
struct PairIterator : public std::iterator<std::random_access_iterator_tag,
Pair<DType>,
std::ptrdiff_t,
Pair<DType*>,
PairRef<DType>> {
struct PairIterator : public std::iterator<
std::random_access_iterator_tag, Pair<DType>,
std::ptrdiff_t, Pair<DType*>, PairRef<DType>> {
PairIterator() = default;
PairIterator(const PairIterator& other) = default;
PairIterator(PairIterator&& other) = default;
......@@ -103,12 +99,24 @@ struct PairIterator : public std::iterator<std::random_access_iterator_tag,
PairIterator& operator=(const PairIterator& other) = default;
PairIterator& operator=(PairIterator&& other) = default;
~PairIterator() = default;
CUB_INLINE bool operator==(const PairIterator& other) const { return a == other.a; }
CUB_INLINE bool operator!=(const PairIterator& other) const { return a != other.a; }
CUB_INLINE bool operator<(const PairIterator& other) const { return a < other.a; }
CUB_INLINE bool operator>(const PairIterator& other) const { return a > other.a; }
CUB_INLINE bool operator<=(const PairIterator& other) const { return a <= other.a; }
CUB_INLINE bool operator>=(const PairIterator& other) const { return a >= other.a; }
CUB_INLINE bool operator==(const PairIterator& other) const {
return a == other.a;
}
CUB_INLINE bool operator!=(const PairIterator& other) const {
return a != other.a;
}
CUB_INLINE bool operator<(const PairIterator& other) const {
return a < other.a;
}
CUB_INLINE bool operator>(const PairIterator& other) const {
return a > other.a;
}
CUB_INLINE bool operator<=(const PairIterator& other) const {
return a <= other.a;
}
CUB_INLINE bool operator>=(const PairIterator& other) const {
return a >= other.a;
}
CUB_INLINE PairIterator& operator+=(const std::ptrdiff_t& movement) {
a += movement;
b += movement;
......@@ -148,12 +156,8 @@ struct PairIterator : public std::iterator<std::random_access_iterator_tag,
CUB_INLINE std::ptrdiff_t operator-(const PairIterator& other) const {
return a - other.a;
}
CUB_INLINE PairRef<DType> operator*() const {
return PairRef<DType>(a, b);
}
CUB_INLINE PairRef<DType> operator*() {
return PairRef<DType>(a, b);
}
CUB_INLINE PairRef<DType> operator*() const { return PairRef<DType>(a, b); }
CUB_INLINE PairRef<DType> operator*() { return PairRef<DType>(a, b); }
CUB_INLINE PairRef<DType> operator[](size_t offset) const {
return PairRef<DType>(a + offset, b + offset);
}
......
......@@ -8,8 +8,9 @@
#include <string>
#include <vector>
#include "./types.h"
#include "../runtime/object.h"
#include "./types.h"
namespace dgl {
......@@ -56,31 +57,28 @@ inline std::string ToStringSparseFormat(SparseFormat sparse_format) {
inline std::vector<SparseFormat> CodeToSparseFormats(dgl_format_code_t code) {
std::vector<SparseFormat> ret;
if (code & COO_CODE)
ret.push_back(SparseFormat::kCOO);
if (code & CSR_CODE)
ret.push_back(SparseFormat::kCSR);
if (code & CSC_CODE)
ret.push_back(SparseFormat::kCSC);
if (code & COO_CODE) ret.push_back(SparseFormat::kCOO);
if (code & CSR_CODE) ret.push_back(SparseFormat::kCSR);
if (code & CSC_CODE) ret.push_back(SparseFormat::kCSC);
return ret;
}
inline dgl_format_code_t
SparseFormatsToCode(const std::vector<SparseFormat> &formats) {
inline dgl_format_code_t SparseFormatsToCode(
const std::vector<SparseFormat>& formats) {
dgl_format_code_t ret = 0;
for (auto format : formats) {
switch (format) {
case SparseFormat::kCOO:
ret |= COO_CODE;
break;
case SparseFormat::kCSR:
ret |= CSR_CODE;
break;
case SparseFormat::kCSC:
ret |= CSC_CODE;
break;
default:
LOG(FATAL) << "Only support COO/CSR/CSC formats.";
case SparseFormat::kCOO:
ret |= COO_CODE;
break;
case SparseFormat::kCSR:
ret |= CSR_CODE;
break;
case SparseFormat::kCSC:
ret |= CSC_CODE;
break;
default:
LOG(FATAL) << "Only support COO/CSR/CSC formats.";
}
}
return ret;
......@@ -88,20 +86,15 @@ SparseFormatsToCode(const std::vector<SparseFormat> &formats) {
inline std::string CodeToStr(dgl_format_code_t code) {
std::string ret = "";
if (code & COO_CODE)
ret += "coo ";
if (code & CSR_CODE)
ret += "csr ";
if (code & CSC_CODE)
ret += "csc ";
if (code & COO_CODE) ret += "coo ";
if (code & CSR_CODE) ret += "csr ";
if (code & CSC_CODE) ret += "csc ";
return ret;
}
inline SparseFormat DecodeFormat(dgl_format_code_t code) {
if (code & COO_CODE)
return SparseFormat::kCOO;
if (code & CSC_CODE)
return SparseFormat::kCSC;
if (code & COO_CODE) return SparseFormat::kCOO;
if (code & CSC_CODE) return SparseFormat::kCSC;
return SparseFormat::kCSR;
}
......@@ -113,20 +106,25 @@ struct SparseMatrix : public runtime::Object {
// Shape of this matrix.
int64_t num_rows = 0, num_cols = 0;
// Index arrays. For CSR, it is {indptr, indices, data}. For COO, it is {row, col, data}.
// Index arrays. For CSR, it is {indptr, indices, data}. For COO, it is {row,
// col, data}.
std::vector<IdArray> indices;
// Boolean flags.
// TODO(minjie): We might revisit this later to provide a more general solution. Currently,
// we only consider aten::COOMatrix and aten::CSRMatrix.
// TODO(minjie): We might revisit this later to provide a more general
// solution. Currently, we only consider aten::COOMatrix and aten::CSRMatrix.
std::vector<bool> flags;
SparseMatrix() {}
SparseMatrix(int32_t fmt, int64_t nrows, int64_t ncols,
const std::vector<IdArray>& idx,
const std::vector<bool>& flg)
: format(fmt), num_rows(nrows), num_cols(ncols), indices(idx), flags(flg) {}
SparseMatrix(
int32_t fmt, int64_t nrows, int64_t ncols,
const std::vector<IdArray>& idx, const std::vector<bool>& flg)
: format(fmt),
num_rows(nrows),
num_cols(ncols),
indices(idx),
flags(flg) {}
static constexpr const char* _type_key = "aten.SparseMatrix";
DGL_DECLARE_OBJECT_TYPE_INFO(SparseMatrix, runtime::Object);
......
/*!
* Copyright (c) 2020 by Contributors
* \file dgl/aten/types.h
* \brief Array and ID types
* \brief Array and ID types
*/
#ifndef DGL_ATEN_TYPES_H_
#define DGL_ATEN_TYPES_H_
#include <cstdint>
#include "../runtime/ndarray.h"
namespace dgl {
......@@ -15,7 +16,7 @@ typedef uint64_t dgl_id_t;
typedef uint64_t dgl_type_t;
/*! \brief Type for dgl fomrat code, whose binary representation indices
* which sparse format is in use and which is not.
*
*
* Suppose the binary representation is xyz, then
* - x indicates whether csc is in use (1 for true and 0 for false).
* - y indicates whether csr is in use.
......
......@@ -7,17 +7,17 @@
#ifndef DGL_BASE_HETEROGRAPH_H_
#define DGL_BASE_HETEROGRAPH_H_
#include <string>
#include <vector>
#include <utility>
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "./runtime/object.h"
#include "array.h"
#include "aten/spmat.h"
#include "aten/types.h"
#include "graph_interface.h"
#include "array.h"
namespace dgl {
......@@ -52,31 +52,26 @@ enum class EdgeDir {
*/
class BaseHeteroGraph : public runtime::Object {
public:
explicit BaseHeteroGraph(GraphPtr meta_graph): meta_graph_(meta_graph) {}
explicit BaseHeteroGraph(GraphPtr meta_graph) : meta_graph_(meta_graph) {}
virtual ~BaseHeteroGraph() = default;
////////////////////////// query/operations on meta graph ////////////////////////
////////////////////// query/operations on meta graph ///////////////////////
/*! \return the number of vertex types */
virtual uint64_t NumVertexTypes() const {
return meta_graph_->NumVertices();
}
virtual uint64_t NumVertexTypes() const { return meta_graph_->NumVertices(); }
/*! \return the number of edge types */
virtual uint64_t NumEdgeTypes() const {
return meta_graph_->NumEdges();
}
virtual uint64_t NumEdgeTypes() const { return meta_graph_->NumEdges(); }
/*! \return given the edge type, find the source type */
virtual std::pair<dgl_type_t, dgl_type_t> GetEndpointTypes(dgl_type_t etype) const {
virtual std::pair<dgl_type_t, dgl_type_t> GetEndpointTypes(
dgl_type_t etype) const {
return meta_graph_->FindEdge(etype);
}
/*! \return the meta graph */
virtual GraphPtr meta_graph() const {
return meta_graph_;
}
virtual GraphPtr meta_graph() const { return meta_graph_; }
/*!
* \brief Return the bipartite graph of the given edge type.
......@@ -85,7 +80,7 @@ class BaseHeteroGraph : public runtime::Object {
*/
virtual HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const = 0;
////////////////////////// query/operations on realized graph ////////////////////////
///////////////////// query/operations on realized graph /////////////////////
/*! \brief Add vertices to the given vertex type */
virtual void AddVertices(dgl_type_t vtype, uint64_t num_vertices) = 0;
......@@ -128,7 +123,8 @@ class BaseHeteroGraph : public runtime::Object {
virtual void RecordStream(DGLStreamHandle stream) = 0;
/*!
* \brief Get the number of integer bits used to store node/edge ids (32 or 64).
* \brief Get the number of integer bits used to store node/edge ids (32 or
* 64).
*/
// TODO(BarclayII) replace NumBits() calls to DataType() calls
virtual uint8_t NumBits() const = 0;
......@@ -156,14 +152,18 @@ class BaseHeteroGraph : public runtime::Object {
/*! \return true if the given vertex is in the graph.*/
virtual bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const = 0;
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/
/*! \return a 0-1 array indicating whether the given vertices are in the
* graph.
*/
virtual BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const = 0;
/*! \return true if the given edge is in the graph.*/
virtual bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0;
virtual bool HasEdgeBetween(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0;
/*! \return a 0-1 array indicating whether the given edges are in the graph.*/
virtual BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const = 0;
virtual BoolArray HasEdgesBetween(
dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const = 0;
/*!
* \brief Find the predecessors of a vertex.
......@@ -187,14 +187,13 @@ class BaseHeteroGraph : public runtime::Object {
/*!
* \brief Get all edge ids between the two given endpoints
* \note The given src and dst vertices should belong to the source vertex type
* and the dest vertex type of the given edge type, respectively.
* \param etype The edge type
* \param src The source vertex.
* \param dst The destination vertex.
* \return the edge id array.
* \note The given src and dst vertices should belong to the source vertex
* type and the dest vertex type of the given edge type, respectively. \param
* etype The edge type \param src The source vertex. \param dst The
* destination vertex. \return the edge id array.
*/
virtual IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0;
virtual IdArray EdgeId(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0;
/*!
* \brief Get all edge ids between the given endpoint pairs.
......@@ -204,34 +203,40 @@ class BaseHeteroGraph : public runtime::Object {
* \param dst The dst vertex ids.
* \return EdgeArray containing all edges between all pairs.
*/
virtual EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const = 0;
virtual EdgeArray EdgeIdsAll(
dgl_type_t etype, IdArray src, IdArray dst) const = 0;
/*!
* \brief Get edge ids between the given endpoint pairs.
*
* Only find one matched edge Ids even if there are multiple matches due to parallel
* edges. The i^th Id in the returned array is for edge (src[i], dst[i]).
* Only find one matched edge Ids even if there are multiple matches due to
* parallel edges. The i^th Id in the returned array is for edge (src[i],
* dst[i]).
*
* \param etype The edge type
* \param src The src vertex ids.
* \param dst The dst vertex ids.
* \return EdgeArray containing all edges between all pairs.
*/
virtual IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const = 0;
virtual IdArray EdgeIdsOne(
dgl_type_t etype, IdArray src, IdArray dst) const = 0;
/*!
* \brief Find the edge ID and return the pair of endpoints
* \param etype The edge type
* \param eid The edge ID
* \return a pair whose first element is the source and the second the destination.
* \return a pair whose first element is the source and the second the
* destination.
*/
virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const = 0;
virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(
dgl_type_t etype, dgl_id_t eid) const = 0;
/*!
* \brief Find the edge IDs and return their source and target node IDs.
* \param etype The edge type
* \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved.
* \return EdgeArray containing all edges with id in eid. The order is
* preserved.
*/
virtual EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const = 0;
......@@ -277,14 +282,14 @@ class BaseHeteroGraph : public runtime::Object {
/*!
* \brief Get all the edges in the graph.
* \note If order is "srcdst", the returned edges list is sorted by their src and
* dst ids. If order is "eid", they are in their edge id order.
* Otherwise, in the arbitrary order.
* \param etype The edge type
* \param order The order of the returned edge list.
* \return the id arrays of the two endpoints of the edges.
* \note If order is "srcdst", the returned edges list is sorted by their src
* and dst ids. If order is "eid", they are in their edge id order. Otherwise,
* in the arbitrary order. \param etype The edge type \param order The order
* of the returned edge list. \return the id arrays of the two endpoints of
* the edges.
*/
virtual EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const = 0;
virtual EdgeArray Edges(
dgl_type_t etype, const std::string& order = "") const = 0;
/*!
* \brief Get the in degree of the given vertex.
......@@ -373,15 +378,15 @@ class BaseHeteroGraph : public runtime::Object {
* If the fmt is 'csr', the function should return three arrays, representing
* indptr, indices and edge ids
*
* If the fmt is 'coo', the function should return one array of shape (2, nnz),
* representing a horitonzal stack of row and col indices.
* If the fmt is 'coo', the function should return one array of shape (2,
* nnz), representing a horitonzal stack of row and col indices.
*
* \param transpose A flag to transpose the returned adjacency matrix.
* \param fmt the format of the returned adjacency matrix.
* \return a vector of IdArrays.
*/
virtual std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const = 0;
dgl_type_t etype, bool transpose, const std::string& fmt) const = 0;
/*!
* \brief Determine which format to use with a preference.
......@@ -451,39 +456,42 @@ class BaseHeteroGraph : public runtime::Object {
/*!
* \brief Extract the induced subgraph by the given vertices.
*
* The length of the given vector should be equal to the number of vertex types.
* Empty arrays can be provided if no vertex is needed for the type. The result
* subgraph has the same meta graph with the parent, but some types can have no
* node/edge.
* The length of the given vector should be equal to the number of vertex
* types. Empty arrays can be provided if no vertex is needed for the type.
* The result subgraph has the same meta graph with the parent, but some types
* can have no node/edge.
*
* \param vids the induced vertices per type.
* \return the subgraph.
*/
virtual HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const = 0;
virtual HeteroSubgraph VertexSubgraph(
const std::vector<IdArray>& vids) const = 0;
/*!
* \brief Extract the induced subgraph by the given edges.
*
* The length of the given vector should be equal to the number of edge types.
* Empty arrays can be provided if no edge is needed for the type. The result
* subgraph has the same meta graph with the parent, but some types can have no
* node/edge.
* subgraph has the same meta graph with the parent, but some types can have
* no node/edge.
*
* \param eids The edges in the subgraph.
* \param preserve_nodes If true, the vertices will not be relabeled, so some vertices
* may have no incident edges.
* \return the subgraph.
* \param preserve_nodes If true, the vertices will not be relabeled, so some
* vertices may have no incident edges. \return the subgraph.
*/
virtual HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const = 0;
/*!
* \brief Convert the list of requested unitgraph graphs into a single unitgraph graph.
* \brief Convert the list of requested unitgraph graphs into a single
* unitgraph graph.
*
* \param etypes The list of edge type IDs.
* \return The flattened graph, with induced source/edge/destination types/IDs.
* \return The flattened graph, with induced source/edge/destination
* types/IDs.
*/
virtual FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const {
virtual FlattenedHeteroGraphPtr Flatten(
const std::vector<dgl_type_t>& etypes) const {
LOG(FATAL) << "Flatten operation unsupported";
return nullptr;
}
......@@ -502,7 +510,7 @@ class BaseHeteroGraph : public runtime::Object {
GraphPtr meta_graph_;
// empty constructor
BaseHeteroGraph(){}
BaseHeteroGraph() {}
};
// Define HeteroGraphRef
......@@ -527,9 +535,9 @@ struct HeteroSubgraph : public runtime::Object {
HeteroGraphPtr graph;
/*!
* \brief The induced vertex ids of each entity type.
* The vector length is equal to the number of vertex types in the parent graph.
* Each array i has the same length as the number of vertices in type i.
* Empty array is allowed if the mapping is identity.
* The vector length is equal to the number of vertex types in the parent
* graph. Each array i has the same length as the number of vertices in type
* i. Empty array is allowed if the mapping is identity.
*/
std::vector<IdArray> induced_vertices;
/*!
......@@ -553,7 +561,8 @@ struct FlattenedHeteroGraph : public runtime::Object {
HeteroGraphRef graph;
/*!
* \brief Mapping from source node ID to node type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously.
* \note The induced type array guarantees that the same type always appear
* contiguously.
*/
IdArray induced_srctype;
/*!
......@@ -564,7 +573,8 @@ struct FlattenedHeteroGraph : public runtime::Object {
IdArray induced_srcid;
/*!
* \brief Mapping from edge ID to edge type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously.
* \note The induced type array guarantees that the same type always appear
* contiguously.
*/
IdArray induced_etype;
/*!
......@@ -575,17 +585,20 @@ struct FlattenedHeteroGraph : public runtime::Object {
IdArray induced_eid;
/*!
* \brief Mapping from destination node ID to node type in parent graph
* \note The induced type array guarantees that the same type always appear contiguously.
* \note The induced type array guarantees that the same type always appear
* contiguously.
*/
IdArray induced_dsttype;
/*!
* \brief The set of node types in parent graph appearing in destination nodes.
* \brief The set of node types in parent graph appearing in destination
* nodes.
*/
IdArray induced_dsttype_set;
/*! \brief Mapping from destination node ID to local node ID in parent graph */
/*! \brief Mapping from destination node ID to local node ID in parent graph
*/
IdArray induced_dstid;
void VisitAttrs(runtime::AttrVisitor *v) final {
void VisitAttrs(runtime::AttrVisitor* v) final {
v->Visit("graph", &graph);
v->Visit("induced_srctype", &induced_srctype);
v->Visit("induced_srctype_set", &induced_srctype_set);
......@@ -610,9 +623,8 @@ DGL_DEFINE_OBJECT_REF(FlattenedHeteroGraphRef, FlattenedHeteroGraph);
* additionally specifying number of nodes per type.
*/
HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph,
const std::vector<HeteroGraphPtr> &rel_graphs,
const std::vector<int64_t> &num_nodes_per_type = {});
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type = {});
/*!
* \brief Create a heterograph from COO input.
......@@ -629,8 +641,8 @@ HeteroGraphPtr CreateHeteroGraph(
* \return A heterograph pointer.
*/
HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, bool row_sorted = false, bool col_sorted = false,
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,
IdArray col, bool row_sorted = false, bool col_sorted = false,
dgl_format_code_t formats = ALL_CODE);
/*!
......@@ -656,9 +668,8 @@ HeteroGraphPtr CreateFromCOO(
* \return A heterograph pointer.
*/
HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
dgl_format_code_t formats = ALL_CODE);
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE);
/*!
* \brief Create a heterograph from CSR input.
......@@ -683,9 +694,8 @@ HeteroGraphPtr CreateFromCSR(
* \return A heterograph pointer.
*/
HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
dgl_format_code_t formats = ALL_CODE);
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE);
/*!
* \brief Create a heterograph from CSC input.
......@@ -702,23 +712,25 @@ HeteroGraphPtr CreateFromCSC(
* \brief Extract the subgraph of the in edges of the given nodes.
* \param graph Graph
* \param nodes Node IDs of each type
* \param relabel_nodes Whether to remove isolated nodes and relabel the rest ones
* \return Subgraph containing only the in edges. The returned graph has the same
* schema as the original one.
* \param relabel_nodes Whether to remove isolated nodes and relabel the rest
* ones \return Subgraph containing only the in edges. The returned graph has
* the same schema as the original one.
*/
HeteroSubgraph InEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& nodes, bool relabel_nodes = false);
const HeteroGraphPtr graph, const std::vector<IdArray>& nodes,
bool relabel_nodes = false);
/*!
* \brief Extract the subgraph of the out edges of the given nodes.
* \param graph Graph
* \param nodes Node IDs of each type
* \param relabel_nodes Whether to remove isolated nodes and relabel the rest ones
* \return Subgraph containing only the out edges. The returned graph has the same
* schema as the original one.
* \param relabel_nodes Whether to remove isolated nodes and relabel the rest
* ones \return Subgraph containing only the out edges. The returned graph has
* the same schema as the original one.
*/
HeteroSubgraph OutEdgeGraph(
const HeteroGraphPtr graph, const std::vector<IdArray>& nodes, bool relabel_nodes = false);
const HeteroGraphPtr graph, const std::vector<IdArray>& nodes,
bool relabel_nodes = false);
/*!
* \brief Joint union multiple graphs into one graph.
......@@ -735,7 +747,8 @@ HeteroGraphPtr JointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
/*!
* \brief Union multiple graphs into one with each input graph as one disjoint component.
* \brief Union multiple graphs into one with each input graph as one disjoint
* component.
*
* All input graphs should have the same metagraph.
*
......@@ -754,7 +767,8 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
/*!
* \brief Slice a contiguous subgraph, e.g. retrieve a component graph from a batched graph.
* \brief Slice a contiguous subgraph, e.g. retrieve a component graph from a
* batched graph.
*
* TODO(mufei): remove the meta_graph argument
*
......@@ -767,25 +781,22 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
* \return Sliced graph
*/
HeteroGraphPtr SliceHeteroGraph(
GraphPtr meta_graph,
HeteroGraphPtr batched_graph,
IdArray num_nodes_per_type,
IdArray start_nid_per_type,
IdArray num_edges_per_type,
IdArray start_eid_per_type);
GraphPtr meta_graph, HeteroGraphPtr batched_graph,
IdArray num_nodes_per_type, IdArray start_nid_per_type,
IdArray num_edges_per_type, IdArray start_eid_per_type);
/*!
* \brief Split a graph into multiple disjoin components.
*
* Edges across different components are ignored. All the result graphs have the same
* metagraph as the input one.
* Edges across different components are ignored. All the result graphs have the
* same metagraph as the input one.
*
* The `vertex_sizes` and `edge_sizes` arrays the concatenation of arrays of each
* node/edge type. Suppose there are N vertex types, then the array length should
* be B*N, where B is the number of components to split.
* The `vertex_sizes` and `edge_sizes` arrays the concatenation of arrays of
* each node/edge type. Suppose there are N vertex types, then the array length
* should be B*N, where B is the number of components to split.
*
* TODO(minjie): remove the meta_graph argument; use vector<IdArray> for vertex_sizes
* and edge_sizes.
* TODO(minjie): remove the meta_graph argument; use vector<IdArray> for
* vertex_sizes and edge_sizes.
*
* \tparam IdType Graph's index data type, can be int32_t or int64_t
* \param meta_graph Metagraph.
......@@ -796,16 +807,11 @@ HeteroGraphPtr SliceHeteroGraph(
*/
template <class IdType>
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
GraphPtr meta_graph,
HeteroGraphPtr batched_graph,
IdArray vertex_sizes,
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
IdArray edge_sizes);
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
GraphPtr meta_graph,
HeteroGraphPtr batched_graph,
IdArray vertex_sizes,
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes,
IdArray edge_sizes);
/*!
......@@ -862,7 +868,8 @@ DGL_DEFINE_OBJECT_REF(HeteroPickleStatesRef, HeteroPickleStates);
HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states);
/*!
* \brief Get the pickling state of the relation graph structure in backend tensors.
* \brief Get the pickling state of the relation graph structure in backend
* tensors.
*
* \return a HeteroPickleStates object
*/
......@@ -886,8 +893,8 @@ HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states);
HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states);
/*!
* \brief Get the pickling states of the relation graph structure in backend tensors for
* ForkingPickler.
* \brief Get the pickling states of the relation graph structure in backend
* tensors for ForkingPickler.
*
* This is different from HeteroPickle where
* (1) Backward compatibility is not required,
......@@ -895,14 +902,11 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states);
*/
HeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph);
#define FORMAT_HAS_CSC(format) \
((format) & CSC_CODE)
#define FORMAT_HAS_CSC(format) ((format)&CSC_CODE)
#define FORMAT_HAS_CSR(format) \
((format) & CSR_CODE)
#define FORMAT_HAS_CSR(format) ((format)&CSR_CODE)
#define FORMAT_HAS_COO(format) \
((format) & COO_CODE)
#define FORMAT_HAS_COO(format) ((format)&COO_CODE)
} // namespace dgl
......
......@@ -7,6 +7,7 @@
#include <string>
#include <vector>
#include "./runtime/ndarray.h"
using namespace dgl::runtime;
......@@ -37,7 +38,7 @@ struct BcastOff {
/*! \brief Whether broadcast is required or not. */
bool use_bcast;
/*!
* \brief Auxiliary information for kernel computation
* \brief Auxiliary information for kernel computation
* \note lhs_len refers to the left hand side operand length.
* e.g. 15 for shape (1, 3, 5)
* rhs_len refers to the right hand side operand length.
......@@ -61,7 +62,6 @@ struct BcastOff {
*/
BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs);
} // namespace dgl
} // namespace dgl
#endif // DGL_BCAST_H_
......@@ -6,13 +6,12 @@
#ifndef DGL_GRAPH_H_
#define DGL_GRAPH_H_
#include <string>
#include <vector>
#include <string>
#include <cstdint>
#include <utility>
#include <tuple>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "graph_interface.h"
......@@ -23,7 +22,7 @@ class GraphOp;
typedef std::shared_ptr<Graph> MutableGraphPtr;
/*! \brief Mutable graph based on adjacency list. */
class Graph: public GraphInterface {
class Graph : public GraphInterface {
public:
/*! \brief default constructor */
Graph() {}
......@@ -89,13 +88,9 @@ class Graph: public GraphInterface {
num_edges_ = 0;
}
DGLContext Context() const override {
return DGLContext{kDGLCPU, 0};
}
DGLContext Context() const override { return DGLContext{kDGLCPU, 0}; }
uint8_t NumBits() const override {
return 64;
}
uint8_t NumBits() const override { return 64; }
/*!
* \note not const since we have caches
......@@ -106,21 +101,17 @@ class Graph: public GraphInterface {
/*!
* \return whether the graph is read-only
*/
bool IsReadonly() const override {
return false;
}
bool IsReadonly() const override { return false; }
/*! \return the number of vertices in the graph.*/
uint64_t NumVertices() const override {
return adjlist_.size();
}
uint64_t NumVertices() const override { return adjlist_.size(); }
/*! \return the number of edges in the graph.*/
uint64_t NumEdges() const override {
return num_edges_;
}
uint64_t NumEdges() const override { return num_edges_; }
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/
/*! \return a 0-1 array indicating whether the given vertices are in the
* graph.
*/
BoolArray HasVertices(IdArray vids) const override;
/*! \return true if the given edge is in the graph.*/
......@@ -132,7 +123,8 @@ class Graph: public GraphInterface {
/*!
* \brief Find the predecessors of a vertex.
* \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the predecessor id array.
*/
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override;
......@@ -140,7 +132,8 @@ class Graph: public GraphInterface {
/*!
* \brief Find the successors of a vertex.
* \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the successor id array.
*/
IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override;
......@@ -169,7 +162,8 @@ class Graph: public GraphInterface {
/*!
* \brief Find the edge ID and return the pair of endpoints
* \param eid The edge ID
* \return a pair whose first element is the source and the second the destination.
* \return a pair whose first element is the source and the second the
* destination.
*/
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
return std::make_pair(all_edges_src_[eid], all_edges_dst_[eid]);
......@@ -178,7 +172,8 @@ class Graph: public GraphInterface {
/*!
* \brief Find the edge IDs and return their source and target node IDs.
* \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved.
* \return EdgeArray containing all edges with id in eid. The order is
* preserved.
*/
EdgeArray FindEdges(IdArray eids) const override;
......@@ -216,10 +211,11 @@ class Graph: public GraphInterface {
* \brief Get all the edges in the graph.
* \note If sorted is true, the returned edges list is sorted by their src and
* dst ids. Otherwise, they are in their edge id order.
* \param sorted Whether the returned edge list is sorted by their src and dst ids
* \param sorted Whether the returned edge list is sorted by their src and dst
* ids.
* \return the id arrays of the two endpoints of the edges.
*/
EdgeArray Edges(const std::string &order = "") const override;
EdgeArray Edges(const std::string& order = "") const override;
/*!
* \brief Get the in degree of the given vertex.
......@@ -258,13 +254,14 @@ class Graph: public GraphInterface {
/*!
* \brief Construct the induced subgraph of the given vertices.
*
* The induced subgraph is a subgraph formed by specifying a set of vertices V' and then
* selecting all of the edges from the original graph that connect two vertices in V'.
* The induced subgraph is a subgraph formed by specifying a set of vertices
* V' and then selecting all of the edges from the original graph that connect
* two vertices in V'.
*
* Vertices and edges in the original graph will be "reindexed" to local index. The local
* index of the vertices preserve the order of the given id array, while the local index
* of the edges preserve the index order in the original graph. Vertices not in the
* original graph are ignored.
* Vertices and edges in the original graph will be "reindexed" to local
* index. The local index of the vertices preserve the order of the given id
* array, while the local index of the edges preserve the index order in the
* original graph. Vertices not in the original graph are ignored.
*
* The result subgraph is read-only.
*
......@@ -276,20 +273,22 @@ class Graph: public GraphInterface {
/*!
* \brief Construct the induced edge subgraph of the given edges.
*
* The induced edges subgraph is a subgraph formed by specifying a set of edges E' and then
* selecting all of the nodes from the original graph that are endpoints in E'.
* The induced edges subgraph is a subgraph formed by specifying a set of
* edges E' and then selecting all of the nodes from the original graph that
* are endpoints in E'.
*
* Vertices and edges in the original graph will be "reindexed" to local index. The local
* index of the edges preserve the order of the given id array, while the local index
* of the vertices preserve the index order in the original graph. Edges not in the
* original graph are ignored.
* Vertices and edges in the original graph will be "reindexed" to local
* index. The local index of the edges preserve the order of the given id
* array, while the local index of the vertices preserve the index order in
* the original graph. Edges not in the original graph are ignored.
*
* The result subgraph is read-only.
*
* \param eids The edges in the subgraph.
* \return the induced edge subgraph
*/
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;
Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const override;
/*!
* \brief Return the successor vector
......@@ -344,12 +343,11 @@ class Graph: public GraphInterface {
* \param fmt the format of the returned adjacency matrix.
* \return a vector of three IdArray.
*/
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override;
std::vector<IdArray> GetAdj(
bool transpose, const std::string& fmt) const override;
/*! \brief Create an empty graph */
static MutableGraphPtr Create() {
return std::make_shared<Graph>();
}
static MutableGraphPtr Create() { return std::make_shared<Graph>(); }
/*! \brief Create from coo */
static MutableGraphPtr CreateFromCOO(
......
......@@ -6,11 +6,11 @@
#ifndef DGL_GRAPH_INTERFACE_H_
#define DGL_GRAPH_INTERFACE_H_
#include <string>
#include <vector>
#include <utility>
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "./runtime/object.h"
#include "array.h"
......@@ -23,7 +23,8 @@ const dgl_id_t DGL_INVALID_ID = static_cast<dgl_id_t>(-1);
* \brief This class references data in std::vector.
*
* This isn't a STL-style iterator. It provides a STL data container interface.
* but it doesn't own data itself. instead, it only references data in std::vector.
* but it doesn't own data itself. instead, it only references data in
* std::vector.
*/
class DGLIdIters {
public:
......@@ -34,18 +35,11 @@ class DGLIdIters {
this->begin_ = begin;
this->end_ = end;
}
const dgl_id_t *begin() const {
return this->begin_;
}
const dgl_id_t *end() const {
return this->end_;
}
dgl_id_t operator[](int64_t i) const {
return *(this->begin_ + i);
}
size_t size() const {
return this->end_ - this->begin_;
}
const dgl_id_t *begin() const { return this->begin_; }
const dgl_id_t *end() const { return this->end_; }
dgl_id_t operator[](int64_t i) const { return *(this->begin_ + i); }
size_t size() const { return this->end_ - this->begin_; }
private:
const dgl_id_t *begin_{nullptr}, *end_{nullptr};
};
......@@ -63,23 +57,15 @@ class DGLIdIters32 {
this->begin_ = begin;
this->end_ = end;
}
const int32_t *begin() const {
return this->begin_;
}
const int32_t *end() const {
return this->end_;
}
int32_t operator[](int32_t i) const {
return *(this->begin_ + i);
}
size_t size() const {
return this->end_ - this->begin_;
}
const int32_t *begin() const { return this->begin_; }
const int32_t *end() const { return this->end_; }
int32_t operator[](int32_t i) const { return *(this->begin_ + i); }
size_t size() const { return this->end_ - this->begin_; }
private:
const int32_t *begin_{nullptr}, *end_{nullptr};
};
/* \brief structure used to represent a list of edges */
typedef struct {
/* \brief the two endpoints and the id of the edge */
......@@ -140,7 +126,8 @@ class GraphInterface : public runtime::Object {
virtual DGLContext Context() const = 0;
/*!
* \brief Get the number of integer bits used to store node/edge ids (32 or 64).
* \brief Get the number of integer bits used to store node/edge ids
* (32 or 64).
*/
virtual uint8_t NumBits() const = 0;
......@@ -192,11 +179,11 @@ class GraphInterface : public runtime::Object {
virtual uint64_t NumEdges() const = 0;
/*! \return true if the given vertex is in the graph.*/
virtual bool HasVertex(dgl_id_t vid) const {
return vid < NumVertices();
}
virtual bool HasVertex(dgl_id_t vid) const { return vid < NumVertices(); }
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/
/*! \return a 0-1 array indicating whether the given vertices are in the
* graph.
*/
virtual BoolArray HasVertices(IdArray vids) const = 0;
/*! \return true if the given edge is in the graph.*/
......@@ -208,7 +195,8 @@ class GraphInterface : public runtime::Object {
/*!
* \brief Find the predecessors of a vertex.
* \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the predecessor id array.
*/
virtual IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const = 0;
......@@ -216,7 +204,8 @@ class GraphInterface : public runtime::Object {
/*!
* \brief Find the successors of a vertex.
* \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the successor id array.
*/
virtual IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const = 0;
......@@ -245,14 +234,16 @@ class GraphInterface : public runtime::Object {
/*!
* \brief Find the edge ID and return the pair of endpoints
* \param eid The edge ID
* \return a pair whose first element is the source and the second the destination.
* \return a pair whose first element is the source and the second the
* destination.
*/
virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const = 0;
/*!
* \brief Find the edge IDs and return their source and target node IDs.
* \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved.
* \return EdgeArray containing all edges with id in eid. The order is
* preserved.
*/
virtual EdgeArray FindEdges(IdArray eids) const = 0;
......@@ -288,8 +279,8 @@ class GraphInterface : public runtime::Object {
/*!
* \brief Get all the edges in the graph.
* \note If order is "srcdst", the returned edges list is sorted by their src and
* dst ids. If order is "eid", they are in their edge id order.
* \note If order is "srcdst", the returned edges list is sorted by their src
* and dst ids. If order is "eid", they are in their edge id order.
* Otherwise, in the arbitrary order.
* \param order The order of the returned edge list.
* \return the id arrays of the two endpoints of the edges.
......@@ -327,13 +318,14 @@ class GraphInterface : public runtime::Object {
/*!
* \brief Construct the induced subgraph of the given vertices.
*
* The induced subgraph is a subgraph formed by specifying a set of vertices V' and then
* selecting all of the edges from the original graph that connect two vertices in V'.
* The induced subgraph is a subgraph formed by specifying a set of vertices
* V' and then selecting all of the edges from the original graph that connect
* two vertices in V'.
*
* Vertices and edges in the original graph will be "reindexed" to local index. The local
* index of the vertices preserve the order of the given id array, while the local index
* of the edges preserve the index order in the original graph. Vertices not in the
* original graph are ignored.
* Vertices and edges in the original graph will be "reindexed" to local
* index. The local index of the vertices preserve the order of the given id
* array, while the local index of the edges preserve the index order in the
* original graph. Vertices not in the original graph are ignored.
*
* The result subgraph is read-only.
*
......@@ -345,22 +337,24 @@ class GraphInterface : public runtime::Object {
/*!
* \brief Construct the induced edge subgraph of the given edges.
*
* The induced edges subgraph is a subgraph formed by specifying a set of edges E' and then
* selecting all of the nodes from the original graph that are endpoints in E'.
* The induced edges subgraph is a subgraph formed by specifying a set of
* edges E' and then selecting all of the nodes from the original graph that
* are endpoints in E'.
*
* Vertices and edges in the original graph will be "reindexed" to local index. The local
* index of the edges preserve the order of the given id array, while the local index
* of the vertices preserve the index order in the original graph. Edges not in the
* original graph are ignored.
* Vertices and edges in the original graph will be "reindexed" to local
* index. The local index of the edges preserve the order of the given id
* array, while the local index of the vertices preserve the index order in
* the original graph. Edges not in the original graph are ignored.
*
* The result subgraph is read-only.
*
* \param eids The edges in the subgraph.
* \param preserve_nodes If true, the vertices will not be relabeled, so some vertices
* may have no incident edges.
* \param preserve_nodes If true, the vertices will not be relabeled, so some
* vertices may have no incident edges.
* \return the induced edge subgraph
*/
virtual Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const = 0;
virtual Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const = 0;
/*!
* \brief Return the successor vector
......@@ -399,14 +393,15 @@ class GraphInterface : public runtime::Object {
* If the fmt is 'csr', the function should return three arrays, representing
* indptr, indices and edge ids
*
* If the fmt is 'coo', the function should return one array of shape (2, nnz),
* representing a horitonzal stack of row and col indices.
* If the fmt is 'coo', the function should return one array of shape (2,
* nnz), representing a horitonzal stack of row and col indices.
*
* \param transpose A flag to transpose the returned adjacency matrix.
* \param fmt the format of the returned adjacency matrix.
* \return a vector of IdArrays.
*/
virtual std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const = 0;
virtual std::vector<IdArray> GetAdj(
bool transpose, const std::string &fmt) const = 0;
/*!
* \brief Sort the columns in CSR.
......@@ -414,10 +409,9 @@ class GraphInterface : public runtime::Object {
* This sorts the columns in each row based on the column Ids.
* The edge ids should be sorted accordingly.
*/
virtual void SortCSR() {
}
virtual void SortCSR() {}
static constexpr const char* _type_key = "graph.Graph";
static constexpr const char *_type_key = "graph.Graph";
DGL_DECLARE_OBJECT_TYPE_INFO(GraphInterface, runtime::Object);
};
......@@ -430,21 +424,23 @@ struct Subgraph : public runtime::Object {
GraphPtr graph;
/*!
* \brief The induced vertex ids.
* \note This is also a map from the new vertex id to the vertex id in the parent graph.
* \note This is also a map from the new vertex id to the vertex id in the
* parent graph.
*/
IdArray induced_vertices;
/*!
* \brief The induced edge ids.
* \note This is also a map from the new edge id to the edge id in the parent graph.
* \note This is also a map from the new edge id to the edge id in the parent
* graph.
*/
IdArray induced_edges;
static constexpr const char* _type_key = "graph.Subgraph";
static constexpr const char *_type_key = "graph.Subgraph";
DGL_DECLARE_OBJECT_TYPE_INFO(Subgraph, runtime::Object);
};
/*! \brief Subgraph data structure for negative subgraph */
struct NegSubgraph: public Subgraph {
struct NegSubgraph : public Subgraph {
/*! \brief The existence of the negative edges in the parent graph. */
IdArray exist;
......@@ -456,7 +452,7 @@ struct NegSubgraph: public Subgraph {
};
/*! \brief Subgraph data structure for halo subgraph */
struct HaloSubgraph: public Subgraph {
struct HaloSubgraph : public Subgraph {
/*! \brief Indicate if a node belongs to the partition. */
IdArray inner_nodes;
};
......
......@@ -7,6 +7,7 @@
#define DGL_GRAPH_OP_H_
#include <vector>
#include "graph.h"
#include "immutable_graph.h"
......@@ -17,7 +18,8 @@ class GraphOp {
/*!
* \brief Return a new graph with all the edges reversed.
*
* The returned graph preserves the vertex and edge index in the original graph.
* The returned graph preserves the vertex and edge index in the original
* graph.
*
* \return the reversed graph
*/
......@@ -40,10 +42,10 @@ class GraphOp {
* \brief Return a disjoint union of the input graphs.
*
* The new graph will include all the nodes/edges in the given graphs.
* Nodes/Edges will be relabled by adding the cumsum of the previous graph sizes
* in the given sequence order. For example, giving input [g1, g2, g3], where
* they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7
* in the result graph. Edge ids are re-assigned similarly.
* Nodes/Edges will be relabled by adding the cumsum of the previous graph
* sizes in the given sequence order. For example, giving input [g1, g2, g3],
* where they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become
* node#7 in the result graph. Edge ids are re-assigned similarly.
*
* The input list must be either ALL mutable graphs or ALL immutable graphs.
* The returned graph type is also determined by the input graph type.
......@@ -62,12 +64,13 @@ class GraphOp {
*
* If the input graph is mutable, the result graphs are mutable.
* If the input graph is immutable, the result graphs are immutable.
*
*
* \param graph The graph to be partitioned.
* \param num The number of partitions.
* \return a list of partitioned graphs
*/
static std::vector<GraphPtr> DisjointPartitionByNum(GraphPtr graph, int64_t num);
static std::vector<GraphPtr> DisjointPartitionByNum(
GraphPtr graph, int64_t num);
/*!
* \brief Partition the graph into several subgraphs.
......@@ -78,21 +81,22 @@ class GraphOp {
*
* If the input graph is mutable, the result graphs are mutable.
* If the input graph is immutable, the result graphs are immutable.
*
*
* \param graph The graph to be partitioned.
* \param sizes The number of partitions.
* \return a list of partitioned graphs
*/
static std::vector<GraphPtr> DisjointPartitionBySizes(GraphPtr graph, IdArray sizes);
static std::vector<GraphPtr> DisjointPartitionBySizes(
GraphPtr graph, IdArray sizes);
/*!
* \brief Map vids in the parent graph to the vids in the subgraph.
*
* If the Id doesn't exist in the subgraph, -1 will be used.
*
* \param parent_vid_map An array that maps the vids in the parent graph to the
* subgraph. The elements store the vertex Ids in the parent graph, and the
* indices indicate the vertex Ids in the subgraph.
* \param parent_vid_map An array that maps the vids in the parent graph to
* the subgraph. The elements store the vertex Ids in the parent graph, and
* the indices indicate the vertex Ids in the subgraph.
* \param query The vertex Ids in the parent graph.
* \return an Id array that contains the subgraph node Ids.
*/
......@@ -134,15 +138,18 @@ class GraphOp {
static GraphPtr ToBidirectedMutableGraph(GraphPtr graph);
/*!
* \brief Same as BidirectedMutableGraph except that the returned graph is immutable.
* \brief Same as BidirectedMutableGraph except that the returned graph is
* immutable.
* \param graph The input graph.
* \return a new immutable bidirected graph.
* \return a new immutable bidirected
* graph.
*/
static GraphPtr ToBidirectedImmutableGraph(GraphPtr graph);
/*!
* \brief Same as BidirectedMutableGraph except that the returned graph is immutable
* and call gk_csr_MakeSymmetric in GKlib. This is more efficient than ToBidirectedImmutableGraph.
* It return a null pointer if the conversion fails.
* \brief Same as BidirectedMutableGraph except that the returned graph is
* immutable and call gk_csr_MakeSymmetric in GKlib. This is more efficient
* than ToBidirectedImmutableGraph. It return a null pointer if the conversion
* fails.
*
* \param graph The input graph.
* \return a new immutable bidirected graph.
......@@ -151,21 +158,25 @@ class GraphOp {
/*!
* \brief Get a induced subgraph with HALO nodes.
* The HALO nodes are the ones that can be reached from `nodes` within `num_hops`.
* The HALO nodes are the ones that can be reached from `nodes` within
* `num_hops`.
* \param graph The input graph.
* \param nodes The input nodes that form the core of the induced subgraph.
* \param num_hops The number of hops to reach.
* \return the induced subgraph with HALO nodes.
*/
static HaloSubgraph GetSubgraphWithHalo(GraphPtr graph, IdArray nodes, int num_hops);
static HaloSubgraph GetSubgraphWithHalo(
GraphPtr graph, IdArray nodes, int num_hops);
/*!
* \brief Reorder the nodes in the immutable graph.
* \param graph The input graph.
* \param new_order The node Ids in the new graph. The index in `new_order` is old node Ids.
* \param new_order The node Ids in the new graph. The index in `new_order` is
* old node Ids.
* \return the graph with reordered node Ids
*/
static GraphPtr ReorderImmutableGraph(ImmutableGraphPtr ig, IdArray new_order);
static GraphPtr ReorderImmutableGraph(
ImmutableGraphPtr ig, IdArray new_order);
};
} // namespace dgl
......
......@@ -16,7 +16,8 @@ namespace dgl {
* \brief Class for representing frontiers.
*
* Each frontier is a list of nodes/edges (specified by their ids).
* An optional tag can be specified on each node/edge (represented by an int value).
* An optional tag can be specified on each node/edge (represented by an int
* value).
*/
struct Frontiers {
/*!\brief a vector store for the nodes/edges in all the frontiers */
......@@ -78,10 +79,10 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
* FORWARD(0), REVERSE(1), NONTREE(2)
*
* A FORWARD edge is one in which `u` has been visisted but `v` has not.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the edge
* is in the DFS tree.
* A NONTREE edge is one in which both `u` and `v` have been visisted but the edge
* is NOT in the DFS tree.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the
* edge is in the DFS tree.
* A NONTREE edge is one in which both `u` and `v` have been visisted but the
* edge is NOT in the DFS tree.
*
* \param csr The input csr matrix.
* \param sources Source nodes.
......@@ -90,11 +91,9 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
* \param return_labels If true, return the recorded edge tags.
* \return A Frontiers object containing the search result
*/
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
const bool has_nontree_edge,
const bool return_labels);
Frontiers DGLDFSLabeledEdges(
const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
const bool has_nontree_edge, const bool return_labels);
} // namespace aten
} // namespace dgl
......
......@@ -6,17 +6,18 @@
#ifndef DGL_IMMUTABLE_GRAPH_H_
#define DGL_IMMUTABLE_GRAPH_H_
#include <vector>
#include <string>
#include <cstdint>
#include <utility>
#include <tuple>
#include <algorithm>
#include <cstdint>
#include <memory>
#include "runtime/ndarray.h"
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "base_heterograph.h"
#include "graph_interface.h"
#include "lazy.h"
#include "base_heterograph.h"
#include "runtime/ndarray.h"
namespace dgl {
......@@ -37,16 +38,16 @@ class CSR : public GraphInterface {
CSR(int64_t num_vertices, int64_t num_edges);
// Create a csr graph whose memory is stored in the shared memory
// that has the given number of verts and edges.
CSR(const std::string &shared_mem_name,
int64_t num_vertices, int64_t num_edges);
CSR(const std::string &shared_mem_name, int64_t num_vertices,
int64_t num_edges);
// Create a csr graph that shares the given indptr and indices.
CSR(IdArray indptr, IdArray indices, IdArray edge_ids);
// Create a csr graph by data iterator
template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR(int64_t num_vertices, int64_t num_edges,
IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin);
CSR(int64_t num_vertices, int64_t num_edges, IndptrIter indptr_begin,
IndicesIter indices_begin, EdgeIdIter edge_ids_begin);
// Create a csr graph whose memory is stored in the shared memory
// and the structure is given by the indptr and indcies.
......@@ -65,31 +66,19 @@ class CSR : public GraphInterface {
LOG(FATAL) << "CSR graph does not allow mutation.";
}
void Clear() override {
LOG(FATAL) << "CSR graph does not allow mutation.";
}
void Clear() override { LOG(FATAL) << "CSR graph does not allow mutation."; }
DGLContext Context() const override {
return adj_.indptr->ctx;
}
DGLContext Context() const override { return adj_.indptr->ctx; }
uint8_t NumBits() const override {
return adj_.indices->dtype.bits;
}
uint8_t NumBits() const override { return adj_.indices->dtype.bits; }
bool IsMultigraph() const override;
bool IsReadonly() const override {
return true;
}
bool IsReadonly() const override { return true; }
uint64_t NumVertices() const override {
return adj_.indptr->shape[0] - 1;
}
uint64_t NumVertices() const override { return adj_.indptr->shape[0] - 1; }
uint64_t NumEdges() const override {
return adj_.indices->shape[0];
}
uint64_t NumEdges() const override { return adj_.indices->shape[0]; }
BoolArray HasVertices(IdArray vids) const override {
LOG(FATAL) << "Not enabled for CSR graph";
......@@ -102,7 +91,7 @@ class CSR : public GraphInterface {
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
LOG(FATAL) << "CSR graph does not support efficient predecessor query."
<< " Please use successors on the reverse CSR graph.";
<< " Please use successors on the reverse CSR graph.";
return {};
}
......@@ -114,25 +103,25 @@ class CSR : public GraphInterface {
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
LOG(FATAL) << "CSR graph does not support efficient FindEdge."
<< " Please use COO graph.";
<< " Please use COO graph.";
return {};
}
EdgeArray FindEdges(IdArray eids) const override {
LOG(FATAL) << "CSR graph does not support efficient FindEdges."
<< " Please use COO graph.";
<< " Please use COO graph.";
return {};
}
EdgeArray InEdges(dgl_id_t vid) const override {
LOG(FATAL) << "CSR graph does not support efficient inedges query."
<< " Please use outedges on the reverse CSR graph.";
<< " Please use outedges on the reverse CSR graph.";
return {};
}
EdgeArray InEdges(IdArray vids) const override {
LOG(FATAL) << "CSR graph does not support efficient inedges query."
<< " Please use outedges on the reverse CSR graph.";
<< " Please use outedges on the reverse CSR graph.";
return {};
}
......@@ -144,13 +133,13 @@ class CSR : public GraphInterface {
uint64_t InDegree(dgl_id_t vid) const override {
LOG(FATAL) << "CSR graph does not support efficient indegree query."
<< " Please use outdegree on the reverse CSR graph.";
<< " Please use outdegree on the reverse CSR graph.";
return 0;
}
DegreeArray InDegrees(IdArray vids) const override {
LOG(FATAL) << "CSR graph does not support efficient indegree query."
<< " Please use outdegree on the reverse CSR graph.";
<< " Please use outdegree on the reverse CSR graph.";
return {};
}
......@@ -162,9 +151,10 @@ class CSR : public GraphInterface {
Subgraph VertexSubgraph(IdArray vids) const override;
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override {
Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const override {
LOG(FATAL) << "CSR graph does not support efficient EdgeSubgraph."
<< " Please use COO graph instead.";
<< " Please use COO graph instead.";
return {};
}
......@@ -174,25 +164,24 @@ class CSR : public GraphInterface {
DGLIdIters PredVec(dgl_id_t vid) const override {
LOG(FATAL) << "CSR graph does not support efficient PredVec."
<< " Please use SuccVec on the reverse CSR graph.";
<< " Please use SuccVec on the reverse CSR graph.";
return DGLIdIters(nullptr, nullptr);
}
DGLIdIters InEdgeVec(dgl_id_t vid) const override {
LOG(FATAL) << "CSR graph does not support efficient InEdgeVec."
<< " Please use OutEdgeVec on the reverse CSR graph.";
<< " Please use OutEdgeVec on the reverse CSR graph.";
return DGLIdIters(nullptr, nullptr);
}
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override {
std::vector<IdArray> GetAdj(
bool transpose, const std::string &fmt) const override {
CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
return {adj_.indptr, adj_.indices, adj_.data};
}
/*! \brief Indicate whether this uses shared memory. */
bool IsSharedMem() const {
return !shared_mem_name_.empty();
}
bool IsSharedMem() const { return !shared_mem_name_.empty(); }
/*! \brief Return the reverse of this CSR graph (i.e, a CSC graph) */
CSRPtr Transpose() const;
......@@ -205,16 +194,14 @@ class CSR : public GraphInterface {
* \note The csr matrix shares the storage with this graph.
* The data field of the CSR matrix stores the edge ids.
*/
aten::CSRMatrix ToCSRMatrix() const {
return adj_;
}
aten::CSRMatrix ToCSRMatrix() const { return adj_; }
/*!
* \brief Copy the data to another context.
* \param ctx The target context.
* \return The graph under another context.
*/
CSR CopyTo(const DGLContext& ctx) const;
CSR CopyTo(const DGLContext &ctx) const;
/*!
* \brief Copy data to shared memory.
......@@ -242,11 +229,10 @@ class CSR : public GraphInterface {
bool Load(dmlc::Stream *fs);
/*! \return Save CSR to stream */
void Save(dmlc::Stream* fs) const;
void Save(dmlc::Stream *fs) const;
void SortCSR() override {
if (adj_.sorted)
return;
if (adj_.sorted) return;
aten::CSRSort_(&adj_);
}
......@@ -254,7 +240,7 @@ class CSR : public GraphInterface {
friend class Serializer;
/*! \brief private default constructor */
CSR() {adj_.sorted = false;}
CSR() { adj_.sorted = false; }
// The internal CSR adjacency matrix.
// The data field stores edge ids.
aten::CSRMatrix adj_;
......@@ -267,8 +253,8 @@ class CSR : public GraphInterface {
class COO : public GraphInterface {
public:
// Create a coo graph that shares the given src and dst
COO(int64_t num_vertices, IdArray src, IdArray dst,
bool row_sorted = false, bool col_sorted = false);
COO(int64_t num_vertices, IdArray src, IdArray dst, bool row_sorted = false,
bool col_sorted = false);
// TODO(da): add constructor for creating COO from shared memory
......@@ -284,35 +270,21 @@ class COO : public GraphInterface {
LOG(FATAL) << "COO graph does not allow mutation.";
}
void Clear() override {
LOG(FATAL) << "COO graph does not allow mutation.";
}
void Clear() override { LOG(FATAL) << "COO graph does not allow mutation."; }
DGLContext Context() const override {
return adj_.row->ctx;
}
DGLContext Context() const override { return adj_.row->ctx; }
uint8_t NumBits() const override {
return adj_.row->dtype.bits;
}
uint8_t NumBits() const override { return adj_.row->dtype.bits; }
bool IsMultigraph() const override;
bool IsReadonly() const override {
return true;
}
bool IsReadonly() const override { return true; }
uint64_t NumVertices() const override {
return adj_.num_rows;
}
uint64_t NumVertices() const override { return adj_.num_rows; }
uint64_t NumEdges() const override {
return adj_.row->shape[0];
}
uint64_t NumEdges() const override { return adj_.row->shape[0]; }
bool HasVertex(dgl_id_t vid) const override {
return vid < NumVertices();
}
bool HasVertex(dgl_id_t vid) const override { return vid < NumVertices(); }
BoolArray HasVertices(IdArray vids) const override {
LOG(FATAL) << "Not enabled for COO graph";
......@@ -321,37 +293,37 @@ class COO : public GraphInterface {
bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {
LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return false;
}
BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override {
LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return {};
}
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
LOG(FATAL) << "COO graph does not support efficient Predecessors."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return {};
}
IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override {
LOG(FATAL) << "COO graph does not support efficient Successors."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return {};
}
IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const override {
LOG(FATAL) << "COO graph does not support efficient EdgeId."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return {};
}
EdgeArray EdgeIds(IdArray src, IdArray dst) const override {
LOG(FATAL) << "COO graph does not support efficient EdgeId."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return {};
}
......@@ -361,25 +333,25 @@ class COO : public GraphInterface {
EdgeArray InEdges(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient InEdges."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return {};
}
EdgeArray InEdges(IdArray vids) const override {
LOG(FATAL) << "COO graph does not support efficient InEdges."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return {};
}
EdgeArray OutEdges(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient OutEdges."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return {};
}
EdgeArray OutEdges(IdArray vids) const override {
LOG(FATAL) << "COO graph does not support efficient OutEdges."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return {};
}
......@@ -387,61 +359,63 @@ class COO : public GraphInterface {
uint64_t InDegree(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient InDegree."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return 0;
}
DegreeArray InDegrees(IdArray vids) const override {
LOG(FATAL) << "COO graph does not support efficient InDegrees."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return {};
}
uint64_t OutDegree(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient OutDegree."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return 0;
}
DegreeArray OutDegrees(IdArray vids) const override {
LOG(FATAL) << "COO graph does not support efficient OutDegrees."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return {};
}
Subgraph VertexSubgraph(IdArray vids) const override {
LOG(FATAL) << "COO graph does not support efficient VertexSubgraph."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return {};
}
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;
Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const override;
DGLIdIters SuccVec(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient SuccVec."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return DGLIdIters(nullptr, nullptr);
}
DGLIdIters OutEdgeVec(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient OutEdgeVec."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return DGLIdIters(nullptr, nullptr);
}
DGLIdIters PredVec(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient PredVec."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return DGLIdIters(nullptr, nullptr);
}
DGLIdIters InEdgeVec(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient InEdgeVec."
<< " Please use CSR graph or AdjList graph instead.";
<< " Please use CSR graph or AdjList graph instead.";
return DGLIdIters(nullptr, nullptr);
}
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override {
std::vector<IdArray> GetAdj(
bool transpose, const std::string &fmt) const override {
CHECK(fmt == "coo") << "Not valid adj format request.";
if (transpose) {
return {aten::HStack(adj_.col, adj_.row)};
......@@ -463,16 +437,14 @@ class COO : public GraphInterface {
* \note The coo matrix shares the storage with this graph.
* The data field of the coo matrix is none.
*/
aten::COOMatrix ToCOOMatrix() const {
return adj_;
}
aten::COOMatrix ToCOOMatrix() const { return adj_; }
/*!
* \brief Copy the data to another context.
* \param ctx The target context.
* \return The graph under another context.
*/
COO CopyTo(const DGLContext& ctx) const;
COO CopyTo(const DGLContext &ctx) const;
/*!
* \brief Copy data to shared memory.
......@@ -489,9 +461,7 @@ class COO : public GraphInterface {
COO AsNumBits(uint8_t bits) const;
/*! \brief Indicate whether this uses shared memory. */
bool IsSharedMem() const {
return false;
}
bool IsSharedMem() const { return false; }
// member getters
......@@ -513,40 +483,40 @@ class COO : public GraphInterface {
*
* DGL's graph is directed. Vertices are integers enumerated from zero.
*/
class ImmutableGraph: public GraphInterface {
class ImmutableGraph : public GraphInterface {
public:
/*! \brief Construct an immutable graph from the COO format. */
explicit ImmutableGraph(COOPtr coo): coo_(coo) { }
explicit ImmutableGraph(COOPtr coo) : coo_(coo) {}
/*!
* \brief Construct an immutable graph from the CSR format.
*
* For a single graph, we need two CSRs, one stores the in-edges of vertices and
* the other stores the out-edges of vertices. These two CSRs stores the same edges.
* The reason we need both is that some operators are faster on in-edge CSR and
* the other operators are faster on out-edge CSR.
* For a single graph, we need two CSRs, one stores the in-edges of vertices
* and the other stores the out-edges of vertices. These two CSRs stores the
* same edges. The reason we need both is that some operators are faster on
* in-edge CSR and the other operators are faster on out-edge CSR.
*
* However, not both CSRs are required. Technically, one CSR contains all information.
* Thus, when we construct a temporary graphs (e.g., the sampled subgraphs), we only
* construct one of the CSRs that runs fast for some operations we expect and construct
* the other CSR on demand.
* However, not both CSRs are required. Technically, one CSR contains all
* information. Thus, when we construct a temporary graphs (e.g., the sampled
* subgraphs), we only construct one of the CSRs that runs fast for some
* operations we expect and construct the other CSR on demand.
*/
ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr)
: in_csr_(in_csr), out_csr_(out_csr) {
: in_csr_(in_csr), out_csr_(out_csr) {
CHECK(in_csr_ || out_csr_) << "Both CSR are missing.";
}
/*! \brief Construct an immutable graph from one CSR. */
explicit ImmutableGraph(CSRPtr csr): out_csr_(csr) { }
explicit ImmutableGraph(CSRPtr csr) : out_csr_(csr) {}
/*! \brief default copy constructor */
ImmutableGraph(const ImmutableGraph& other) = default;
ImmutableGraph(const ImmutableGraph &other) = default;
#ifndef _MSC_VER
/*! \brief default move constructor */
ImmutableGraph(ImmutableGraph&& other) = default;
ImmutableGraph(ImmutableGraph &&other) = default;
#else
ImmutableGraph(ImmutableGraph&& other) {
ImmutableGraph(ImmutableGraph &&other) {
this->in_csr_ = other.in_csr_;
this->out_csr_ = other.out_csr_;
this->coo_ = other.coo_;
......@@ -557,7 +527,7 @@ class ImmutableGraph: public GraphInterface {
#endif // _MSC_VER
/*! \brief default assign constructor */
ImmutableGraph& operator=(const ImmutableGraph& other) = default;
ImmutableGraph &operator=(const ImmutableGraph &other) = default;
/*! \brief default destructor */
~ImmutableGraph() = default;
......@@ -578,28 +548,20 @@ class ImmutableGraph: public GraphInterface {
LOG(FATAL) << "Clear isn't supported in ImmutableGraph";
}
DGLContext Context() const override {
return AnyGraph()->Context();
}
DGLContext Context() const override { return AnyGraph()->Context(); }
uint8_t NumBits() const override {
return AnyGraph()->NumBits();
}
uint8_t NumBits() const override { return AnyGraph()->NumBits(); }
/*!
* \note not const since we have caches
* \return whether the graph is a multigraph
*/
bool IsMultigraph() const override {
return AnyGraph()->IsMultigraph();
}
bool IsMultigraph() const override { return AnyGraph()->IsMultigraph(); }
/*!
* \return whether the graph is read-only
*/
bool IsReadonly() const override {
return true;
}
bool IsReadonly() const override { return true; }
/**
* \brief Check if the graph is unibipartite.
......@@ -616,19 +578,13 @@ class ImmutableGraph: public GraphInterface {
}
/*! \return the number of vertices in the graph.*/
uint64_t NumVertices() const override {
return AnyGraph()->NumVertices();
}
uint64_t NumVertices() const override { return AnyGraph()->NumVertices(); }
/*! \return the number of edges in the graph.*/
uint64_t NumEdges() const override {
return AnyGraph()->NumEdges();
}
uint64_t NumEdges() const override { return AnyGraph()->NumEdges(); }
/*! \return true if the given vertex is in the graph.*/
bool HasVertex(dgl_id_t vid) const override {
return vid < NumVertices();
}
bool HasVertex(dgl_id_t vid) const override { return vid < NumVertices(); }
BoolArray HasVertices(IdArray vids) const override;
......@@ -652,7 +608,8 @@ class ImmutableGraph: public GraphInterface {
/*!
* \brief Find the predecessors of a vertex.
* \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the predecessor id array.
*/
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
......@@ -662,7 +619,8 @@ class ImmutableGraph: public GraphInterface {
/*!
* \brief Find the successors of a vertex.
* \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the successor id array.
*/
IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override {
......@@ -706,7 +664,8 @@ class ImmutableGraph: public GraphInterface {
/*!
* \brief Find the edge ID and return the pair of endpoints
* \param eid The edge ID
* \return a pair whose first element is the source and the second the destination.
* \return a pair whose first element is the source and the second the
* destination.
*/
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
return GetCOO()->FindEdge(eid);
......@@ -715,7 +674,8 @@ class ImmutableGraph: public GraphInterface {
/*!
* \brief Find the edge IDs and return their source and target node IDs.
* \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved.
* \return EdgeArray containing all edges with id in eid. The order is
* preserved.
*/
EdgeArray FindEdges(IdArray eids) const override {
return GetCOO()->FindEdges(eids);
......@@ -728,7 +688,7 @@ class ImmutableGraph: public GraphInterface {
* \return the edges
*/
EdgeArray InEdges(dgl_id_t vid) const override {
const EdgeArray& ret = GetInCSR()->OutEdges(vid);
const EdgeArray &ret = GetInCSR()->OutEdges(vid);
return {ret.dst, ret.src, ret.id};
}
......@@ -738,7 +698,7 @@ class ImmutableGraph: public GraphInterface {
* \return the id arrays of the two endpoints of the edges.
*/
EdgeArray InEdges(IdArray vids) const override {
const EdgeArray& ret = GetInCSR()->OutEdges(vids);
const EdgeArray &ret = GetInCSR()->OutEdges(vids);
return {ret.dst, ret.src, ret.id};
}
......@@ -765,7 +725,8 @@ class ImmutableGraph: public GraphInterface {
* \brief Get all the edges in the graph.
* \note If sorted is true, the returned edges list is sorted by their src and
* dst ids. Otherwise, they are in their edge id order.
* \param sorted Whether the returned edge list is sorted by their src and dst ids
* \param sorted Whether the returned edge list is sorted by their src and dst
* ids.
* \return the id arrays of the two endpoints of the edges.
*/
EdgeArray Edges(const std::string &order = "") const override;
......@@ -809,13 +770,14 @@ class ImmutableGraph: public GraphInterface {
/*!
* \brief Construct the induced subgraph of the given vertices.
*
* The induced subgraph is a subgraph formed by specifying a set of vertices V' and then
* selecting all of the edges from the original graph that connect two vertices in V'.
* The induced subgraph is a subgraph formed by specifying a set of vertices
* V' and then selecting all of the edges from the original graph that connect
* two vertices in V'.
*
* Vertices and edges in the original graph will be "reindexed" to local index. The local
* index of the vertices preserve the order of the given id array, while the local index
* of the edges preserve the index order in the original graph. Vertices not in the
* original graph are ignored.
* Vertices and edges in the original graph will be "reindexed" to local
* index. The local index of the vertices preserve the order of the given id
* array, while the local index of the edges preserve the index order in the
* original graph. Vertices not in the original graph are ignored.
*
* The result subgraph is read-only.
*
......@@ -827,20 +789,22 @@ class ImmutableGraph: public GraphInterface {
/*!
* \brief Construct the induced edge subgraph of the given edges.
*
* The induced edges subgraph is a subgraph formed by specifying a set of edges E' and then
* selecting all of the nodes from the original graph that are endpoints in E'.
* The induced edges subgraph is a subgraph formed by specifying a set of
* edges E' and then selecting all of the nodes from the original graph that
* are endpoints in E'.
*
* Vertices and edges in the original graph will be "reindexed" to local index. The local
* index of the edges preserve the order of the given id array, while the local index
* of the vertices preserve the index order in the original graph. Edges not in the
* original graph are ignored.
* Vertices and edges in the original graph will be "reindexed" to local
* index. The local index of the edges preserve the order of the given id
* array, while the local index of the vertices preserve the index order in
* the original graph. Edges not in the original graph are ignored.
*
* The result subgraph is read-only.
*
* \param eids The edges in the subgraph.
* \return the induced edge subgraph
*/
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;
Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const override;
/*!
* \brief Return the successor vector
......@@ -887,7 +851,8 @@ class ImmutableGraph: public GraphInterface {
* \param fmt the format of the returned adjacency matrix.
* \return a vector of three IdArray.
*/
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override;
std::vector<IdArray> GetAdj(
bool transpose, const std::string &fmt) const override;
/* !\brief Return in csr. If not exist, transpose the other one.*/
CSRPtr GetInCSR() const;
......@@ -900,14 +865,15 @@ class ImmutableGraph: public GraphInterface {
/*! \brief Create an immutable graph from CSR. */
static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir);
IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir);
static ImmutableGraphPtr CreateFromCSR(const std::string &shared_mem_name);
/*! \brief Create an immutable graph from COO. */
static ImmutableGraphPtr CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst,
bool row_osrted = false, bool col_sorted = false);
int64_t num_vertices, IdArray src, IdArray dst, bool row_osrted = false,
bool col_sorted = false);
/*!
* \brief Convert the given graph to an immutable graph.
......@@ -925,14 +891,15 @@ class ImmutableGraph: public GraphInterface {
* \param ctx The target context.
* \return The graph under another context.
*/
static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DGLContext& ctx);
static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DGLContext &ctx);
/*!
* \brief Copy data to shared memory.
* \param name The name of the shared memory.
* \return The graph in the shared memory
*/
static ImmutableGraphPtr CopyToSharedMem(ImmutableGraphPtr g, const std::string &name);
static ImmutableGraphPtr CopyToSharedMem(
ImmutableGraphPtr g, const std::string &name);
/*!
* \brief Convert the graph to use the given number of bits for storage.
......@@ -944,7 +911,8 @@ class ImmutableGraph: public GraphInterface {
/*!
* \brief Return a new graph with all the edges reversed.
*
* The returned graph preserves the vertex and edge index in the original graph.
* The returned graph preserves the vertex and edge index in the original
* graph.
*
* \return the reversed graph
*/
......@@ -954,20 +922,16 @@ class ImmutableGraph: public GraphInterface {
bool Load(dmlc::Stream *fs);
/*! \return Save ImmutableGraph to stream, using out csr */
void Save(dmlc::Stream* fs) const;
void Save(dmlc::Stream *fs) const;
void SortCSR() override {
GetInCSR()->SortCSR();
GetOutCSR()->SortCSR();
}
bool HasInCSR() const {
return in_csr_ != NULL;
}
bool HasInCSR() const { return in_csr_ != NULL; }
bool HasOutCSR() const {
return out_csr_ != NULL;
}
bool HasOutCSR() const { return out_csr_ != NULL; }
/*! \brief Cast this graph to a heterograph */
HeteroGraphPtr AsHeteroGraph() const;
......@@ -981,12 +945,13 @@ class ImmutableGraph: public GraphInterface {
/* !\brief internal constructor for all the members */
ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo)
: in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
: in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
CHECK(AnyGraph()) << "At least one graph structure should exist.";
}
ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, const std::string shared_mem_name)
: in_csr_(in_csr), out_csr_(out_csr) {
ImmutableGraph(
CSRPtr in_csr, CSRPtr out_csr, const std::string shared_mem_name)
: in_csr_(in_csr), out_csr_(out_csr) {
CHECK(in_csr_ || out_csr_) << "Both CSR are missing.";
this->shared_mem_name_ = shared_mem_name;
}
......@@ -1025,18 +990,19 @@ class ImmutableGraph: public GraphInterface {
// inline implementations
template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR::CSR(int64_t num_vertices, int64_t num_edges,
IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin) {
CSR::CSR(
int64_t num_vertices, int64_t num_edges, IndptrIter indptr_begin,
IndicesIter indices_begin, EdgeIdIter edge_ids_begin) {
// TODO(minjie): this should be changed to a device-agnostic implementation
// in the future
// in the future.
adj_.num_rows = num_vertices;
adj_.num_cols = num_vertices;
adj_.indptr = aten::NewIdArray(num_vertices + 1);
adj_.indices = aten::NewIdArray(num_edges);
adj_.data = aten::NewIdArray(num_edges);
dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
dgl_id_t* edge_ids_data = static_cast<dgl_id_t*>(adj_.data->data);
dgl_id_t *indptr_data = static_cast<dgl_id_t *>(adj_.indptr->data);
dgl_id_t *indices_data = static_cast<dgl_id_t *>(adj_.indices->data);
dgl_id_t *edge_ids_data = static_cast<dgl_id_t *>(adj_.data->data);
for (int64_t i = 0; i < num_vertices + 1; ++i)
*(indptr_data++) = *(indptr_begin++);
for (int64_t i = 0; i < num_edges; ++i) {
......
......@@ -7,12 +7,12 @@
#define DGL_KERNEL_H_
#include <string>
#include <vector>
#include <utility>
#include <vector>
#include "array.h"
#include "./bcast.h"
#include "./base_heterograph.h"
#include "./bcast.h"
#include "array.h"
namespace dgl {
namespace aten {
......@@ -30,12 +30,9 @@ namespace aten {
* as the argmax on source nodes and edges for reduce operators such as
* `min` and `max`.
*/
void SpMM(const std::string& op, const std::string& reduce,
HeteroGraphPtr graph,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux);
void SpMM(
const std::string& op, const std::string& reduce, HeteroGraphPtr graph,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
/*!
* \brief Generalized Sampled Dense-Dense Matrix Multiplication.
......@@ -46,23 +43,18 @@ void SpMM(const std::string& op, const std::string& reduce,
* \param vfeat The destination node feature.
* \param out The output feature on edge.
*/
void SDDMM(const std::string& op,
HeteroGraphPtr graph,
NDArray ufeat,
NDArray efeat,
NDArray out);
void SDDMM(
const std::string& op, HeteroGraphPtr graph, NDArray ufeat, NDArray efeat,
NDArray out);
/*!
* \brief Sparse-sparse matrix multiplication.
*
* The sparse matrices must have scalar weights (i.e. \a A_weights and \a B_weights
* are 1D vectors.)
* The sparse matrices must have scalar weights (i.e. \a A_weights and \a
* B_weights are 1D vectors.)
*/
std::pair<CSRMatrix, NDArray> CSRMM(
CSRMatrix A,
NDArray A_weights,
CSRMatrix B,
NDArray B_weights);
CSRMatrix A, NDArray A_weights, CSRMatrix B, NDArray B_weights);
/*!
* \brief Summing up a list of sparse matrices.
......@@ -71,8 +63,7 @@ std::pair<CSRMatrix, NDArray> CSRMM(
* are 1D vectors.)
*/
std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& A,
const std::vector<NDArray>& A_weights);
const std::vector<CSRMatrix>& A, const std::vector<NDArray>& A_weights);
} // namespace aten
} // namespace dgl
......
......@@ -22,15 +22,17 @@ class Lazy {
/*!\brief default constructor to construct a lazy object */
Lazy() {}
/*!\brief constructor to construct an object with given value (non-lazy case) */
explicit Lazy(const T& val): ptr_(new T(val)) {}
/*!
* \brief constructor to construct an object with given value (non-lazy case)
*/
explicit Lazy(const T& val) : ptr_(new T(val)) {}
/*!\brief destructor */
~Lazy() = default;
/*!
* \brief Get the value of this object. If the object has not been instantiated,
* using the provided function to create it.
* \brief Get the value of this object. If the object has not been
* instantiated, using the provided function to create it.
* \param fn The creator function.
* \return the object value.
*/
......
......@@ -6,9 +6,9 @@
#ifndef DGL_NODEFLOW_H_
#define DGL_NODEFLOW_H_
#include <vector>
#include <string>
#include <memory>
#include <string>
#include <vector>
#include "./runtime/object.h"
#include "graph_interface.h"
......@@ -18,12 +18,12 @@ namespace dgl {
class ImmutableGraph;
/*!
* \brief A NodeFlow graph stores the sampling results for a sampler that samples
* nodes/edges in layers.
* \brief A NodeFlow graph stores the sampling results for a sampler that
* samples nodes/edges in layers.
*
* We store multiple layers of the sampling results in a single graph, which results
* in a more compact format. We store extra information,
* such as the node and edge mapping from the NodeFlow graph to the parent graph.
* We store multiple layers of the sampling results in a single graph, which
* results in a more compact format. We store extra information, such as the
* node and edge mapping from the NodeFlow graph to the parent graph.
*/
struct NodeFlowObject : public runtime::Object {
/*! \brief The graph. */
......@@ -45,7 +45,7 @@ struct NodeFlowObject : public runtime::Object {
*/
IdArray edge_mapping;
static constexpr const char* _type_key = "graph.NodeFlow";
static constexpr const char *_type_key = "graph.NodeFlow";
DGL_DECLARE_OBJECT_TYPE_INFO(NodeFlowObject, runtime::Object);
};
......@@ -80,8 +80,8 @@ class NodeFlow : public runtime::ObjectRef {
* of an edge and the column represents the source.
*
* If fmt == "csr", the function returns three arrays: indptr, indices, eid.
* If fmt == "coo", the function returns two arrays: idx, eid. Here, the idx array
* is the concatenation of src and dst node id arrays.
* If fmt == "coo", the function returns two arrays: idx, eid. Here, the idx
* array is the concatenation of src and dst node id arrays.
*
* \param graph An immutable graph.
* \param fmt the format of the returned adjacency matrix.
......@@ -92,11 +92,10 @@ class NodeFlow : public runtime::ObjectRef {
* space.
* \return a vector of IdArrays.
*/
std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::string &fmt,
size_t layer0_size, size_t layer1_start,
size_t layer1_end, bool remap);
std::vector<IdArray> GetNodeFlowSlice(
const ImmutableGraph &graph, const std::string &fmt, size_t layer0_size,
size_t layer1_start, size_t layer1_end, bool remap);
} // namespace dgl
#endif // DGL_NODEFLOW_H_
......@@ -7,14 +7,14 @@
#ifndef DGL_PACKED_FUNC_EXT_H_
#define DGL_PACKED_FUNC_EXT_H_
#include <memory>
#include <sstream>
#include <string>
#include <memory>
#include <type_traits>
#include "./runtime/packed_func.h"
#include "./runtime/object.h"
#include "./runtime/container.h"
#include "./runtime/object.h"
#include "./runtime/packed_func.h"
namespace dgl {
namespace runtime {
......@@ -22,7 +22,7 @@ namespace runtime {
* \brief Runtime type checker for node type.
* \tparam T the type to be checked.
*/
template<typename T>
template <typename T>
struct ObjectTypeChecker {
static inline bool Check(Object* sptr) {
// This is the only place in the project where RTTI is used
......@@ -31,13 +31,13 @@ struct ObjectTypeChecker {
using ContainerType = typename T::ContainerType;
return sptr->derived_from<ContainerType>();
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};
template<typename T>
template <typename T>
struct ObjectTypeChecker<List<T> > {
static inline bool Check(Object* sptr) {
if (sptr == nullptr) return false;
......@@ -48,14 +48,14 @@ struct ObjectTypeChecker<List<T> > {
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "list<";
ObjectTypeChecker<T>::PrintName(os);
os << ">";
}
};
template<typename V>
template <typename V>
struct ObjectTypeChecker<Map<std::string, V> > {
static inline bool Check(Object* sptr) {
if (sptr == nullptr) return false;
......@@ -66,7 +66,7 @@ struct ObjectTypeChecker<Map<std::string, V> > {
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<string";
os << ',';
ObjectTypeChecker<V>::PrintName(os);
......@@ -74,9 +74,7 @@ struct ObjectTypeChecker<Map<std::string, V> > {
}
};
template<typename K, typename V>
template <typename K, typename V>
struct ObjectTypeChecker<Map<K, V> > {
static inline bool Check(Object* sptr) {
if (sptr == nullptr) return false;
......@@ -88,7 +86,7 @@ struct ObjectTypeChecker<Map<K, V> > {
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<";
ObjectTypeChecker<K>::PrintName(os);
os << ',';
......@@ -97,7 +95,7 @@ struct ObjectTypeChecker<Map<K, V> > {
}
};
template<typename T>
template <typename T>
inline std::string NodeTypeName() {
std::ostringstream os;
ObjectTypeChecker<T>::PrintName(os);
......@@ -106,7 +104,7 @@ inline std::string NodeTypeName() {
// extensions for DGLArgValue
template<typename TObjectRef>
template <typename TObjectRef>
inline TObjectRef DGLArgValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
......@@ -115,8 +113,8 @@ inline TObjectRef DGLArgValue::AsObjectRef() const {
DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle);
std::shared_ptr<Object>& sptr = *ptr<std::shared_ptr<Object> >();
CHECK(ObjectTypeChecker<TObjectRef>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<TObjectRef>()
<< " but get " << sptr->type_key();
<< "Expected type " << NodeTypeName<TObjectRef>() << " but get "
<< sptr->type_key();
return TObjectRef(sptr);
}
......@@ -125,12 +123,10 @@ inline std::shared_ptr<Object>& DGLArgValue::obj_sptr() {
return *ptr<std::shared_ptr<Object> >();
}
template<typename TObjectRef, typename>
template <typename TObjectRef, typename>
inline bool DGLArgValue::IsObjectType() const {
DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle);
std::shared_ptr<Object>& sptr =
*ptr<std::shared_ptr<Object> >();
std::shared_ptr<Object>& sptr = *ptr<std::shared_ptr<Object> >();
return ObjectTypeChecker<TObjectRef>::Check(sptr.get());
}
......@@ -155,7 +151,7 @@ inline DGLRetValue& DGLRetValue::operator=(const ObjectRef& other) {
return *this;
}
template<typename TObjectRef>
template <typename TObjectRef>
inline TObjectRef DGLRetValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
......@@ -165,7 +161,8 @@ inline TObjectRef DGLRetValue::AsObjectRef() const {
return TObjectRef(*ptr<std::shared_ptr<Object> >());
}
inline void DGLArgsSetter::operator()(size_t i, const ObjectRef& other) const { // NOLINT(*)
inline void DGLArgsSetter::operator()(
size_t i, const ObjectRef& other) const { // NOLINT(*)
if (other.defined()) {
values_[i].v_handle = const_cast<std::shared_ptr<Object>*>(&(other.obj_));
type_codes_[i] = kObjectHandle;
......
......@@ -8,8 +8,9 @@
#define DGL_RANDOM_H_
#include <dgl/array.h>
#include <dmlc/thread_local.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <random>
#include <thread>
#include <vector>
......@@ -46,33 +47,27 @@ class RandomEngine {
}
/*! \brief Constructor with given seed */
explicit RandomEngine(uint32_t seed) {
SetSeed(seed);
}
explicit RandomEngine(uint32_t seed) { SetSeed(seed); }
/*! \brief Get the thread-local random number generator instance */
static RandomEngine *ThreadLocal() {
static RandomEngine* ThreadLocal() {
return dmlc::ThreadLocalStore<RandomEngine>::Get();
}
/*!
* \brief Set the seed of this random number generator
*/
void SetSeed(uint32_t seed) {
rng_.seed(seed + GetThreadId());
}
void SetSeed(uint32_t seed) { rng_.seed(seed + GetThreadId()); }
/*!
* \brief Generate an arbitrary random 32-bit integer.
*/
int32_t RandInt32() {
return static_cast<int32_t>(rng_());
}
int32_t RandInt32() { return static_cast<int32_t>(rng_()); }
/*!
* \brief Generate a uniform random integer in [0, upper)
*/
template<typename T>
template <typename T>
T RandInt(T upper) {
return RandInt<T>(0, upper);
}
......@@ -80,7 +75,7 @@ class RandomEngine {
/*!
* \brief Generate a uniform random integer in [lower, upper)
*/
template<typename T>
template <typename T>
T RandInt(T lower, T upper) {
CHECK_LT(lower, upper);
std::uniform_int_distribution<T> dist(lower, upper - 1);
......@@ -90,7 +85,7 @@ class RandomEngine {
/*!
* \brief Generate a uniform random float in [0, 1)
*/
template<typename T>
template <typename T>
T Uniform() {
return Uniform<T>(0., 1.);
}
......@@ -98,7 +93,7 @@ class RandomEngine {
/*!
* \brief Generate a uniform random float in [lower, upper)
*/
template<typename T>
template <typename T>
T Uniform(T lower, T upper) {
// Although the result is in [lower, upper), we allow lower == upper as in
// www.cplusplus.com/reference/random/uniform_real_distribution/uniform_real_distribution/
......@@ -108,23 +103,27 @@ class RandomEngine {
}
/*!
* \brief Pick a random integer between 0 to N-1 according to given probabilities
* \tparam IdxType Return integer type
* \param prob Array of N unnormalized probability of each element. Must be non-negative.
* \brief Pick a random integer between 0 to N-1 according to given
* probabilities.
* \tparam IdxType Return integer type.
* \param prob Array of N unnormalized probability of each element. Must be
* non-negative.
* \return An integer randomly picked from 0 to N-1.
*/
template<typename IdxType>
template <typename IdxType>
IdxType Choice(FloatArray prob);
/*!
* \brief Pick random integers between 0 to N-1 according to given probabilities
*
* \brief Pick random integers between 0 to N-1 according to given
* probabilities
*
* If replace is false, the number of picked integers must not larger than N.
*
* \tparam IdxType Id type
* \tparam FloatType Probability value type
* \param num Number of integers to choose
* \param prob Array of N unnormalized probability of each element. Must be non-negative.
* \param prob Array of N unnormalized probability of each element. Must be
* non-negative.
* \param out The output buffer to write selected indices.
* \param replace If true, choose with replacement.
*/
......@@ -132,14 +131,16 @@ class RandomEngine {
void Choice(IdxType num, FloatArray prob, IdxType* out, bool replace = true);
/*!
* \brief Pick random integers between 0 to N-1 according to given probabilities
*
* \brief Pick random integers between 0 to N-1 according to given
* probabilities
*
* If replace is false, the number of picked integers must not larger than N.
*
* \tparam IdxType Id type
* \tparam FloatType Probability value type
* \param num Number of integers to choose
* \param prob Array of N unnormalized probability of each element. Must be non-negative.
* \param prob Array of N unnormalized probability of each element. Must be
* non-negative.
* \param replace If true, choose with replacement.
* \return Picked indices
*/
......@@ -147,7 +148,8 @@ class RandomEngine {
IdArray Choice(IdxType num, FloatArray prob, bool replace = true) {
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, prob->ctx);
Choice<IdxType, FloatType>(num, prob, static_cast<IdxType*>(ret->data), replace);
Choice<IdxType, FloatType>(
num, prob, static_cast<IdxType*>(ret->data), replace);
return ret;
}
......@@ -163,7 +165,8 @@ class RandomEngine {
* \param replace If true, choose with replacement.
*/
template <typename IdxType>
void UniformChoice(IdxType num, IdxType population, IdxType* out, bool replace = true);
void UniformChoice(
IdxType num, IdxType population, IdxType* out, bool replace = true);
/*!
* \brief Pick random integers from population by uniform distribution.
......@@ -181,43 +184,48 @@ class RandomEngine {
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
// TODO(minjie): only CPU implementation right now
IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0});
UniformChoice<IdxType>(num, population, static_cast<IdxType*>(ret->data), replace);
UniformChoice<IdxType>(
num, population, static_cast<IdxType*>(ret->data), replace);
return ret;
}
/*!
* \brief Pick random integers with different probability for different segments.
* \brief Pick random integers with different probability for different
* segments.
*
* For example, if split=[0, 4, 10] and bias=[1.5, 1], it means to pick some integers
* from 0 to 9, which is divided into two segments. 0-3 are in the first segment and the rest
* belongs to the second. The weight(bias) of each candidate in the first segment is upweighted
* to 1.5.
* For example, if split=[0, 4, 10] and bias=[1.5, 1], it means to pick some
* integers from 0 to 9, which is divided into two segments. 0-3 are in the
* first segment and the rest belongs to the second. The weight(bias) of each
* candidate in the first segment is upweighted to 1.5.
*
* candidate | 0 1 2 3 | 4 5 6 7 8 9 |
* split ^ ^ ^
* bias | 1.5 | 1 |
*
*
* The complexity of this operator is O(k * log(T)) where k is the number of integers we want
* to pick, and T is the number of segments. It is much faster compared with assigning
* probability for each candidate, of which the complexity is O(k * log(N)) where N is the
* number of all candidates.
* The complexity of this operator is O(k * log(T)) where k is the number of
* integers we want to pick, and T is the number of segments. It is much
* faster compared with assigning probability for each candidate, of which the
* complexity is O(k * log(N)) where N is the number of all candidates.
*
* If replace is false, num must not be larger than population.
*
* \tparam IdxType Return integer type
* \param num Number of integers to choose
* \param split Array of T+1 split positions of different segments(including start and end)
* \param bias Array of T weight of each segments
* \param split Array of T+1 split positions of different segments(including
* start and end)
* \param bias Array of T weight of each segments.
* \param out The output buffer to write selected indices.
* \param replace If true, choose with replacement.
*/
template <typename IdxType, typename FloatType>
void BiasedChoice(
IdxType num, const IdxType *split, FloatArray bias, IdxType* out, bool replace = true);
IdxType num, const IdxType* split, FloatArray bias, IdxType* out,
bool replace = true);
/*!
* \brief Pick random integers with different probability for different segments.
/*!
* \brief Pick random integers with different probability for different
* segments.
*
* If replace is false, num must not be larger than population.
*
......@@ -229,10 +237,11 @@ class RandomEngine {
*/
template <typename IdxType, typename FloatType>
IdArray BiasedChoice(
IdxType num, const IdxType *split, FloatArray bias, bool replace = true) {
IdxType num, const IdxType* split, FloatArray bias, bool replace = true) {
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0});
BiasedChoice<IdxType, FloatType>(num, split, bias, static_cast<IdxType*>(ret->data), replace);
BiasedChoice<IdxType, FloatType>(
num, split, bias, static_cast<IdxType*>(ret->data), replace);
return ret;
}
......
......@@ -27,9 +27,8 @@ extern "C" {
* \param out The result function.
* \return 0 when no error is thrown, -1 when failure happens
*/
DGL_DLL int DGLBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
DGLFunctionHandle *out);
DGL_DLL int DGLBackendGetFuncFromEnv(
void* mod_node, const char* func_name, DGLFunctionHandle* out);
/*!
* \brief Backend function to register system-wide library symbol.
*
......@@ -42,22 +41,21 @@ DGL_DLL int DGLBackendRegisterSystemLibSymbol(const char* name, void* ptr);
/*!
* \brief Backend function to allocate temporal workspace.
*
* \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
* \note The result allocate spaced is ensured to be aligned to
* kTempAllocaAlignment.
*
* \param nbytes The size of the space requested.
* \param device_type The device type which the space will be allocated.
* \param device_id The device id which the space will be allocated.
* \param dtype_code_hint The type code of the array elements. Only used in
* certain backends such as OpenGL.
* certain backends such as OpenGL.
* \param dtype_bits_hint The type bits of the array elements. Only used in
* certain backends such as OpenGL.
* certain backends such as OpenGL.
* \return nullptr when error is thrown, a valid ptr if success
*/
DGL_DLL void* DGLBackendAllocWorkspace(int device_type,
int device_id,
uint64_t nbytes,
int dtype_code_hint,
int dtype_bits_hint);
DGL_DLL void* DGLBackendAllocWorkspace(
int device_type, int device_id, uint64_t nbytes, int dtype_code_hint,
int dtype_bits_hint);
/*!
* \brief Backend function to free temporal workspace.
......@@ -69,9 +67,7 @@ DGL_DLL void* DGLBackendAllocWorkspace(int device_type,
*
* \sa DGLBackendAllocWorkspace
*/
DGL_DLL int DGLBackendFreeWorkspace(int device_type,
int device_id,
void* ptr);
DGL_DLL int DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr);
/*!
* \brief Environment for DGL parallel task.
......@@ -100,13 +96,12 @@ typedef int (*FDGLParallelLambda)(
* \param flambda The parallel function to be launched.
* \param cdata The closure data.
* \param num_task Number of tasks to launch, can be 0, means launch
* with all available threads.
* with all available threads.
*
* \return 0 when no error is thrown, -1 when failure happens
*/
DGL_DLL int DGLBackendParallelLaunch(FDGLParallelLambda flambda,
void* cdata,
int num_task);
DGL_DLL int DGLBackendParallelLaunch(
FDGLParallelLambda flambda, void* cdata, int num_task);
/*!
* \brief BSP barrrier between parallel threads
......@@ -116,7 +111,6 @@ DGL_DLL int DGLBackendParallelLaunch(FDGLParallelLambda flambda,
*/
DGL_DLL int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv);
/*!
* \brief Simple static initialization fucntion.
* Run f once and set handle to be not null.
......@@ -128,10 +122,8 @@ DGL_DLL int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv);
* \param nbytes Number of bytes in the closure data.
* \return 0 when no error is thrown, -1 when failure happens
*/
DGL_DLL int DGLBackendRunOnce(void** handle,
int (*f)(void*),
void *cdata,
int nbytes);
DGL_DLL int DGLBackendRunOnce(
void** handle, int (*f)(void*), void* cdata, int nbytes);
#ifdef __cplusplus
} // DGL_EXTERN_C
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment