Unverified Commit 9b4d6079 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Hetero] New syntax (#824)

* WIP. remove graph arg in NodeBatch and EdgeBatch

* refactor: use graph adapter for scheduler

* WIP: recv

* draft impl

* stuck at bipartite

* bipartite->unitgraph; support dsttype == srctype

* pass test_query

* pass test_query

* pass test_view

* test apply

* pass udf message passing tests

* pass quan's test using builtins

* WIP: wildcard slicing

* new construct methods

* broken

* good

* add stack cross reducer

* fix bug; fix mx

* fix bug in csrmm2 when the CSR is not square

* lint

* removed FlattenedHeteroGraph class

* WIP

* prop nodes, prop edges, filter nodes/edges

* add DGLGraph tests to heterograph. Fix several bugs

* finish nx<->hetero graph conversion

* create bipartite from nx

* more spec on hetero/homo conversion

* silly fixes

* check node and edge types

* repr

* to api

* adj APIs

* inc

* fix some lints and bugs

* fix some lints

* hetero/homo conversion

* fix flatten test

* more spec in hetero_from_homo and test

* flatten using concat names

* WIP: creators

* rewrite hetero_from_homo in a more efficient way

* remove useless variables

* fix lint

* subgraphs and typed subgraphs

* lint & removed heterosubgraph class

* lint x2

* disable heterograph mutation test

* docstring update

* add edge id for nx graph test

* fix mx unittests

* fix bug

* try fix

* fix unittest when cross_reducer is stack

* fix ci

* fix nx bipartite bug; docstring

* fix scipy creation bug

* lint

* fix bug when converting heterograph from homograph

* fix bug in hetero_from_homo about ntype order

* trailing white

* docstring fixes for add_foo and data views

* docstring for relation slice

* to_hetero and to_homo with feature support

* lint

* lint

* DGLGraph compatibility

* incidence matrix & docstring fixes

* example string fixes

* feature in hetero_from_relations

* deduplication of edge types in to_hetero

* fix lint

* fix
parent ddb5d804
/*!
* Copyright (c) 2019 by Contributors
* \file graph/bipartite.h
* \brief Bipartite graph
* \file graph/unit_graph.h
* \brief UnitGraph graph
*/
#ifndef DGL_GRAPH_BIPARTITE_H_
#define DGL_GRAPH_BIPARTITE_H_
#ifndef DGL_GRAPH_UNIT_GRAPH_H_
#define DGL_GRAPH_UNIT_GRAPH_H_
#include <dgl/base_heterograph.h>
#include <dgl/lazy.h>
......@@ -19,55 +19,56 @@
namespace dgl {
/*!
* \brief Bipartite graph
* \brief UnitGraph graph
*
* Bipartite graph is a special type of heterograph which has two types
* of nodes: "Src" and "Dst". All the edges are from "Src" type nodes to
* "Dst" type nodes, so there is no edge among nodes of the same type.
* UnitGraph graph is a special type of heterograph which
* (1) Have two types of nodes: "Src" and "Dst". All the edges are
* from "Src" type nodes to "Dst" type nodes, so there is no edge among
* nodes of the same type. Thus, its metagraph has two nodes and one edge
* between them.
* (2) Have only one type of nodes and edges. Thus, its metagraph has one node
* and one self-loop edge.
*/
class Bipartite : public BaseHeteroGraph {
class UnitGraph : public BaseHeteroGraph {
public:
/*! \brief source node group type */
static constexpr dgl_type_t kSrcVType = 0;
/*! \brief destination node group type */
static constexpr dgl_type_t kDstVType = 1;
/*! \brief edge group type */
static constexpr dgl_type_t kEType = 0;
// internal data structure
class COO;
class CSR;
typedef std::shared_ptr<COO> COOPtr;
typedef std::shared_ptr<CSR> CSRPtr;
uint64_t NumVertexTypes() const override {
return 2;
inline dgl_type_t SrcType() const {
return 0;
}
inline dgl_type_t DstType() const {
return NumVertexTypes() == 1? 0 : 1;
}
uint64_t NumEdgeTypes() const override {
return 1;
inline dgl_type_t EdgeType() const {
return 0;
}
HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
LOG(FATAL) << "The method shouldn't be called for Bipartite graph. "
LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
<< "The relation graph is simply this graph itself.";
return {};
}
void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
LOG(FATAL) << "Bipartite graph is not mutable.";
LOG(FATAL) << "UnitGraph graph is not mutable.";
}
void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
LOG(FATAL) << "Bipartite graph is not mutable.";
LOG(FATAL) << "UnitGraph graph is not mutable.";
}
void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
LOG(FATAL) << "Bipartite graph is not mutable.";
LOG(FATAL) << "UnitGraph graph is not mutable.";
}
void Clear() override {
LOG(FATAL) << "Bipartite graph is not mutable.";
LOG(FATAL) << "UnitGraph graph is not mutable.";
}
DLContext Context() const override;
......@@ -139,13 +140,14 @@ class Bipartite : public BaseHeteroGraph {
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;
// creators
/*! \brief Create a bipartite graph from COO arrays */
static HeteroGraphPtr CreateFromCOO(int64_t num_src, int64_t num_dst,
/*! \brief Create a graph from COO arrays */
static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col);
/*! \brief Create a bipartite graph from (out) CSR arrays */
/*! \brief Create a graph from (out) CSR arrays */
static HeteroGraphPtr CreateFromCSR(
int64_t num_src, int64_t num_dst,
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids);
/*! \brief Convert the graph to use the given number of bits for storage */
......@@ -173,7 +175,14 @@ class Bipartite : public BaseHeteroGraph {
aten::COOMatrix GetCOOMatrix() const;
private:
Bipartite(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo);
/*!
* \brief constructor
* \param metagraph metagraph
* \param in_csr in edge csr
* \param out_csr out edge csr
* \param coo coo
*/
UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo);
/*! \return Return any existing format. */
HeteroGraphPtr GetAny() const;
......@@ -190,4 +199,4 @@ class Bipartite : public BaseHeteroGraph {
}; // namespace dgl
#endif // DGL_GRAPH_BIPARTITE_H_
#endif // DGL_GRAPH_UNIT_GRAPH_H_
......@@ -10,7 +10,7 @@
#include "./binary_reduce_impl_decl.h"
#include "./utils.h"
#include "../c_api_common.h"
#include "../graph/bipartite.h"
#include "../graph/unit_graph.h"
#include "./csr_interface.h"
using namespace dgl::runtime;
......@@ -243,9 +243,9 @@ class ImmutableGraphCSRWrapper : public CSRWrapper {
const ImmutableGraph* gptr_;
};
class BipartiteCSRWrapper : public CSRWrapper {
class UnitGraphCSRWrapper : public CSRWrapper {
public:
explicit BipartiteCSRWrapper(const Bipartite* graph) :
explicit UnitGraphCSRWrapper(const UnitGraph* graph) :
gptr_(graph) { }
aten::CSRMatrix GetInCSRMatrix() const override {
......@@ -265,7 +265,7 @@ class BipartiteCSRWrapper : public CSRWrapper {
}
private:
const Bipartite* gptr_;
const UnitGraph* gptr_;
};
} // namespace
......@@ -350,9 +350,9 @@ void BinaryOpReduce(
{__VA_ARGS__} \
} else if (ObjectTypeChecker<HeteroGraphRef>::Check(sptr.get())) { \
HeteroGraphRef g = argval; \
auto bgptr = std::dynamic_pointer_cast<Bipartite>(g.sptr()); \
auto bgptr = std::dynamic_pointer_cast<UnitGraph>(g.sptr()); \
CHECK_NOTNULL(bgptr); \
BipartiteCSRWrapper wrapper(bgptr.get()); \
UnitGraphCSRWrapper wrapper(bgptr.get()); \
{__VA_ARGS__} \
} \
} while (0)
......
......@@ -36,7 +36,6 @@ GData<Idx, DType> AllocGData(const std::string& op,
// GData
GData<Idx, DType> gdata;
gdata.x_length = x_len;
gdata.out_size = out_data->shape[0];
gdata.lhs_data = static_cast<DType*>(lhs_data->data);
gdata.rhs_data = static_cast<DType*>(rhs_data->data);
gdata.out_data = static_cast<DType*>(out_data->data);
......@@ -127,7 +126,6 @@ BackwardGData<Idx, DType> AllocBackwardGData(
// GData
BackwardGData<Idx, DType> gdata;
gdata.x_length = x_len;
gdata.out_size = out_data->shape[0];
gdata.lhs_data = static_cast<DType*>(lhs_data->data);
gdata.rhs_data = static_cast<DType*>(rhs_data->data);
gdata.out_data = static_cast<DType*>(out_data->data);
......
......@@ -37,9 +37,7 @@ struct GData {
// length along x(feature) dimension
int64_t x_length{0};
// size of data, can be single value or a vector
int64_t data_len;
// number of rows of the output tensor
int64_t out_size{0};
int64_t data_len{0};
// input data
DType *lhs_data{nullptr}, *rhs_data{nullptr};
// output data
......@@ -122,9 +120,7 @@ struct BackwardGData {
// length along x(feature) dimension
int64_t x_length{0};
// size of data, can be single value or a vector
int64_t data_len;
// number of rows of the output tensor
int64_t out_size{0};
int64_t data_len{0};
// input data
DType *lhs_data{nullptr}, *rhs_data{nullptr}, *out_data{nullptr};
DType *grad_out_data{nullptr};
......@@ -227,7 +223,7 @@ struct BcastGData {
int64_t lhs_shape[NDim]{0}, lhs_stride[NDim]{0};
int64_t rhs_shape[NDim]{0}, rhs_stride[NDim]{0};
// size of data, can be single value or a vector
int64_t data_len;
int64_t data_len{0};
// input data
DType *lhs_data{nullptr}, *rhs_data{nullptr};
// input id mappings
......@@ -333,7 +329,7 @@ struct BackwardBcastGData {
int64_t rhs_shape[NDim]{0}, rhs_stride[NDim]{0};
int64_t out_shape[NDim]{0}, out_stride[NDim]{0};
// size of data, can be single value or a vector
int64_t data_len;
int64_t data_len{0};
// input id mappings
Idx *lhs_mapping{nullptr}, *rhs_mapping{nullptr}, *out_mapping{nullptr};
// input data
......
......@@ -12,7 +12,6 @@
#include "../csr_interface.h"
using minigun::advance::RuntimeConfig;
using Csr = minigun::Csr<int32_t>;
namespace dgl {
namespace kernel {
......@@ -84,9 +83,9 @@ cublasStatus_t Xgeam<double>(cublasHandle_t handle, cublasOperation_t transa,
template <typename DType>
void CusparseCsrmm2(
const RuntimeConfig& rtcfg,
const Csr& csr,
const aten::CSRMatrix& csr,
const DType* B_data, DType* C_data,
int out_size, int x_length) {
int x_length) {
// We use csrmm2 to perform following operation:
// C = A x B, where A is a sparse matrix in csr format, B is the dense matrix for node
// feature tensor. However, since cusparse only supports column-major, while our tensor
......@@ -94,18 +93,10 @@ void CusparseCsrmm2(
// C = trans(A x trans(B)).
// Currently, we use cublasXgeam to implement transposition and allocate intermediate
// workspace memory for this.
// TODO(minjie): The given CSR could potentially represent a bipartite graph (e.g. in the
// case of nodeflow). Currently, we don't have bipartite graph support. Here is a small
// hack. In the python side, we create a CSR that includes both the source and destination
// nodes in the bipartite graph (so it is still square matrix). Here, when multiplying
// this sparse matrix, we specify the number of rows (the `m` here) to be equal to the
// number of rows of the output tensor (i.e, the `out_size`).
// In the future, we should make sure the number of rows of the given csr is equal
// to out_size (a.k.a the given csr is a rectangle matrix).
const int m = out_size;
const int k = csr.row_offsets.length - 1;
const int m = csr.num_rows;
const int n = x_length;
const int nnz = csr.column_indices.length;
const int k = csr.num_cols;
const int nnz = csr.indices->shape[0];
const DType alpha = 1.0;
const DType beta = 0.0;
// device
......@@ -130,7 +121,9 @@ void CusparseCsrmm2(
CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_TRANSPOSE,
m, n, k, nnz, &alpha,
descr, valptr, csr.row_offsets.data, csr.column_indices.data,
descr, valptr,
static_cast<int32_t*>(csr.indptr->data),
static_cast<int32_t*>(csr.indices->data),
B_data, n, &beta, trans_out, m));
device->FreeWorkspace(rtcfg.ctx, valptr);
// transpose the output matrix
......@@ -242,10 +235,9 @@ void CallBinaryReduce<kDLGPU, int32_t, float, SelectSrc, SelectNone,
cuda::FallbackCallBinaryReduce<float>(rtcfg, graph, gdata);
} else {
// cusparse use rev csr for csrmm
auto incsr = graph.GetInCSRMatrix();
Csr csr = utils::CreateCsr<int32_t>(incsr.indptr, incsr.indices);
auto csr = graph.GetInCSRMatrix();
cuda::CusparseCsrmm2(rtcfg, csr, gdata->lhs_data, gdata->out_data,
gdata->out_size, gdata->x_length);
gdata->x_length);
}
}
......@@ -259,10 +251,9 @@ void CallBinaryReduce<kDLGPU, int32_t, double, SelectSrc, SelectNone,
cuda::FallbackCallBinaryReduce<double>(rtcfg, graph, gdata);
} else {
// cusparse use rev csr for csrmm
auto incsr = graph.GetInCSRMatrix();
Csr csr = utils::CreateCsr<int32_t>(incsr.indptr, incsr.indices);
auto csr = graph.GetInCSRMatrix();
cuda::CusparseCsrmm2(rtcfg, csr, gdata->lhs_data, gdata->out_data,
gdata->out_size, gdata->x_length);
gdata->x_length);
}
}
......@@ -278,10 +269,9 @@ void CallBackwardBinaryReduce<kDLGPU, binary_op::kGradLhs, int32_t, float,
if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) {
cuda::FallbackCallBackwardBinaryReduce<float>(rtcfg, graph, gdata);
} else {
auto outcsr = graph.GetOutCSRMatrix();
Csr csr = utils::CreateCsr<int32_t>(outcsr.indptr, outcsr.indices);
auto csr = graph.GetOutCSRMatrix();
cuda::CusparseCsrmm2(rtcfg, csr, gdata->grad_out_data, gdata->grad_lhs_data,
gdata->out_size, gdata->x_length);
gdata->x_length);
}
}
......@@ -295,10 +285,9 @@ void CallBackwardBinaryReduce<kDLGPU, binary_op::kGradLhs, int32_t, double,
if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) {
cuda::FallbackCallBackwardBinaryReduce<double>(rtcfg, graph, gdata);
} else {
auto outcsr = graph.GetOutCSRMatrix();
Csr csr = utils::CreateCsr<int32_t>(outcsr.indptr, outcsr.indices);
auto csr = graph.GetOutCSRMatrix();
cuda::CusparseCsrmm2(rtcfg, csr, gdata->grad_out_data, gdata->grad_lhs_data,
gdata->out_size, gdata->x_length);
gdata->x_length);
}
}
......
......@@ -8,6 +8,7 @@
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_object_api.h>
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/packed_func_ext.h>
#include <vector>
#include <string>
......@@ -62,6 +63,9 @@ struct APIAttrGetter : public AttrVisitor {
found_object_ref = true;
}
}
void Visit(const char* key, NDArray* value) final {
if (skey == key) *ret = value[0];
}
};
struct APIAttrDir : public AttrVisitor {
......@@ -88,6 +92,9 @@ struct APIAttrDir : public AttrVisitor {
void Visit(const char* key, ObjectRef* value) final {
names->push_back(key);
}
void Visit(const char* key, NDArray* value) final {
names->push_back(key);
}
};
int DGLObjectFree(ObjectHandle handle) {
......
This diff is collapsed.
This diff is collapsed.
......@@ -11,9 +11,9 @@ def test_node_batch():
g.ndata['x'] = feat
# test all
v = ALL
v = utils.toindex(slice(0, g.number_of_nodes()))
n_repr = g.get_n_repr(v)
nbatch = NodeBatch(g, v, n_repr)
nbatch = NodeBatch(v, n_repr)
assert F.allclose(nbatch.data['x'], feat)
assert nbatch.mailbox is None
assert F.allclose(nbatch.nodes(), g.nodes())
......@@ -23,7 +23,7 @@ def test_node_batch():
# test partial
v = utils.toindex(F.tensor([0, 3, 5, 7, 9]))
n_repr = g.get_n_repr(v)
nbatch = NodeBatch(g, v, n_repr)
nbatch = NodeBatch(v, n_repr)
assert F.allclose(nbatch.data['x'], F.gather_row(feat, F.tensor([0, 3, 5, 7, 9])))
assert nbatch.mailbox is None
assert F.allclose(nbatch.nodes(), F.tensor([0, 3, 5, 7, 9]))
......@@ -39,13 +39,13 @@ def test_edge_batch():
g.edata['x'] = efeat
# test all
eid = ALL
eid = utils.toindex(slice(0, g.number_of_edges()))
u, v, _ = g._graph.edges('eid')
src_data = g.get_n_repr(u)
edge_data = g.get_e_repr(eid)
dst_data = g.get_n_repr(v)
ebatch = EdgeBatch(g, (u, v, eid), src_data, edge_data, dst_data)
ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data)
assert F.shape(ebatch.src['x'])[0] == g.number_of_edges() and\
F.shape(ebatch.src['x'])[1] == d
assert F.shape(ebatch.dst['x'])[0] == g.number_of_edges() and\
......@@ -64,7 +64,7 @@ def test_edge_batch():
src_data = g.get_n_repr(u)
edge_data = g.get_e_repr(eid)
dst_data = g.get_n_repr(v)
ebatch = EdgeBatch(g, (u, v, eid), src_data, edge_data, dst_data)
ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data)
assert F.shape(ebatch.src['x'])[0] == 8 and\
F.shape(ebatch.src['x'])[1] == d
assert F.shape(ebatch.dst['x'])[0] == 8 and\
......
......@@ -192,7 +192,7 @@ function-naming-style=snake_case
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,j,k,u,v,e,n,m,w,x,y,g,fn,ex,Run,_
good-names=i,j,k,u,v,e,n,m,w,x,y,g,G,hg,fn,ex,Run,_
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment