Unverified Commit 07dc8fb6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4800)



* clang-format

* manul

* manul

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 9d9280cb
......@@ -15,9 +15,9 @@ namespace featgraph {
void LoadFeatGraphModule(const std::string& path);
/* \brief Call Featgraph's SDDMM kernel. */
void SDDMMTreeReduction(DLManagedTensor* row, DLManagedTensor* col,
DLManagedTensor* lhs, DLManagedTensor* rhs,
DLManagedTensor* out);
void SDDMMTreeReduction(
DLManagedTensor* row, DLManagedTensor* col, DLManagedTensor* lhs,
DLManagedTensor* rhs, DLManagedTensor* out);
} // namespace featgraph
} // namespace dgl
......
......@@ -3,18 +3,18 @@
* \file featgraph/src/featgraph.cc
* \brief FeatGraph kernels.
*/
#include <dmlc/logging.h>
#include <featgraph.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <dmlc/logging.h>
#include <featgraph.h>
namespace dgl {
namespace featgraph {
/* \brief Singleton that loads the featgraph module. */
class FeatGraphModule {
public:
public:
static FeatGraphModule* Global() {
static FeatGraphModule inst;
return &inst;
......@@ -32,7 +32,8 @@ public:
}
return ret;
}
private:
private:
tvm::runtime::Module mod;
FeatGraphModule() {}
};
......@@ -44,34 +45,36 @@ void LoadFeatGraphModule(const std::string& path) {
/* \brief Convert DLDataType to string. */
inline std::string DTypeAsStr(const DLDataType& t) {
switch(t.code) {
case 0U: return "int" + std::to_string(t.bits);
case 1U: return "uint" + std::to_string(t.bits);
case 2U: return "float" + std::to_string(t.bits);
case 3U: return "bfloat" + std::to_string(t.bits);
default: LOG(FATAL) << "Type code " << t.code << " not recognized";
switch (t.code) {
case 0U:
return "int" + std::to_string(t.bits);
case 1U:
return "uint" + std::to_string(t.bits);
case 2U:
return "float" + std::to_string(t.bits);
case 3U:
return "bfloat" + std::to_string(t.bits);
default:
LOG(FATAL) << "Type code " << t.code << " not recognized";
}
}
/* \brief Get operator filename. */
inline std::string GetOperatorName(
const std::string& base_name,
const DLDataType& dtype,
const std::string& base_name, const DLDataType& dtype,
const DLDataType& idtype) {
return base_name + "_" + DTypeAsStr(dtype) + "_" + DTypeAsStr(idtype);
}
/* \brief Call FeatGraph's SDDMM kernel. */
void SDDMMTreeReduction(DLManagedTensor* row, DLManagedTensor* col,
DLManagedTensor* lhs, DLManagedTensor* rhs,
DLManagedTensor* out) {
void SDDMMTreeReduction(
DLManagedTensor* row, DLManagedTensor* col, DLManagedTensor* lhs,
DLManagedTensor* rhs, DLManagedTensor* out) {
tvm::runtime::ModuleNode* mod = FeatGraphModule::Global()->Get();
std::string f_name = GetOperatorName("SDDMMTreeReduction",
(row->dl_tensor).dtype,
(lhs->dl_tensor).dtype);
std::string f_name = GetOperatorName(
"SDDMMTreeReduction", (row->dl_tensor).dtype, (lhs->dl_tensor).dtype);
tvm::runtime::PackedFunc f = mod->GetFunction(f_name);
if (f != nullptr)
f(row, col, lhs, rhs, out);
if (f != nullptr) f(row, col, lhs, rhs, out);
}
} // namespace featgraph
......
/*
* NOTE(zihao): this file was modified from TVM project:
* - https://github.com/apache/tvm/blob/9713d675c64ae3075e10be5acadeef1328a44bb5/apps/howto_deploy/tvm_runtime_pack.cc
* -
* https://github.com/apache/tvm/blob/9713d675c64ae3075e10be5acadeef1328a44bb5/apps/howto_deploy/tvm_runtime_pack.cc
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
......
......@@ -8,10 +8,10 @@
*/
#ifndef DGL_ARRAY_H_
#define DGL_ARRAY_H_
#include "./aten/types.h"
#include "./aten/array_ops.h"
#include "./aten/coo.h"
#include "./aten/csr.h"
#include "./aten/macro.h"
#include "./aten/spmat.h"
#include "./aten/csr.h"
#include "./aten/coo.h"
#include "./aten/types.h"
#endif // DGL_ARRAY_H_
......@@ -12,8 +12,8 @@
#define CUB_INLINE inline
#endif // __CUDA_ARCH__
#include <iterator>
#include <algorithm>
#include <iterator>
#include <utility>
namespace dgl {
......@@ -51,8 +51,8 @@ CUB_INLINE void swap(const Pair<DType>& r1, const Pair<DType>& r2) {
r1.swap(r2);
}
// PairRef and PairIterator that serves as an iterator over a pair of arrays in a
// zipped fashion like zip(a, b).
// PairRef and PairIterator that serves as an iterator over a pair of arrays in
// a zipped fashion like zip(a, b).
template <typename DType>
struct PairRef {
PairRef() = delete;
......@@ -69,9 +69,7 @@ struct PairRef {
*b = val.second;
return *this;
}
CUB_INLINE operator Pair<DType>() const {
return Pair<DType>(*a, *b);
}
CUB_INLINE operator Pair<DType>() const { return Pair<DType>(*a, *b); }
CUB_INLINE operator std::pair<DType, DType>() const {
return std::make_pair(*a, *b);
}
......@@ -91,11 +89,9 @@ CUB_INLINE void swap(const PairRef<DType>& r1, const PairRef<DType>& r2) {
}
template <typename DType>
struct PairIterator : public std::iterator<std::random_access_iterator_tag,
Pair<DType>,
std::ptrdiff_t,
Pair<DType*>,
PairRef<DType>> {
struct PairIterator : public std::iterator<
std::random_access_iterator_tag, Pair<DType>,
std::ptrdiff_t, Pair<DType*>, PairRef<DType>> {
PairIterator() = default;
PairIterator(const PairIterator& other) = default;
PairIterator(PairIterator&& other) = default;
......@@ -103,12 +99,24 @@ struct PairIterator : public std::iterator<std::random_access_iterator_tag,
PairIterator& operator=(const PairIterator& other) = default;
PairIterator& operator=(PairIterator&& other) = default;
~PairIterator() = default;
CUB_INLINE bool operator==(const PairIterator& other) const { return a == other.a; }
CUB_INLINE bool operator!=(const PairIterator& other) const { return a != other.a; }
CUB_INLINE bool operator<(const PairIterator& other) const { return a < other.a; }
CUB_INLINE bool operator>(const PairIterator& other) const { return a > other.a; }
CUB_INLINE bool operator<=(const PairIterator& other) const { return a <= other.a; }
CUB_INLINE bool operator>=(const PairIterator& other) const { return a >= other.a; }
CUB_INLINE bool operator==(const PairIterator& other) const {
return a == other.a;
}
CUB_INLINE bool operator!=(const PairIterator& other) const {
return a != other.a;
}
CUB_INLINE bool operator<(const PairIterator& other) const {
return a < other.a;
}
CUB_INLINE bool operator>(const PairIterator& other) const {
return a > other.a;
}
CUB_INLINE bool operator<=(const PairIterator& other) const {
return a <= other.a;
}
CUB_INLINE bool operator>=(const PairIterator& other) const {
return a >= other.a;
}
CUB_INLINE PairIterator& operator+=(const std::ptrdiff_t& movement) {
a += movement;
b += movement;
......@@ -148,12 +156,8 @@ struct PairIterator : public std::iterator<std::random_access_iterator_tag,
CUB_INLINE std::ptrdiff_t operator-(const PairIterator& other) const {
return a - other.a;
}
CUB_INLINE PairRef<DType> operator*() const {
return PairRef<DType>(a, b);
}
CUB_INLINE PairRef<DType> operator*() {
return PairRef<DType>(a, b);
}
CUB_INLINE PairRef<DType> operator*() const { return PairRef<DType>(a, b); }
CUB_INLINE PairRef<DType> operator*() { return PairRef<DType>(a, b); }
CUB_INLINE PairRef<DType> operator[](size_t offset) const {
return PairRef<DType>(a + offset, b + offset);
}
......
......@@ -8,8 +8,9 @@
#include <string>
#include <vector>
#include "./types.h"
#include "../runtime/object.h"
#include "./types.h"
namespace dgl {
......@@ -56,17 +57,14 @@ inline std::string ToStringSparseFormat(SparseFormat sparse_format) {
inline std::vector<SparseFormat> CodeToSparseFormats(dgl_format_code_t code) {
std::vector<SparseFormat> ret;
if (code & COO_CODE)
ret.push_back(SparseFormat::kCOO);
if (code & CSR_CODE)
ret.push_back(SparseFormat::kCSR);
if (code & CSC_CODE)
ret.push_back(SparseFormat::kCSC);
if (code & COO_CODE) ret.push_back(SparseFormat::kCOO);
if (code & CSR_CODE) ret.push_back(SparseFormat::kCSR);
if (code & CSC_CODE) ret.push_back(SparseFormat::kCSC);
return ret;
}
inline dgl_format_code_t
SparseFormatsToCode(const std::vector<SparseFormat> &formats) {
inline dgl_format_code_t SparseFormatsToCode(
const std::vector<SparseFormat>& formats) {
dgl_format_code_t ret = 0;
for (auto format : formats) {
switch (format) {
......@@ -88,20 +86,15 @@ SparseFormatsToCode(const std::vector<SparseFormat> &formats) {
inline std::string CodeToStr(dgl_format_code_t code) {
std::string ret = "";
if (code & COO_CODE)
ret += "coo ";
if (code & CSR_CODE)
ret += "csr ";
if (code & CSC_CODE)
ret += "csc ";
if (code & COO_CODE) ret += "coo ";
if (code & CSR_CODE) ret += "csr ";
if (code & CSC_CODE) ret += "csc ";
return ret;
}
inline SparseFormat DecodeFormat(dgl_format_code_t code) {
if (code & COO_CODE)
return SparseFormat::kCOO;
if (code & CSC_CODE)
return SparseFormat::kCSC;
if (code & COO_CODE) return SparseFormat::kCOO;
if (code & CSC_CODE) return SparseFormat::kCSC;
return SparseFormat::kCSR;
}
......@@ -113,20 +106,25 @@ struct SparseMatrix : public runtime::Object {
// Shape of this matrix.
int64_t num_rows = 0, num_cols = 0;
// Index arrays. For CSR, it is {indptr, indices, data}. For COO, it is {row, col, data}.
// Index arrays. For CSR, it is {indptr, indices, data}. For COO, it is {row,
// col, data}.
std::vector<IdArray> indices;
// Boolean flags.
// TODO(minjie): We might revisit this later to provide a more general solution. Currently,
// we only consider aten::COOMatrix and aten::CSRMatrix.
// TODO(minjie): We might revisit this later to provide a more general
// solution. Currently, we only consider aten::COOMatrix and aten::CSRMatrix.
std::vector<bool> flags;
SparseMatrix() {}
SparseMatrix(int32_t fmt, int64_t nrows, int64_t ncols,
const std::vector<IdArray>& idx,
const std::vector<bool>& flg)
: format(fmt), num_rows(nrows), num_cols(ncols), indices(idx), flags(flg) {}
SparseMatrix(
int32_t fmt, int64_t nrows, int64_t ncols,
const std::vector<IdArray>& idx, const std::vector<bool>& flg)
: format(fmt),
num_rows(nrows),
num_cols(ncols),
indices(idx),
flags(flg) {}
static constexpr const char* _type_key = "aten.SparseMatrix";
DGL_DECLARE_OBJECT_TYPE_INFO(SparseMatrix, runtime::Object);
......
......@@ -7,6 +7,7 @@
#define DGL_ATEN_TYPES_H_
#include <cstdint>
#include "../runtime/ndarray.h"
namespace dgl {
......
This diff is collapsed.
......@@ -7,6 +7,7 @@
#include <string>
#include <vector>
#include "./runtime/ndarray.h"
using namespace dgl::runtime;
......@@ -64,4 +65,3 @@ BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs);
} // namespace dgl
#endif // DGL_BCAST_H_
......@@ -6,13 +6,12 @@
#ifndef DGL_GRAPH_H_
#define DGL_GRAPH_H_
#include <string>
#include <vector>
#include <string>
#include <cstdint>
#include <utility>
#include <tuple>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "graph_interface.h"
......@@ -23,7 +22,7 @@ class GraphOp;
typedef std::shared_ptr<Graph> MutableGraphPtr;
/*! \brief Mutable graph based on adjacency list. */
class Graph: public GraphInterface {
class Graph : public GraphInterface {
public:
/*! \brief default constructor */
Graph() {}
......@@ -89,13 +88,9 @@ class Graph: public GraphInterface {
num_edges_ = 0;
}
DGLContext Context() const override {
return DGLContext{kDGLCPU, 0};
}
DGLContext Context() const override { return DGLContext{kDGLCPU, 0}; }
uint8_t NumBits() const override {
return 64;
}
uint8_t NumBits() const override { return 64; }
/*!
* \note not const since we have caches
......@@ -106,21 +101,17 @@ class Graph: public GraphInterface {
/*!
* \return whether the graph is read-only
*/
bool IsReadonly() const override {
return false;
}
bool IsReadonly() const override { return false; }
/*! \return the number of vertices in the graph.*/
uint64_t NumVertices() const override {
return adjlist_.size();
}
uint64_t NumVertices() const override { return adjlist_.size(); }
/*! \return the number of edges in the graph.*/
uint64_t NumEdges() const override {
return num_edges_;
}
uint64_t NumEdges() const override { return num_edges_; }
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/
/*! \return a 0-1 array indicating whether the given vertices are in the
* graph.
*/
BoolArray HasVertices(IdArray vids) const override;
/*! \return true if the given edge is in the graph.*/
......@@ -132,7 +123,8 @@ class Graph: public GraphInterface {
/*!
* \brief Find the predecessors of a vertex.
* \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the predecessor id array.
*/
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override;
......@@ -140,7 +132,8 @@ class Graph: public GraphInterface {
/*!
* \brief Find the successors of a vertex.
* \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the successor id array.
*/
IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const override;
......@@ -169,7 +162,8 @@ class Graph: public GraphInterface {
/*!
* \brief Find the edge ID and return the pair of endpoints
* \param eid The edge ID
* \return a pair whose first element is the source and the second the destination.
* \return a pair whose first element is the source and the second the
* destination.
*/
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
return std::make_pair(all_edges_src_[eid], all_edges_dst_[eid]);
......@@ -178,7 +172,8 @@ class Graph: public GraphInterface {
/*!
* \brief Find the edge IDs and return their source and target node IDs.
* \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved.
* \return EdgeArray containing all edges with id in eid. The order is
* preserved.
*/
EdgeArray FindEdges(IdArray eids) const override;
......@@ -216,10 +211,11 @@ class Graph: public GraphInterface {
* \brief Get all the edges in the graph.
* \note If sorted is true, the returned edges list is sorted by their src and
* dst ids. Otherwise, they are in their edge id order.
* \param sorted Whether the returned edge list is sorted by their src and dst ids
* \param sorted Whether the returned edge list is sorted by their src and dst
* ids.
* \return the id arrays of the two endpoints of the edges.
*/
EdgeArray Edges(const std::string &order = "") const override;
EdgeArray Edges(const std::string& order = "") const override;
/*!
* \brief Get the in degree of the given vertex.
......@@ -258,13 +254,14 @@ class Graph: public GraphInterface {
/*!
* \brief Construct the induced subgraph of the given vertices.
*
* The induced subgraph is a subgraph formed by specifying a set of vertices V' and then
* selecting all of the edges from the original graph that connect two vertices in V'.
* The induced subgraph is a subgraph formed by specifying a set of vertices
* V' and then selecting all of the edges from the original graph that connect
* two vertices in V'.
*
* Vertices and edges in the original graph will be "reindexed" to local index. The local
* index of the vertices preserve the order of the given id array, while the local index
* of the edges preserve the index order in the original graph. Vertices not in the
* original graph are ignored.
* Vertices and edges in the original graph will be "reindexed" to local
* index. The local index of the vertices preserve the order of the given id
* array, while the local index of the edges preserve the index order in the
* original graph. Vertices not in the original graph are ignored.
*
* The result subgraph is read-only.
*
......@@ -276,20 +273,22 @@ class Graph: public GraphInterface {
/*!
* \brief Construct the induced edge subgraph of the given edges.
*
* The induced edges subgraph is a subgraph formed by specifying a set of edges E' and then
* selecting all of the nodes from the original graph that are endpoints in E'.
* The induced edges subgraph is a subgraph formed by specifying a set of
* edges E' and then selecting all of the nodes from the original graph that
* are endpoints in E'.
*
* Vertices and edges in the original graph will be "reindexed" to local index. The local
* index of the edges preserve the order of the given id array, while the local index
* of the vertices preserve the index order in the original graph. Edges not in the
* original graph are ignored.
* Vertices and edges in the original graph will be "reindexed" to local
* index. The local index of the edges preserve the order of the given id
* array, while the local index of the vertices preserve the index order in
* the original graph. Edges not in the original graph are ignored.
*
* The result subgraph is read-only.
*
* \param eids The edges in the subgraph.
* \return the induced edge subgraph
*/
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;
Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const override;
/*!
* \brief Return the successor vector
......@@ -344,12 +343,11 @@ class Graph: public GraphInterface {
* \param fmt the format of the returned adjacency matrix.
* \return a vector of three IdArray.
*/
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override;
std::vector<IdArray> GetAdj(
bool transpose, const std::string& fmt) const override;
/*! \brief Create an empty graph */
static MutableGraphPtr Create() {
return std::make_shared<Graph>();
}
static MutableGraphPtr Create() { return std::make_shared<Graph>(); }
/*! \brief Create from coo */
static MutableGraphPtr CreateFromCOO(
......
......@@ -6,11 +6,11 @@
#ifndef DGL_GRAPH_INTERFACE_H_
#define DGL_GRAPH_INTERFACE_H_
#include <string>
#include <vector>
#include <utility>
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "./runtime/object.h"
#include "array.h"
......@@ -23,7 +23,8 @@ const dgl_id_t DGL_INVALID_ID = static_cast<dgl_id_t>(-1);
* \brief This class references data in std::vector.
*
* This isn't a STL-style iterator. It provides a STL data container interface.
* but it doesn't own data itself. instead, it only references data in std::vector.
* but it doesn't own data itself. instead, it only references data in
* std::vector.
*/
class DGLIdIters {
public:
......@@ -34,18 +35,11 @@ class DGLIdIters {
this->begin_ = begin;
this->end_ = end;
}
const dgl_id_t *begin() const {
return this->begin_;
}
const dgl_id_t *end() const {
return this->end_;
}
dgl_id_t operator[](int64_t i) const {
return *(this->begin_ + i);
}
size_t size() const {
return this->end_ - this->begin_;
}
const dgl_id_t *begin() const { return this->begin_; }
const dgl_id_t *end() const { return this->end_; }
dgl_id_t operator[](int64_t i) const { return *(this->begin_ + i); }
size_t size() const { return this->end_ - this->begin_; }
private:
const dgl_id_t *begin_{nullptr}, *end_{nullptr};
};
......@@ -63,23 +57,15 @@ class DGLIdIters32 {
this->begin_ = begin;
this->end_ = end;
}
const int32_t *begin() const {
return this->begin_;
}
const int32_t *end() const {
return this->end_;
}
int32_t operator[](int32_t i) const {
return *(this->begin_ + i);
}
size_t size() const {
return this->end_ - this->begin_;
}
const int32_t *begin() const { return this->begin_; }
const int32_t *end() const { return this->end_; }
int32_t operator[](int32_t i) const { return *(this->begin_ + i); }
size_t size() const { return this->end_ - this->begin_; }
private:
const int32_t *begin_{nullptr}, *end_{nullptr};
};
/* \brief structure used to represent a list of edges */
typedef struct {
/* \brief the two endpoints and the id of the edge */
......@@ -140,7 +126,8 @@ class GraphInterface : public runtime::Object {
virtual DGLContext Context() const = 0;
/*!
* \brief Get the number of integer bits used to store node/edge ids (32 or 64).
* \brief Get the number of integer bits used to store node/edge ids
* (32 or 64).
*/
virtual uint8_t NumBits() const = 0;
......@@ -192,11 +179,11 @@ class GraphInterface : public runtime::Object {
virtual uint64_t NumEdges() const = 0;
/*! \return true if the given vertex is in the graph.*/
virtual bool HasVertex(dgl_id_t vid) const {
return vid < NumVertices();
}
virtual bool HasVertex(dgl_id_t vid) const { return vid < NumVertices(); }
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/
/*! \return a 0-1 array indicating whether the given vertices are in the
* graph.
*/
virtual BoolArray HasVertices(IdArray vids) const = 0;
/*! \return true if the given edge is in the graph.*/
......@@ -208,7 +195,8 @@ class GraphInterface : public runtime::Object {
/*!
* \brief Find the predecessors of a vertex.
* \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the predecessor id array.
*/
virtual IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const = 0;
......@@ -216,7 +204,8 @@ class GraphInterface : public runtime::Object {
/*!
* \brief Find the successors of a vertex.
* \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \param radius The radius of the neighborhood. Default is immediate neighbor
* (radius=1).
* \return the successor id array.
*/
virtual IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const = 0;
......@@ -245,14 +234,16 @@ class GraphInterface : public runtime::Object {
/*!
* \brief Find the edge ID and return the pair of endpoints
* \param eid The edge ID
* \return a pair whose first element is the source and the second the destination.
* \return a pair whose first element is the source and the second the
* destination.
*/
virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const = 0;
/*!
* \brief Find the edge IDs and return their source and target node IDs.
* \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved.
* \return EdgeArray containing all edges with id in eid. The order is
* preserved.
*/
virtual EdgeArray FindEdges(IdArray eids) const = 0;
......@@ -288,8 +279,8 @@ class GraphInterface : public runtime::Object {
/*!
* \brief Get all the edges in the graph.
* \note If order is "srcdst", the returned edges list is sorted by their src and
* dst ids. If order is "eid", they are in their edge id order.
* \note If order is "srcdst", the returned edges list is sorted by their src
* and dst ids. If order is "eid", they are in their edge id order.
* Otherwise, in the arbitrary order.
* \param order The order of the returned edge list.
* \return the id arrays of the two endpoints of the edges.
......@@ -327,13 +318,14 @@ class GraphInterface : public runtime::Object {
/*!
* \brief Construct the induced subgraph of the given vertices.
*
* The induced subgraph is a subgraph formed by specifying a set of vertices V' and then
* selecting all of the edges from the original graph that connect two vertices in V'.
* The induced subgraph is a subgraph formed by specifying a set of vertices
* V' and then selecting all of the edges from the original graph that connect
* two vertices in V'.
*
* Vertices and edges in the original graph will be "reindexed" to local index. The local
* index of the vertices preserve the order of the given id array, while the local index
* of the edges preserve the index order in the original graph. Vertices not in the
* original graph are ignored.
* Vertices and edges in the original graph will be "reindexed" to local
* index. The local index of the vertices preserve the order of the given id
* array, while the local index of the edges preserve the index order in the
* original graph. Vertices not in the original graph are ignored.
*
* The result subgraph is read-only.
*
......@@ -345,22 +337,24 @@ class GraphInterface : public runtime::Object {
/*!
* \brief Construct the induced edge subgraph of the given edges.
*
* The induced edges subgraph is a subgraph formed by specifying a set of edges E' and then
* selecting all of the nodes from the original graph that are endpoints in E'.
* The induced edges subgraph is a subgraph formed by specifying a set of
* edges E' and then selecting all of the nodes from the original graph that
* are endpoints in E'.
*
* Vertices and edges in the original graph will be "reindexed" to local index. The local
* index of the edges preserve the order of the given id array, while the local index
* of the vertices preserve the index order in the original graph. Edges not in the
* original graph are ignored.
* Vertices and edges in the original graph will be "reindexed" to local
* index. The local index of the edges preserve the order of the given id
* array, while the local index of the vertices preserve the index order in
* the original graph. Edges not in the original graph are ignored.
*
* The result subgraph is read-only.
*
* \param eids The edges in the subgraph.
* \param preserve_nodes If true, the vertices will not be relabeled, so some vertices
* may have no incident edges.
* \param preserve_nodes If true, the vertices will not be relabeled, so some
* vertices may have no incident edges.
* \return the induced edge subgraph
*/
virtual Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const = 0;
virtual Subgraph EdgeSubgraph(
IdArray eids, bool preserve_nodes = false) const = 0;
/*!
* \brief Return the successor vector
......@@ -399,14 +393,15 @@ class GraphInterface : public runtime::Object {
* If the fmt is 'csr', the function should return three arrays, representing
* indptr, indices and edge ids
*
* If the fmt is 'coo', the function should return one array of shape (2, nnz),
* representing a horitonzal stack of row and col indices.
* If the fmt is 'coo', the function should return one array of shape (2,
* nnz), representing a horitonzal stack of row and col indices.
*
* \param transpose A flag to transpose the returned adjacency matrix.
* \param fmt the format of the returned adjacency matrix.
* \return a vector of IdArrays.
*/
virtual std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const = 0;
virtual std::vector<IdArray> GetAdj(
bool transpose, const std::string &fmt) const = 0;
/*!
* \brief Sort the columns in CSR.
......@@ -414,10 +409,9 @@ class GraphInterface : public runtime::Object {
* This sorts the columns in each row based on the column Ids.
* The edge ids should be sorted accordingly.
*/
virtual void SortCSR() {
}
virtual void SortCSR() {}
static constexpr const char* _type_key = "graph.Graph";
static constexpr const char *_type_key = "graph.Graph";
DGL_DECLARE_OBJECT_TYPE_INFO(GraphInterface, runtime::Object);
};
......@@ -430,21 +424,23 @@ struct Subgraph : public runtime::Object {
GraphPtr graph;
/*!
* \brief The induced vertex ids.
* \note This is also a map from the new vertex id to the vertex id in the parent graph.
* \note This is also a map from the new vertex id to the vertex id in the
* parent graph.
*/
IdArray induced_vertices;
/*!
* \brief The induced edge ids.
* \note This is also a map from the new edge id to the edge id in the parent graph.
* \note This is also a map from the new edge id to the edge id in the parent
* graph.
*/
IdArray induced_edges;
static constexpr const char* _type_key = "graph.Subgraph";
static constexpr const char *_type_key = "graph.Subgraph";
DGL_DECLARE_OBJECT_TYPE_INFO(Subgraph, runtime::Object);
};
/*! \brief Subgraph data structure for negative subgraph */
struct NegSubgraph: public Subgraph {
struct NegSubgraph : public Subgraph {
/*! \brief The existence of the negative edges in the parent graph. */
IdArray exist;
......@@ -456,7 +452,7 @@ struct NegSubgraph: public Subgraph {
};
/*! \brief Subgraph data structure for halo subgraph */
struct HaloSubgraph: public Subgraph {
struct HaloSubgraph : public Subgraph {
/*! \brief Indicate if a node belongs to the partition. */
IdArray inner_nodes;
};
......
......@@ -7,6 +7,7 @@
#define DGL_GRAPH_OP_H_
#include <vector>
#include "graph.h"
#include "immutable_graph.h"
......@@ -17,7 +18,8 @@ class GraphOp {
/*!
* \brief Return a new graph with all the edges reversed.
*
* The returned graph preserves the vertex and edge index in the original graph.
* The returned graph preserves the vertex and edge index in the original
* graph.
*
* \return the reversed graph
*/
......@@ -40,10 +42,10 @@ class GraphOp {
* \brief Return a disjoint union of the input graphs.
*
* The new graph will include all the nodes/edges in the given graphs.
* Nodes/Edges will be relabled by adding the cumsum of the previous graph sizes
* in the given sequence order. For example, giving input [g1, g2, g3], where
* they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7
* in the result graph. Edge ids are re-assigned similarly.
* Nodes/Edges will be relabled by adding the cumsum of the previous graph
* sizes in the given sequence order. For example, giving input [g1, g2, g3],
* where they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become
* node#7 in the result graph. Edge ids are re-assigned similarly.
*
* The input list must be either ALL mutable graphs or ALL immutable graphs.
* The returned graph type is also determined by the input graph type.
......@@ -67,7 +69,8 @@ class GraphOp {
* \param num The number of partitions.
* \return a list of partitioned graphs
*/
static std::vector<GraphPtr> DisjointPartitionByNum(GraphPtr graph, int64_t num);
static std::vector<GraphPtr> DisjointPartitionByNum(
GraphPtr graph, int64_t num);
/*!
* \brief Partition the graph into several subgraphs.
......@@ -83,16 +86,17 @@ class GraphOp {
* \param sizes The number of partitions.
* \return a list of partitioned graphs
*/
static std::vector<GraphPtr> DisjointPartitionBySizes(GraphPtr graph, IdArray sizes);
static std::vector<GraphPtr> DisjointPartitionBySizes(
GraphPtr graph, IdArray sizes);
/*!
* \brief Map vids in the parent graph to the vids in the subgraph.
*
* If the Id doesn't exist in the subgraph, -1 will be used.
*
* \param parent_vid_map An array that maps the vids in the parent graph to the
* subgraph. The elements store the vertex Ids in the parent graph, and the
* indices indicate the vertex Ids in the subgraph.
* \param parent_vid_map An array that maps the vids in the parent graph to
* the subgraph. The elements store the vertex Ids in the parent graph, and
* the indices indicate the vertex Ids in the subgraph.
* \param query The vertex Ids in the parent graph.
* \return an Id array that contains the subgraph node Ids.
*/
......@@ -134,15 +138,18 @@ class GraphOp {
static GraphPtr ToBidirectedMutableGraph(GraphPtr graph);
/*!
* \brief Same as BidirectedMutableGraph except that the returned graph is immutable.
* \brief Same as BidirectedMutableGraph except that the returned graph is
* immutable.
* \param graph The input graph.
* \return a new immutable bidirected graph.
* \return a new immutable bidirected
* graph.
*/
static GraphPtr ToBidirectedImmutableGraph(GraphPtr graph);
/*!
* \brief Same as BidirectedMutableGraph except that the returned graph is immutable
* and call gk_csr_MakeSymmetric in GKlib. This is more efficient than ToBidirectedImmutableGraph.
* It return a null pointer if the conversion fails.
* \brief Same as BidirectedMutableGraph except that the returned graph is
* immutable and call gk_csr_MakeSymmetric in GKlib. This is more efficient
* than ToBidirectedImmutableGraph. It return a null pointer if the conversion
* fails.
*
* \param graph The input graph.
* \return a new immutable bidirected graph.
......@@ -151,21 +158,25 @@ class GraphOp {
/*!
* \brief Get a induced subgraph with HALO nodes.
* The HALO nodes are the ones that can be reached from `nodes` within `num_hops`.
* The HALO nodes are the ones that can be reached from `nodes` within
* `num_hops`.
* \param graph The input graph.
* \param nodes The input nodes that form the core of the induced subgraph.
* \param num_hops The number of hops to reach.
* \return the induced subgraph with HALO nodes.
*/
static HaloSubgraph GetSubgraphWithHalo(GraphPtr graph, IdArray nodes, int num_hops);
static HaloSubgraph GetSubgraphWithHalo(
GraphPtr graph, IdArray nodes, int num_hops);
/*!
* \brief Reorder the nodes in the immutable graph.
* \param graph The input graph.
* \param new_order The node Ids in the new graph. The index in `new_order` is old node Ids.
* \param new_order The node Ids in the new graph. The index in `new_order` is
* old node Ids.
* \return the graph with reordered node Ids
*/
static GraphPtr ReorderImmutableGraph(ImmutableGraphPtr ig, IdArray new_order);
static GraphPtr ReorderImmutableGraph(
ImmutableGraphPtr ig, IdArray new_order);
};
} // namespace dgl
......
......@@ -16,7 +16,8 @@ namespace dgl {
* \brief Class for representing frontiers.
*
* Each frontier is a list of nodes/edges (specified by their ids).
* An optional tag can be specified on each node/edge (represented by an int value).
* An optional tag can be specified on each node/edge (represented by an int
* value).
*/
struct Frontiers {
/*!\brief a vector store for the nodes/edges in all the frontiers */
......@@ -78,10 +79,10 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
* FORWARD(0), REVERSE(1), NONTREE(2)
*
* A FORWARD edge is one in which `u` has been visisted but `v` has not.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the edge
* is in the DFS tree.
* A NONTREE edge is one in which both `u` and `v` have been visisted but the edge
* is NOT in the DFS tree.
* A REVERSE edge is one in which both `u` and `v` have been visisted and the
* edge is in the DFS tree.
* A NONTREE edge is one in which both `u` and `v` have been visisted but the
* edge is NOT in the DFS tree.
*
* \param csr The input csr matrix.
* \param sources Source nodes.
......@@ -90,11 +91,9 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
* \param return_labels If true, return the recorded edge tags.
* \return A Frontiers object containing the search result
*/
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
const bool has_nontree_edge,
const bool return_labels);
Frontiers DGLDFSLabeledEdges(
const CSRMatrix& csr, IdArray source, const bool has_reverse_edge,
const bool has_nontree_edge, const bool return_labels);
} // namespace aten
} // namespace dgl
......
This diff is collapsed.
......@@ -7,12 +7,12 @@
#define DGL_KERNEL_H_
#include <string>
#include <vector>
#include <utility>
#include <vector>
#include "array.h"
#include "./bcast.h"
#include "./base_heterograph.h"
#include "./bcast.h"
#include "array.h"
namespace dgl {
namespace aten {
......@@ -30,12 +30,9 @@ namespace aten {
* as the argmax on source nodes and edges for reduce operators such as
* `min` and `max`.
*/
void SpMM(const std::string& op, const std::string& reduce,
HeteroGraphPtr graph,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux);
void SpMM(
const std::string& op, const std::string& reduce, HeteroGraphPtr graph,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
/*!
* \brief Generalized Sampled Dense-Dense Matrix Multiplication.
......@@ -46,23 +43,18 @@ void SpMM(const std::string& op, const std::string& reduce,
* \param vfeat The destination node feature.
* \param out The output feature on edge.
*/
void SDDMM(const std::string& op,
HeteroGraphPtr graph,
NDArray ufeat,
NDArray efeat,
void SDDMM(
const std::string& op, HeteroGraphPtr graph, NDArray ufeat, NDArray efeat,
NDArray out);
/*!
* \brief Sparse-sparse matrix multiplication.
*
* The sparse matrices must have scalar weights (i.e. \a A_weights and \a B_weights
* are 1D vectors.)
* The sparse matrices must have scalar weights (i.e. \a A_weights and \a
* B_weights are 1D vectors.)
*/
std::pair<CSRMatrix, NDArray> CSRMM(
CSRMatrix A,
NDArray A_weights,
CSRMatrix B,
NDArray B_weights);
CSRMatrix A, NDArray A_weights, CSRMatrix B, NDArray B_weights);
/*!
* \brief Summing up a list of sparse matrices.
......@@ -71,8 +63,7 @@ std::pair<CSRMatrix, NDArray> CSRMM(
* are 1D vectors.)
*/
std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& A,
const std::vector<NDArray>& A_weights);
const std::vector<CSRMatrix>& A, const std::vector<NDArray>& A_weights);
} // namespace aten
} // namespace dgl
......
......@@ -22,15 +22,17 @@ class Lazy {
/*!\brief default constructor to construct a lazy object */
Lazy() {}
/*!\brief constructor to construct an object with given value (non-lazy case) */
explicit Lazy(const T& val): ptr_(new T(val)) {}
/*!
* \brief constructor to construct an object with given value (non-lazy case)
*/
explicit Lazy(const T& val) : ptr_(new T(val)) {}
/*!\brief destructor */
~Lazy() = default;
/*!
* \brief Get the value of this object. If the object has not been instantiated,
* using the provided function to create it.
* \brief Get the value of this object. If the object has not been
* instantiated, using the provided function to create it.
* \param fn The creator function.
* \return the object value.
*/
......
......@@ -6,9 +6,9 @@
#ifndef DGL_NODEFLOW_H_
#define DGL_NODEFLOW_H_
#include <vector>
#include <string>
#include <memory>
#include <string>
#include <vector>
#include "./runtime/object.h"
#include "graph_interface.h"
......@@ -18,12 +18,12 @@ namespace dgl {
class ImmutableGraph;
/*!
* \brief A NodeFlow graph stores the sampling results for a sampler that samples
* nodes/edges in layers.
* \brief A NodeFlow graph stores the sampling results for a sampler that
* samples nodes/edges in layers.
*
* We store multiple layers of the sampling results in a single graph, which results
* in a more compact format. We store extra information,
* such as the node and edge mapping from the NodeFlow graph to the parent graph.
* We store multiple layers of the sampling results in a single graph, which
* results in a more compact format. We store extra information, such as the
* node and edge mapping from the NodeFlow graph to the parent graph.
*/
struct NodeFlowObject : public runtime::Object {
/*! \brief The graph. */
......@@ -45,7 +45,7 @@ struct NodeFlowObject : public runtime::Object {
*/
IdArray edge_mapping;
static constexpr const char* _type_key = "graph.NodeFlow";
static constexpr const char *_type_key = "graph.NodeFlow";
DGL_DECLARE_OBJECT_TYPE_INFO(NodeFlowObject, runtime::Object);
};
......@@ -80,8 +80,8 @@ class NodeFlow : public runtime::ObjectRef {
* of an edge and the column represents the source.
*
* If fmt == "csr", the function returns three arrays: indptr, indices, eid.
* If fmt == "coo", the function returns two arrays: idx, eid. Here, the idx array
* is the concatenation of src and dst node id arrays.
* If fmt == "coo", the function returns two arrays: idx, eid. Here, the idx
* array is the concatenation of src and dst node id arrays.
*
* \param graph An immutable graph.
* \param fmt the format of the returned adjacency matrix.
......@@ -92,11 +92,10 @@ class NodeFlow : public runtime::ObjectRef {
* space.
* \return a vector of IdArrays.
*/
std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::string &fmt,
size_t layer0_size, size_t layer1_start,
size_t layer1_end, bool remap);
std::vector<IdArray> GetNodeFlowSlice(
const ImmutableGraph &graph, const std::string &fmt, size_t layer0_size,
size_t layer1_start, size_t layer1_end, bool remap);
} // namespace dgl
#endif // DGL_NODEFLOW_H_
......@@ -7,14 +7,14 @@
#ifndef DGL_PACKED_FUNC_EXT_H_
#define DGL_PACKED_FUNC_EXT_H_
#include <memory>
#include <sstream>
#include <string>
#include <memory>
#include <type_traits>
#include "./runtime/packed_func.h"
#include "./runtime/object.h"
#include "./runtime/container.h"
#include "./runtime/object.h"
#include "./runtime/packed_func.h"
namespace dgl {
namespace runtime {
......@@ -22,7 +22,7 @@ namespace runtime {
* \brief Runtime type checker for node type.
* \tparam T the type to be checked.
*/
template<typename T>
template <typename T>
struct ObjectTypeChecker {
static inline bool Check(Object* sptr) {
// This is the only place in the project where RTTI is used
......@@ -37,7 +37,7 @@ struct ObjectTypeChecker {
}
};
template<typename T>
template <typename T>
struct ObjectTypeChecker<List<T> > {
static inline bool Check(Object* sptr) {
if (sptr == nullptr) return false;
......@@ -55,7 +55,7 @@ struct ObjectTypeChecker<List<T> > {
}
};
template<typename V>
template <typename V>
struct ObjectTypeChecker<Map<std::string, V> > {
static inline bool Check(Object* sptr) {
if (sptr == nullptr) return false;
......@@ -74,9 +74,7 @@ struct ObjectTypeChecker<Map<std::string, V> > {
}
};
template<typename K, typename V>
template <typename K, typename V>
struct ObjectTypeChecker<Map<K, V> > {
static inline bool Check(Object* sptr) {
if (sptr == nullptr) return false;
......@@ -97,7 +95,7 @@ struct ObjectTypeChecker<Map<K, V> > {
}
};
template<typename T>
template <typename T>
inline std::string NodeTypeName() {
std::ostringstream os;
ObjectTypeChecker<T>::PrintName(os);
......@@ -106,7 +104,7 @@ inline std::string NodeTypeName() {
// extensions for DGLArgValue
template<typename TObjectRef>
template <typename TObjectRef>
inline TObjectRef DGLArgValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
......@@ -115,8 +113,8 @@ inline TObjectRef DGLArgValue::AsObjectRef() const {
DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle);
std::shared_ptr<Object>& sptr = *ptr<std::shared_ptr<Object> >();
CHECK(ObjectTypeChecker<TObjectRef>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<TObjectRef>()
<< " but get " << sptr->type_key();
<< "Expected type " << NodeTypeName<TObjectRef>() << " but get "
<< sptr->type_key();
return TObjectRef(sptr);
}
......@@ -125,12 +123,10 @@ inline std::shared_ptr<Object>& DGLArgValue::obj_sptr() {
return *ptr<std::shared_ptr<Object> >();
}
template<typename TObjectRef, typename>
template <typename TObjectRef, typename>
inline bool DGLArgValue::IsObjectType() const {
DGL_CHECK_TYPE_CODE(type_code_, kObjectHandle);
std::shared_ptr<Object>& sptr =
*ptr<std::shared_ptr<Object> >();
std::shared_ptr<Object>& sptr = *ptr<std::shared_ptr<Object> >();
return ObjectTypeChecker<TObjectRef>::Check(sptr.get());
}
......@@ -155,7 +151,7 @@ inline DGLRetValue& DGLRetValue::operator=(const ObjectRef& other) {
return *this;
}
template<typename TObjectRef>
template <typename TObjectRef>
inline TObjectRef DGLRetValue::AsObjectRef() const {
static_assert(
std::is_base_of<ObjectRef, TObjectRef>::value,
......@@ -165,7 +161,8 @@ inline TObjectRef DGLRetValue::AsObjectRef() const {
return TObjectRef(*ptr<std::shared_ptr<Object> >());
}
inline void DGLArgsSetter::operator()(size_t i, const ObjectRef& other) const { // NOLINT(*)
inline void DGLArgsSetter::operator()(
size_t i, const ObjectRef& other) const { // NOLINT(*)
if (other.defined()) {
values_[i].v_handle = const_cast<std::shared_ptr<Object>*>(&(other.obj_));
type_codes_[i] = kObjectHandle;
......
......@@ -8,8 +8,9 @@
#define DGL_RANDOM_H_
#include <dgl/array.h>
#include <dmlc/thread_local.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <random>
#include <thread>
#include <vector>
......@@ -46,33 +47,27 @@ class RandomEngine {
}
/*! \brief Constructor with given seed */
explicit RandomEngine(uint32_t seed) {
SetSeed(seed);
}
explicit RandomEngine(uint32_t seed) { SetSeed(seed); }
/*! \brief Get the thread-local random number generator instance */
static RandomEngine *ThreadLocal() {
static RandomEngine* ThreadLocal() {
return dmlc::ThreadLocalStore<RandomEngine>::Get();
}
/*!
* \brief Set the seed of this random number generator
*/
void SetSeed(uint32_t seed) {
rng_.seed(seed + GetThreadId());
}
void SetSeed(uint32_t seed) { rng_.seed(seed + GetThreadId()); }
/*!
* \brief Generate an arbitrary random 32-bit integer.
*/
int32_t RandInt32() {
return static_cast<int32_t>(rng_());
}
int32_t RandInt32() { return static_cast<int32_t>(rng_()); }
/*!
* \brief Generate a uniform random integer in [0, upper)
*/
template<typename T>
template <typename T>
T RandInt(T upper) {
return RandInt<T>(0, upper);
}
......@@ -80,7 +75,7 @@ class RandomEngine {
/*!
* \brief Generate a uniform random integer in [lower, upper)
*/
template<typename T>
template <typename T>
T RandInt(T lower, T upper) {
CHECK_LT(lower, upper);
std::uniform_int_distribution<T> dist(lower, upper - 1);
......@@ -90,7 +85,7 @@ class RandomEngine {
/*!
* \brief Generate a uniform random float in [0, 1)
*/
template<typename T>
template <typename T>
T Uniform() {
return Uniform<T>(0., 1.);
}
......@@ -98,7 +93,7 @@ class RandomEngine {
/*!
* \brief Generate a uniform random float in [lower, upper)
*/
template<typename T>
template <typename T>
T Uniform(T lower, T upper) {
// Although the result is in [lower, upper), we allow lower == upper as in
// www.cplusplus.com/reference/random/uniform_real_distribution/uniform_real_distribution/
......@@ -108,23 +103,27 @@ class RandomEngine {
}
/*!
* \brief Pick a random integer between 0 to N-1 according to given probabilities
* \tparam IdxType Return integer type
* \param prob Array of N unnormalized probability of each element. Must be non-negative.
* \brief Pick a random integer between 0 to N-1 according to given
* probabilities.
* \tparam IdxType Return integer type.
* \param prob Array of N unnormalized probability of each element. Must be
* non-negative.
* \return An integer randomly picked from 0 to N-1.
*/
template<typename IdxType>
template <typename IdxType>
IdxType Choice(FloatArray prob);
/*!
* \brief Pick random integers between 0 to N-1 according to given probabilities
* \brief Pick random integers between 0 to N-1 according to given
* probabilities
*
* If replace is false, the number of picked integers must not larger than N.
*
* \tparam IdxType Id type
* \tparam FloatType Probability value type
* \param num Number of integers to choose
* \param prob Array of N unnormalized probability of each element. Must be non-negative.
* \param prob Array of N unnormalized probability of each element. Must be
* non-negative.
* \param out The output buffer to write selected indices.
* \param replace If true, choose with replacement.
*/
......@@ -132,14 +131,16 @@ class RandomEngine {
void Choice(IdxType num, FloatArray prob, IdxType* out, bool replace = true);
/*!
* \brief Pick random integers between 0 to N-1 according to given probabilities
* \brief Pick random integers between 0 to N-1 according to given
* probabilities
*
* If replace is false, the number of picked integers must not larger than N.
*
* \tparam IdxType Id type
* \tparam FloatType Probability value type
* \param num Number of integers to choose
* \param prob Array of N unnormalized probability of each element. Must be non-negative.
* \param prob Array of N unnormalized probability of each element. Must be
* non-negative.
* \param replace If true, choose with replacement.
* \return Picked indices
*/
......@@ -147,7 +148,8 @@ class RandomEngine {
IdArray Choice(IdxType num, FloatArray prob, bool replace = true) {
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, prob->ctx);
Choice<IdxType, FloatType>(num, prob, static_cast<IdxType*>(ret->data), replace);
Choice<IdxType, FloatType>(
num, prob, static_cast<IdxType*>(ret->data), replace);
return ret;
}
......@@ -163,7 +165,8 @@ class RandomEngine {
* \param replace If true, choose with replacement.
*/
template <typename IdxType>
void UniformChoice(IdxType num, IdxType population, IdxType* out, bool replace = true);
void UniformChoice(
IdxType num, IdxType population, IdxType* out, bool replace = true);
/*!
* \brief Pick random integers from population by uniform distribution.
......@@ -181,43 +184,48 @@ class RandomEngine {
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
// TODO(minjie): only CPU implementation right now
IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0});
UniformChoice<IdxType>(num, population, static_cast<IdxType*>(ret->data), replace);
UniformChoice<IdxType>(
num, population, static_cast<IdxType*>(ret->data), replace);
return ret;
}
/*!
* \brief Pick random integers with different probability for different segments.
* \brief Pick random integers with different probability for different
* segments.
*
* For example, if split=[0, 4, 10] and bias=[1.5, 1], it means to pick some integers
* from 0 to 9, which is divided into two segments. 0-3 are in the first segment and the rest
* belongs to the second. The weight(bias) of each candidate in the first segment is upweighted
* to 1.5.
* For example, if split=[0, 4, 10] and bias=[1.5, 1], it means to pick some
* integers from 0 to 9, which is divided into two segments. 0-3 are in the
* first segment and the rest belongs to the second. The weight(bias) of each
* candidate in the first segment is upweighted to 1.5.
*
* candidate | 0 1 2 3 | 4 5 6 7 8 9 |
* split ^ ^ ^
* bias | 1.5 | 1 |
*
*
* The complexity of this operator is O(k * log(T)) where k is the number of integers we want
* to pick, and T is the number of segments. It is much faster compared with assigning
* probability for each candidate, of which the complexity is O(k * log(N)) where N is the
* number of all candidates.
* The complexity of this operator is O(k * log(T)) where k is the number of
* integers we want to pick, and T is the number of segments. It is much
* faster compared with assigning probability for each candidate, of which the
* complexity is O(k * log(N)) where N is the number of all candidates.
*
* If replace is false, num must not be larger than population.
*
* \tparam IdxType Return integer type
* \param num Number of integers to choose
* \param split Array of T+1 split positions of different segments(including start and end)
* \param bias Array of T weight of each segments
* \param split Array of T+1 split positions of different segments(including
* start and end)
* \param bias Array of T weight of each segments.
* \param out The output buffer to write selected indices.
* \param replace If true, choose with replacement.
*/
template <typename IdxType, typename FloatType>
void BiasedChoice(
IdxType num, const IdxType *split, FloatArray bias, IdxType* out, bool replace = true);
IdxType num, const IdxType* split, FloatArray bias, IdxType* out,
bool replace = true);
/*!
* \brief Pick random integers with different probability for different segments.
* \brief Pick random integers with different probability for different
* segments.
*
* If replace is false, num must not be larger than population.
*
......@@ -229,10 +237,11 @@ class RandomEngine {
*/
template <typename IdxType, typename FloatType>
IdArray BiasedChoice(
IdxType num, const IdxType *split, FloatArray bias, bool replace = true) {
IdxType num, const IdxType* split, FloatArray bias, bool replace = true) {
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0});
BiasedChoice<IdxType, FloatType>(num, split, bias, static_cast<IdxType*>(ret->data), replace);
BiasedChoice<IdxType, FloatType>(
num, split, bias, static_cast<IdxType*>(ret->data), replace);
return ret;
}
......
......@@ -27,9 +27,8 @@ extern "C" {
* \param out The result function.
* \return 0 when no error is thrown, -1 when failure happens
*/
DGL_DLL int DGLBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
DGLFunctionHandle *out);
DGL_DLL int DGLBackendGetFuncFromEnv(
void* mod_node, const char* func_name, DGLFunctionHandle* out);
/*!
* \brief Backend function to register system-wide library symbol.
*
......@@ -42,7 +41,8 @@ DGL_DLL int DGLBackendRegisterSystemLibSymbol(const char* name, void* ptr);
/*!
* \brief Backend function to allocate temporal workspace.
*
* \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
* \note The result allocate spaced is ensured to be aligned to
* kTempAllocaAlignment.
*
* \param nbytes The size of the space requested.
* \param device_type The device type which the space will be allocated.
......@@ -53,10 +53,8 @@ DGL_DLL int DGLBackendRegisterSystemLibSymbol(const char* name, void* ptr);
* certain backends such as OpenGL.
* \return nullptr when error is thrown, a valid ptr if success
*/
DGL_DLL void* DGLBackendAllocWorkspace(int device_type,
int device_id,
uint64_t nbytes,
int dtype_code_hint,
DGL_DLL void* DGLBackendAllocWorkspace(
int device_type, int device_id, uint64_t nbytes, int dtype_code_hint,
int dtype_bits_hint);
/*!
......@@ -69,9 +67,7 @@ DGL_DLL void* DGLBackendAllocWorkspace(int device_type,
*
* \sa DGLBackendAllocWorkspace
*/
DGL_DLL int DGLBackendFreeWorkspace(int device_type,
int device_id,
void* ptr);
DGL_DLL int DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr);
/*!
* \brief Environment for DGL parallel task.
......@@ -104,9 +100,8 @@ typedef int (*FDGLParallelLambda)(
*
* \return 0 when no error is thrown, -1 when failure happens
*/
DGL_DLL int DGLBackendParallelLaunch(FDGLParallelLambda flambda,
void* cdata,
int num_task);
DGL_DLL int DGLBackendParallelLaunch(
FDGLParallelLambda flambda, void* cdata, int num_task);
/*!
* \brief BSP barrrier between parallel threads
......@@ -116,7 +111,6 @@ DGL_DLL int DGLBackendParallelLaunch(FDGLParallelLambda flambda,
*/
DGL_DLL int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv);
/*!
* \brief Simple static initialization fucntion.
* Run f once and set handle to be not null.
......@@ -128,10 +122,8 @@ DGL_DLL int DGLBackendParallelBarrier(int task_id, DGLParallelGroupEnv* penv);
* \param nbytes Number of bytes in the closure data.
* \return 0 when no error is thrown, -1 when failure happens
*/
DGL_DLL int DGLBackendRunOnce(void** handle,
int (*f)(void*),
void *cdata,
int nbytes);
DGL_DLL int DGLBackendRunOnce(
void** handle, int (*f)(void*), void* cdata, int nbytes);
#ifdef __cplusplus
} // DGL_EXTERN_C
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment