Unverified Commit 9b4d6079 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Hetero] New syntax (#824)

* WIP. remove graph arg in NodeBatch and EdgeBatch

* refactor: use graph adapter for scheduler

* WIP: recv

* draft impl

* stuck at bipartite

* bipartite->unitgraph; support dsttype == srctype

* pass test_query

* pass test_query

* pass test_view

* test apply

* pass udf message passing tests

* pass quan's test using builtins

* WIP: wildcard slicing

* new construct methods

* broken

* good

* add stack cross reducer

* fix bug; fix mx

* fix bug in csrmm2 when the CSR is not square

* lint

* removed FlattenedHeteroGraph class

* WIP

* prop nodes, prop edges, filter nodes/edges

* add DGLGraph tests to heterograph. Fix several bugs

* finish nx<->hetero graph conversion

* create bipartite from nx

* more spec on hetero/homo conversion

* silly fixes

* check node and edge types

* repr

* to api

* adj APIs

* inc

* fix some lints and bugs

* fix some lints

* hetero/homo conversion

* fix flatten test

* more spec in hetero_from_homo and test

* flatten using concat names

* WIP: creators

* rewrite hetero_from_homo in a more efficient way

* remove useless variables

* fix lint

* subgraphs and typed subgraphs

* lint & removed heterosubgraph class

* lint x2

* disable heterograph mutation test

* docstring update

* add edge id for nx graph test

* fix mx unittests

* fix bug

* try fix

* fix unittest when cross_reducer is stack

* fix ci

* fix nx bipartite bug; docstring

* fix scipy creation bug

* lint

* fix bug when converting heterograph from homograph

* fix bug in hetero_from_homo about ntype order

* trailing white

* docstring fixes for add_foo and data views

* docstring for relation slice

* to_hetero and to_homo with feature support

* lint

* lint

* DGLGraph compatibility

* incidence matrix & docstring fixes

* example string fixes

* feature in hetero_from_relations

* deduplication of edge types in to_hetero

* fix lint

* fix
parent ddb5d804
/*! /*!
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* \file graph/bipartite.h * \file graph/unit_graph.h
* \brief Bipartite graph * \brief UnitGraph graph
*/ */
#ifndef DGL_GRAPH_BIPARTITE_H_ #ifndef DGL_GRAPH_UNIT_GRAPH_H_
#define DGL_GRAPH_BIPARTITE_H_ #define DGL_GRAPH_UNIT_GRAPH_H_
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/lazy.h> #include <dgl/lazy.h>
...@@ -19,55 +19,56 @@ ...@@ -19,55 +19,56 @@
namespace dgl { namespace dgl {
/*! /*!
* \brief Bipartite graph * \brief UnitGraph graph
* *
* Bipartite graph is a special type of heterograph which has two types * UnitGraph graph is a special type of heterograph which
* of nodes: "Src" and "Dst". All the edges are from "Src" type nodes to * (1) Have two types of nodes: "Src" and "Dst". All the edges are
* "Dst" type nodes, so there is no edge among nodes of the same type. * from "Src" type nodes to "Dst" type nodes, so there is no edge among
* nodes of the same type. Thus, its metagraph has two nodes and one edge
* between them.
* (2) Have only one type of nodes and edges. Thus, its metagraph has one node
* and one self-loop edge.
*/ */
class Bipartite : public BaseHeteroGraph { class UnitGraph : public BaseHeteroGraph {
public: public:
/*! \brief source node group type */
static constexpr dgl_type_t kSrcVType = 0;
/*! \brief destination node group type */
static constexpr dgl_type_t kDstVType = 1;
/*! \brief edge group type */
static constexpr dgl_type_t kEType = 0;
// internal data structure // internal data structure
class COO; class COO;
class CSR; class CSR;
typedef std::shared_ptr<COO> COOPtr; typedef std::shared_ptr<COO> COOPtr;
typedef std::shared_ptr<CSR> CSRPtr; typedef std::shared_ptr<CSR> CSRPtr;
uint64_t NumVertexTypes() const override { inline dgl_type_t SrcType() const {
return 2; return 0;
}
inline dgl_type_t DstType() const {
return NumVertexTypes() == 1? 0 : 1;
} }
uint64_t NumEdgeTypes() const override { inline dgl_type_t EdgeType() const {
return 1; return 0;
} }
HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override { HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
LOG(FATAL) << "The method shouldn't be called for Bipartite graph. " LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
<< "The relation graph is simply this graph itself."; << "The relation graph is simply this graph itself.";
return {}; return {};
} }
void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override { void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
LOG(FATAL) << "Bipartite graph is not mutable."; LOG(FATAL) << "UnitGraph graph is not mutable.";
} }
void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override { void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
LOG(FATAL) << "Bipartite graph is not mutable."; LOG(FATAL) << "UnitGraph graph is not mutable.";
} }
void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override { void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
LOG(FATAL) << "Bipartite graph is not mutable."; LOG(FATAL) << "UnitGraph graph is not mutable.";
} }
void Clear() override { void Clear() override {
LOG(FATAL) << "Bipartite graph is not mutable."; LOG(FATAL) << "UnitGraph graph is not mutable.";
} }
DLContext Context() const override; DLContext Context() const override;
...@@ -139,13 +140,14 @@ class Bipartite : public BaseHeteroGraph { ...@@ -139,13 +140,14 @@ class Bipartite : public BaseHeteroGraph {
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override; const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;
// creators // creators
/*! \brief Create a bipartite graph from COO arrays */ /*! \brief Create a graph from COO arrays */
static HeteroGraphPtr CreateFromCOO(int64_t num_src, int64_t num_dst, static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col); IdArray row, IdArray col);
/*! \brief Create a bipartite graph from (out) CSR arrays */ /*! \brief Create a graph from (out) CSR arrays */
static HeteroGraphPtr CreateFromCSR( static HeteroGraphPtr CreateFromCSR(
int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids); IdArray indptr, IdArray indices, IdArray edge_ids);
/*! \brief Convert the graph to use the given number of bits for storage */ /*! \brief Convert the graph to use the given number of bits for storage */
...@@ -173,7 +175,14 @@ class Bipartite : public BaseHeteroGraph { ...@@ -173,7 +175,14 @@ class Bipartite : public BaseHeteroGraph {
aten::COOMatrix GetCOOMatrix() const; aten::COOMatrix GetCOOMatrix() const;
private: private:
Bipartite(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo); /*!
* \brief constructor
* \param metagraph metagraph
* \param in_csr in edge csr
* \param out_csr out edge csr
* \param coo coo
*/
UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo);
/*! \return Return any existing format. */ /*! \return Return any existing format. */
HeteroGraphPtr GetAny() const; HeteroGraphPtr GetAny() const;
...@@ -190,4 +199,4 @@ class Bipartite : public BaseHeteroGraph { ...@@ -190,4 +199,4 @@ class Bipartite : public BaseHeteroGraph {
}; // namespace dgl }; // namespace dgl
#endif // DGL_GRAPH_BIPARTITE_H_ #endif // DGL_GRAPH_UNIT_GRAPH_H_
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "./binary_reduce_impl_decl.h" #include "./binary_reduce_impl_decl.h"
#include "./utils.h" #include "./utils.h"
#include "../c_api_common.h" #include "../c_api_common.h"
#include "../graph/bipartite.h" #include "../graph/unit_graph.h"
#include "./csr_interface.h" #include "./csr_interface.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -243,9 +243,9 @@ class ImmutableGraphCSRWrapper : public CSRWrapper { ...@@ -243,9 +243,9 @@ class ImmutableGraphCSRWrapper : public CSRWrapper {
const ImmutableGraph* gptr_; const ImmutableGraph* gptr_;
}; };
class BipartiteCSRWrapper : public CSRWrapper { class UnitGraphCSRWrapper : public CSRWrapper {
public: public:
explicit BipartiteCSRWrapper(const Bipartite* graph) : explicit UnitGraphCSRWrapper(const UnitGraph* graph) :
gptr_(graph) { } gptr_(graph) { }
aten::CSRMatrix GetInCSRMatrix() const override { aten::CSRMatrix GetInCSRMatrix() const override {
...@@ -265,7 +265,7 @@ class BipartiteCSRWrapper : public CSRWrapper { ...@@ -265,7 +265,7 @@ class BipartiteCSRWrapper : public CSRWrapper {
} }
private: private:
const Bipartite* gptr_; const UnitGraph* gptr_;
}; };
} // namespace } // namespace
...@@ -350,9 +350,9 @@ void BinaryOpReduce( ...@@ -350,9 +350,9 @@ void BinaryOpReduce(
{__VA_ARGS__} \ {__VA_ARGS__} \
} else if (ObjectTypeChecker<HeteroGraphRef>::Check(sptr.get())) { \ } else if (ObjectTypeChecker<HeteroGraphRef>::Check(sptr.get())) { \
HeteroGraphRef g = argval; \ HeteroGraphRef g = argval; \
auto bgptr = std::dynamic_pointer_cast<Bipartite>(g.sptr()); \ auto bgptr = std::dynamic_pointer_cast<UnitGraph>(g.sptr()); \
CHECK_NOTNULL(bgptr); \ CHECK_NOTNULL(bgptr); \
BipartiteCSRWrapper wrapper(bgptr.get()); \ UnitGraphCSRWrapper wrapper(bgptr.get()); \
{__VA_ARGS__} \ {__VA_ARGS__} \
} \ } \
} while (0) } while (0)
......
...@@ -36,7 +36,6 @@ GData<Idx, DType> AllocGData(const std::string& op, ...@@ -36,7 +36,6 @@ GData<Idx, DType> AllocGData(const std::string& op,
// GData // GData
GData<Idx, DType> gdata; GData<Idx, DType> gdata;
gdata.x_length = x_len; gdata.x_length = x_len;
gdata.out_size = out_data->shape[0];
gdata.lhs_data = static_cast<DType*>(lhs_data->data); gdata.lhs_data = static_cast<DType*>(lhs_data->data);
gdata.rhs_data = static_cast<DType*>(rhs_data->data); gdata.rhs_data = static_cast<DType*>(rhs_data->data);
gdata.out_data = static_cast<DType*>(out_data->data); gdata.out_data = static_cast<DType*>(out_data->data);
...@@ -127,7 +126,6 @@ BackwardGData<Idx, DType> AllocBackwardGData( ...@@ -127,7 +126,6 @@ BackwardGData<Idx, DType> AllocBackwardGData(
// GData // GData
BackwardGData<Idx, DType> gdata; BackwardGData<Idx, DType> gdata;
gdata.x_length = x_len; gdata.x_length = x_len;
gdata.out_size = out_data->shape[0];
gdata.lhs_data = static_cast<DType*>(lhs_data->data); gdata.lhs_data = static_cast<DType*>(lhs_data->data);
gdata.rhs_data = static_cast<DType*>(rhs_data->data); gdata.rhs_data = static_cast<DType*>(rhs_data->data);
gdata.out_data = static_cast<DType*>(out_data->data); gdata.out_data = static_cast<DType*>(out_data->data);
......
...@@ -37,9 +37,7 @@ struct GData { ...@@ -37,9 +37,7 @@ struct GData {
// length along x(feature) dimension // length along x(feature) dimension
int64_t x_length{0}; int64_t x_length{0};
// size of data, can be single value or a vector // size of data, can be single value or a vector
int64_t data_len; int64_t data_len{0};
// number of rows of the output tensor
int64_t out_size{0};
// input data // input data
DType *lhs_data{nullptr}, *rhs_data{nullptr}; DType *lhs_data{nullptr}, *rhs_data{nullptr};
// output data // output data
...@@ -122,9 +120,7 @@ struct BackwardGData { ...@@ -122,9 +120,7 @@ struct BackwardGData {
// length along x(feature) dimension // length along x(feature) dimension
int64_t x_length{0}; int64_t x_length{0};
// size of data, can be single value or a vector // size of data, can be single value or a vector
int64_t data_len; int64_t data_len{0};
// number of rows of the output tensor
int64_t out_size{0};
// input data // input data
DType *lhs_data{nullptr}, *rhs_data{nullptr}, *out_data{nullptr}; DType *lhs_data{nullptr}, *rhs_data{nullptr}, *out_data{nullptr};
DType *grad_out_data{nullptr}; DType *grad_out_data{nullptr};
...@@ -227,7 +223,7 @@ struct BcastGData { ...@@ -227,7 +223,7 @@ struct BcastGData {
int64_t lhs_shape[NDim]{0}, lhs_stride[NDim]{0}; int64_t lhs_shape[NDim]{0}, lhs_stride[NDim]{0};
int64_t rhs_shape[NDim]{0}, rhs_stride[NDim]{0}; int64_t rhs_shape[NDim]{0}, rhs_stride[NDim]{0};
// size of data, can be single value or a vector // size of data, can be single value or a vector
int64_t data_len; int64_t data_len{0};
// input data // input data
DType *lhs_data{nullptr}, *rhs_data{nullptr}; DType *lhs_data{nullptr}, *rhs_data{nullptr};
// input id mappings // input id mappings
...@@ -333,7 +329,7 @@ struct BackwardBcastGData { ...@@ -333,7 +329,7 @@ struct BackwardBcastGData {
int64_t rhs_shape[NDim]{0}, rhs_stride[NDim]{0}; int64_t rhs_shape[NDim]{0}, rhs_stride[NDim]{0};
int64_t out_shape[NDim]{0}, out_stride[NDim]{0}; int64_t out_shape[NDim]{0}, out_stride[NDim]{0};
// size of data, can be single value or a vector // size of data, can be single value or a vector
int64_t data_len; int64_t data_len{0};
// input id mappings // input id mappings
Idx *lhs_mapping{nullptr}, *rhs_mapping{nullptr}, *out_mapping{nullptr}; Idx *lhs_mapping{nullptr}, *rhs_mapping{nullptr}, *out_mapping{nullptr};
// input data // input data
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#include "../csr_interface.h" #include "../csr_interface.h"
using minigun::advance::RuntimeConfig; using minigun::advance::RuntimeConfig;
using Csr = minigun::Csr<int32_t>;
namespace dgl { namespace dgl {
namespace kernel { namespace kernel {
...@@ -84,9 +83,9 @@ cublasStatus_t Xgeam<double>(cublasHandle_t handle, cublasOperation_t transa, ...@@ -84,9 +83,9 @@ cublasStatus_t Xgeam<double>(cublasHandle_t handle, cublasOperation_t transa,
template <typename DType> template <typename DType>
void CusparseCsrmm2( void CusparseCsrmm2(
const RuntimeConfig& rtcfg, const RuntimeConfig& rtcfg,
const Csr& csr, const aten::CSRMatrix& csr,
const DType* B_data, DType* C_data, const DType* B_data, DType* C_data,
int out_size, int x_length) { int x_length) {
// We use csrmm2 to perform following operation: // We use csrmm2 to perform following operation:
// C = A x B, where A is a sparse matrix in csr format, B is the dense matrix for node // C = A x B, where A is a sparse matrix in csr format, B is the dense matrix for node
// feature tensor. However, since cusparse only supports column-major, while our tensor // feature tensor. However, since cusparse only supports column-major, while our tensor
...@@ -94,18 +93,10 @@ void CusparseCsrmm2( ...@@ -94,18 +93,10 @@ void CusparseCsrmm2(
// C = trans(A x trans(B)). // C = trans(A x trans(B)).
// Currently, we use cublasXgeam to implement transposition and allocate intermediate // Currently, we use cublasXgeam to implement transposition and allocate intermediate
// workspace memory for this. // workspace memory for this.
// TODO(minjie): The given CSR could potentially represent a bipartite graph (e.g. in the const int m = csr.num_rows;
// case of nodeflow). Currently, we don't have bipartite graph support. Here is a small
// hack. In the python side, we create a CSR that includes both the source and destination
// nodes in the bipartite graph (so it is still square matrix). Here, when multiplying
// this sparse matrix, we specify the number of rows (the `m` here) to be equal to the
// number of rows of the output tensor (i.e, the `out_size`).
// In the future, we should make sure the number of rows of the given csr is equal
// to out_size (a.k.a the given csr is a rectangle matrix).
const int m = out_size;
const int k = csr.row_offsets.length - 1;
const int n = x_length; const int n = x_length;
const int nnz = csr.column_indices.length; const int k = csr.num_cols;
const int nnz = csr.indices->shape[0];
const DType alpha = 1.0; const DType alpha = 1.0;
const DType beta = 0.0; const DType beta = 0.0;
// device // device
...@@ -130,7 +121,9 @@ void CusparseCsrmm2( ...@@ -130,7 +121,9 @@ void CusparseCsrmm2(
CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_TRANSPOSE, CUSPARSE_OPERATION_TRANSPOSE,
m, n, k, nnz, &alpha, m, n, k, nnz, &alpha,
descr, valptr, csr.row_offsets.data, csr.column_indices.data, descr, valptr,
static_cast<int32_t*>(csr.indptr->data),
static_cast<int32_t*>(csr.indices->data),
B_data, n, &beta, trans_out, m)); B_data, n, &beta, trans_out, m));
device->FreeWorkspace(rtcfg.ctx, valptr); device->FreeWorkspace(rtcfg.ctx, valptr);
// transpose the output matrix // transpose the output matrix
...@@ -242,10 +235,9 @@ void CallBinaryReduce<kDLGPU, int32_t, float, SelectSrc, SelectNone, ...@@ -242,10 +235,9 @@ void CallBinaryReduce<kDLGPU, int32_t, float, SelectSrc, SelectNone,
cuda::FallbackCallBinaryReduce<float>(rtcfg, graph, gdata); cuda::FallbackCallBinaryReduce<float>(rtcfg, graph, gdata);
} else { } else {
// cusparse use rev csr for csrmm // cusparse use rev csr for csrmm
auto incsr = graph.GetInCSRMatrix(); auto csr = graph.GetInCSRMatrix();
Csr csr = utils::CreateCsr<int32_t>(incsr.indptr, incsr.indices);
cuda::CusparseCsrmm2(rtcfg, csr, gdata->lhs_data, gdata->out_data, cuda::CusparseCsrmm2(rtcfg, csr, gdata->lhs_data, gdata->out_data,
gdata->out_size, gdata->x_length); gdata->x_length);
} }
} }
...@@ -259,10 +251,9 @@ void CallBinaryReduce<kDLGPU, int32_t, double, SelectSrc, SelectNone, ...@@ -259,10 +251,9 @@ void CallBinaryReduce<kDLGPU, int32_t, double, SelectSrc, SelectNone,
cuda::FallbackCallBinaryReduce<double>(rtcfg, graph, gdata); cuda::FallbackCallBinaryReduce<double>(rtcfg, graph, gdata);
} else { } else {
// cusparse use rev csr for csrmm // cusparse use rev csr for csrmm
auto incsr = graph.GetInCSRMatrix(); auto csr = graph.GetInCSRMatrix();
Csr csr = utils::CreateCsr<int32_t>(incsr.indptr, incsr.indices);
cuda::CusparseCsrmm2(rtcfg, csr, gdata->lhs_data, gdata->out_data, cuda::CusparseCsrmm2(rtcfg, csr, gdata->lhs_data, gdata->out_data,
gdata->out_size, gdata->x_length); gdata->x_length);
} }
} }
...@@ -278,10 +269,9 @@ void CallBackwardBinaryReduce<kDLGPU, binary_op::kGradLhs, int32_t, float, ...@@ -278,10 +269,9 @@ void CallBackwardBinaryReduce<kDLGPU, binary_op::kGradLhs, int32_t, float,
if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) { if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) {
cuda::FallbackCallBackwardBinaryReduce<float>(rtcfg, graph, gdata); cuda::FallbackCallBackwardBinaryReduce<float>(rtcfg, graph, gdata);
} else { } else {
auto outcsr = graph.GetOutCSRMatrix(); auto csr = graph.GetOutCSRMatrix();
Csr csr = utils::CreateCsr<int32_t>(outcsr.indptr, outcsr.indices);
cuda::CusparseCsrmm2(rtcfg, csr, gdata->grad_out_data, gdata->grad_lhs_data, cuda::CusparseCsrmm2(rtcfg, csr, gdata->grad_out_data, gdata->grad_lhs_data,
gdata->out_size, gdata->x_length); gdata->x_length);
} }
} }
...@@ -295,10 +285,9 @@ void CallBackwardBinaryReduce<kDLGPU, binary_op::kGradLhs, int32_t, double, ...@@ -295,10 +285,9 @@ void CallBackwardBinaryReduce<kDLGPU, binary_op::kGradLhs, int32_t, double,
if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) { if (gdata->lhs_mapping || gdata->rhs_mapping || gdata->out_mapping) {
cuda::FallbackCallBackwardBinaryReduce<double>(rtcfg, graph, gdata); cuda::FallbackCallBackwardBinaryReduce<double>(rtcfg, graph, gdata);
} else { } else {
auto outcsr = graph.GetOutCSRMatrix(); auto csr = graph.GetOutCSRMatrix();
Csr csr = utils::CreateCsr<int32_t>(outcsr.indptr, outcsr.indices);
cuda::CusparseCsrmm2(rtcfg, csr, gdata->grad_out_data, gdata->grad_lhs_data, cuda::CusparseCsrmm2(rtcfg, csr, gdata->grad_out_data, gdata->grad_lhs_data,
gdata->out_size, gdata->x_length); gdata->x_length);
} }
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <dgl/runtime/c_object_api.h> #include <dgl/runtime/c_object_api.h>
#include <dgl/runtime/object.h> #include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <vector> #include <vector>
#include <string> #include <string>
...@@ -62,6 +63,9 @@ struct APIAttrGetter : public AttrVisitor { ...@@ -62,6 +63,9 @@ struct APIAttrGetter : public AttrVisitor {
found_object_ref = true; found_object_ref = true;
} }
} }
void Visit(const char* key, NDArray* value) final {
if (skey == key) *ret = value[0];
}
}; };
struct APIAttrDir : public AttrVisitor { struct APIAttrDir : public AttrVisitor {
...@@ -88,6 +92,9 @@ struct APIAttrDir : public AttrVisitor { ...@@ -88,6 +92,9 @@ struct APIAttrDir : public AttrVisitor {
void Visit(const char* key, ObjectRef* value) final { void Visit(const char* key, ObjectRef* value) final {
names->push_back(key); names->push_back(key);
} }
void Visit(const char* key, NDArray* value) final {
names->push_back(key);
}
}; };
int DGLObjectFree(ObjectHandle handle) { int DGLObjectFree(ObjectHandle handle) {
......
This diff is collapsed.
This diff is collapsed.
...@@ -11,9 +11,9 @@ def test_node_batch(): ...@@ -11,9 +11,9 @@ def test_node_batch():
g.ndata['x'] = feat g.ndata['x'] = feat
# test all # test all
v = ALL v = utils.toindex(slice(0, g.number_of_nodes()))
n_repr = g.get_n_repr(v) n_repr = g.get_n_repr(v)
nbatch = NodeBatch(g, v, n_repr) nbatch = NodeBatch(v, n_repr)
assert F.allclose(nbatch.data['x'], feat) assert F.allclose(nbatch.data['x'], feat)
assert nbatch.mailbox is None assert nbatch.mailbox is None
assert F.allclose(nbatch.nodes(), g.nodes()) assert F.allclose(nbatch.nodes(), g.nodes())
...@@ -23,7 +23,7 @@ def test_node_batch(): ...@@ -23,7 +23,7 @@ def test_node_batch():
# test partial # test partial
v = utils.toindex(F.tensor([0, 3, 5, 7, 9])) v = utils.toindex(F.tensor([0, 3, 5, 7, 9]))
n_repr = g.get_n_repr(v) n_repr = g.get_n_repr(v)
nbatch = NodeBatch(g, v, n_repr) nbatch = NodeBatch(v, n_repr)
assert F.allclose(nbatch.data['x'], F.gather_row(feat, F.tensor([0, 3, 5, 7, 9]))) assert F.allclose(nbatch.data['x'], F.gather_row(feat, F.tensor([0, 3, 5, 7, 9])))
assert nbatch.mailbox is None assert nbatch.mailbox is None
assert F.allclose(nbatch.nodes(), F.tensor([0, 3, 5, 7, 9])) assert F.allclose(nbatch.nodes(), F.tensor([0, 3, 5, 7, 9]))
...@@ -39,13 +39,13 @@ def test_edge_batch(): ...@@ -39,13 +39,13 @@ def test_edge_batch():
g.edata['x'] = efeat g.edata['x'] = efeat
# test all # test all
eid = ALL eid = utils.toindex(slice(0, g.number_of_edges()))
u, v, _ = g._graph.edges('eid') u, v, _ = g._graph.edges('eid')
src_data = g.get_n_repr(u) src_data = g.get_n_repr(u)
edge_data = g.get_e_repr(eid) edge_data = g.get_e_repr(eid)
dst_data = g.get_n_repr(v) dst_data = g.get_n_repr(v)
ebatch = EdgeBatch(g, (u, v, eid), src_data, edge_data, dst_data) ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data)
assert F.shape(ebatch.src['x'])[0] == g.number_of_edges() and\ assert F.shape(ebatch.src['x'])[0] == g.number_of_edges() and\
F.shape(ebatch.src['x'])[1] == d F.shape(ebatch.src['x'])[1] == d
assert F.shape(ebatch.dst['x'])[0] == g.number_of_edges() and\ assert F.shape(ebatch.dst['x'])[0] == g.number_of_edges() and\
...@@ -64,7 +64,7 @@ def test_edge_batch(): ...@@ -64,7 +64,7 @@ def test_edge_batch():
src_data = g.get_n_repr(u) src_data = g.get_n_repr(u)
edge_data = g.get_e_repr(eid) edge_data = g.get_e_repr(eid)
dst_data = g.get_n_repr(v) dst_data = g.get_n_repr(v)
ebatch = EdgeBatch(g, (u, v, eid), src_data, edge_data, dst_data) ebatch = EdgeBatch((u, v, eid), src_data, edge_data, dst_data)
assert F.shape(ebatch.src['x'])[0] == 8 and\ assert F.shape(ebatch.src['x'])[0] == 8 and\
F.shape(ebatch.src['x'])[1] == d F.shape(ebatch.src['x'])[1] == d
assert F.shape(ebatch.dst['x'])[0] == 8 and\ assert F.shape(ebatch.dst['x'])[0] == 8 and\
......
...@@ -192,7 +192,7 @@ function-naming-style=snake_case ...@@ -192,7 +192,7 @@ function-naming-style=snake_case
#function-rgx= #function-rgx=
# Good variable names which should always be accepted, separated by a comma. # Good variable names which should always be accepted, separated by a comma.
good-names=i,j,k,u,v,e,n,m,w,x,y,g,fn,ex,Run,_ good-names=i,j,k,u,v,e,n,m,w,x,y,g,G,hg,fn,ex,Run,_
# Include a hint for the correct naming format with invalid-name. # Include a hint for the correct naming format with invalid-name.
include-naming-hint=no include-naming-hint=no
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment