Unverified Commit 44089c8b authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Refactor][Graph] Merge DGLGraph and DGLHeteroGraph (#1862)



* Merge

* [Graph][CUDA] Graph on GPU and many refactoring (#1791)

* change edge_ids behavior and C++ impl

* fix unittests; remove utils.Index in edge_id

* pass mx and th tests

* pass tf test

* add aten::Scatter_

* Add nonzero; impl CSRGetDataAndIndices/CSRSliceMatrix

* CSRGetData and CSRGetDataAndIndices passed tests

* CSRSliceMatrix basic tests

* fix bug in empty slice

* CUDA CSRHasDuplicate

* has_node; has_edge_between

* predecessors, successors

* deprecate send/recv; fix send_and_recv

* deprecate send/recv; fix send_and_recv

* in_edges; out_edges; all_edges; apply_edges

* in deg/out deg

* subgraph/edge_subgraph

* adj

* in_subgraph/out_subgraph

* sample neighbors

* set/get_n/e_repr

* wip: working on refactoring all idtypes

* pass ndata/edata tests on gpu

* fix

* stash

* workaround nonzero issue

* stash

* nx conversion

* test_hetero_basics except update routines

* test_update_routines

* test_hetero_basics for pytorch

* more fixes

* WIP: flatten graph

* wip: flatten

* test_flatten

* test_to_device

* fix bug in to_homo

* fix bug in CSRSliceMatrix

* pass subgraph test

* fix send_and_recv

* fix filter

* test_heterograph

* passed all pytorch tests

* fix mx unittest

* fix pytorch test_nn

* fix all unittests for PyTorch

* passed all mxnet tests

* lint

* fix tf nn test

* pass all tf tests

* lint

* lint

* change deprecation

* try fix compile

* lint

* update METIDS

* fix utest

* fix

* fix utests

* try debug

* revert

* small fix

* fix utests

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* trigger

* +1s

* [kernel] Use heterograph index instead of unitgraph index (#1813)

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* trigger

* +1s

* [Graph] Mutation for Heterograph (#1818)

* mutation add_nodes and add_edges

* Add support for remove_edges, remove_nodes, add_selfloop, remove_selfloop

* Fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>

* upd

* upd

* upd

* fix

* [Transfom] Mutable transform (#1833)

* add nodesy

* All three

* Fix

* lint

* Add some test case

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* fix

* triger

* Fix

* fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>

* [Graph] Migrate Batch & Readout module to heterograph (#1836)

* dgl.batch

* unbatch

* fix to device

* reduce readout; segment reduce

* change batch_num_nodes|edges to function

* reduce readout/ softmax

* broadcast

* topk

* fix

* fix tf and mx

* fix some ci

* fix batch but unbatch differently

* new checkk

* upd

* upd

* upd

* idtype behavior; code reorg

* idtype behavior; code reorg

* wip: test_basics

* pass test_basics

* WIP: from nx/ to nx

* missing files

* upd

* pass test_basics:test_nx_conversion

* Fix test

* Fix inplace update

* WIP: fixing tests

* upd

* pass test_transform cpu

* pass gpu test_transform

* pass test_batched_graph

* GPU graph auto cast to int32

* missing file

* stash

* WIP: rgcn-hetero

* Fix two datasety

* upd

* weird

* Fix capsuley

* fuck you

* fuck matthias

* Fix dgmg

* fix bug in block degrees; pass rgcn-hetero

* rgcn

* gat and diffpool fix
also fix ppi and tu dataset

* Tree LSTM

* pointcloud

* rrn; wip: sgc

* resolve conflicts

* upd

* sgc and reddit dataset

* upd

* Fix deepwalk, gindt and gcn

* fix datasets and sign

* optimization

* optimization

* upd

* upd

* Fix GIN

* fix bug in add_nodes add_edges; tagcn

* adaptive sampling and gcmc

* upd

* upd

* fix geometric

* fix

* metapath2vec

* fix agnn

* fix pickling problem of block

* fix utests

* miss file

* linegraph

* upd

* upd

* upd

* graphsage

* stgcn_wave

* fix hgt

* on unittests

* Fix transformer

* Fix HAN

* passed pytorch unittests

* lint

* fix

* Fix cluster gcn

* cluster-gcn is ready

* on fixing block related codes

* 2nd order derivative

* Revert "2nd order derivative"

This reverts commit 523bf6c249bee61b51b1ad1babf42aad4167f206.

* passed torch utests again

* fix all mxnet unittests

* delete some useless tests

* pass all tf cpu tests

* disable

* disable distributed unittest

* fix

* fix

* lint

* fix

* fix

* fix script

* fix tutorial

* fix apply edges bug

* fix 2 basics

* fix tutorial
Co-authored-by: default avataryzh119 <expye@outlook.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-7-42.us-west-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-1-5.us-west-2.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-68-185.ec2.internal>
parent 015acfd2
...@@ -149,21 +149,53 @@ inline bool CSRHasData(CSRMatrix csr) { ...@@ -149,21 +149,53 @@ inline bool CSRHasData(CSRMatrix csr) {
/*! \brief Whether the column indices of each row is sorted. */ /*! \brief Whether the column indices of each row is sorted. */
bool CSRIsSorted(CSRMatrix csr); bool CSRIsSorted(CSRMatrix csr);
/* \brief Get data. The return type is an ndarray due to possible duplicate entries. */
runtime::NDArray CSRGetData(CSRMatrix , int64_t row, int64_t col);
/*! /*!
* \brief Batched implementation of CSRGetData. * \brief Get the data and the row,col indices for each returned entries.
*
* The operator supports matrix with duplicate entries and all the matched entries
* will be returned. The operator assumes there is NO duplicate (row, col) pair
* in the given input. Otherwise, the returned result is undefined.
*
* If some (row, col) pairs do not contain a valid non-zero elements,
* they will not be included in the return arrays.
*
* \note This operator allows broadcasting (i.e, either row or col can be of length 1). * \note This operator allows broadcasting (i.e, either row or col can be of length 1).
* \param mat Sparse matrix
* \param rows Row index
* \param cols Column index
* \return Three arrays {rows, cols, data}
*/ */
std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix , runtime::NDArray rows, runtime::NDArray cols);
runtime::NDArray CSRGetData(CSRMatrix, runtime::NDArray rows, runtime::NDArray cols); /* \brief Get data. The return type is an ndarray due to possible duplicate entries. */
inline runtime::NDArray CSRGetAllData(CSRMatrix mat, int64_t row, int64_t col) {
const auto& nbits = mat.indptr->dtype.bits;
const auto& ctx = mat.indptr->ctx;
IdArray rows = VecToIdArray<int64_t>({row}, nbits, ctx);
IdArray cols = VecToIdArray<int64_t>({col}, nbits, ctx);
const auto& rst = CSRGetDataAndIndices(mat, rows, cols);
return rst[2];
}
/*! /*!
* \brief Get the data and the row,col indices for each returned entries. * \brief Get the data for each (row, col) pair.
*
* The operator supports matrix with duplicate entries but only one matched entry
* will be returned for each (row, col) pair. Support duplicate input (row, col)
* pairs.
*
* If some (row, col) pairs do not contain a valid non-zero elements,
* their data values are filled with -1.
*
* \note This operator allows broadcasting (i.e, either row or col can be of length 1). * \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*
* \param mat Sparse matrix.
* \param rows Row index.
* \param cols Column index.
* \return Data array. The i^th element is the data of (rows[i], cols[i])
*/ */
std::vector<runtime::NDArray> CSRGetDataAndIndices( runtime::NDArray CSRGetData(CSRMatrix, runtime::NDArray rows, runtime::NDArray cols);
CSRMatrix , runtime::NDArray rows, runtime::NDArray cols);
/*! \brief Return a transposed CSR matrix */ /*! \brief Return a transposed CSR matrix */
CSRMatrix CSRTranspose(CSRMatrix csr); CSRMatrix CSRTranspose(CSRMatrix csr);
...@@ -218,9 +250,13 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows); ...@@ -218,9 +250,13 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
* \brief Get the submatrix specified by the row and col ids. * \brief Get the submatrix specified by the row and col ids.
* *
* In numpy notation, given matrix M, row index array I, col index array J * In numpy notation, given matrix M, row index array I, col index array J
* This function returns the submatrix M[I, J]. * This function returns the submatrix M[I, J]. It assumes that there is no
* duplicate (row, col) pair in the given indices. M could have duplicate
* entries.
* *
* The sliced row and column IDs are relabeled to starting from zero. * The sliced row and column IDs are relabeled according to the given
* rows and cols (i.e., row #0 in the new matrix corresponds to rows[0] in
* the original matrix).
* *
* \param csr The input csr matrix * \param csr The input csr matrix
* \param rows The row index to select * \param rows The row index to select
......
...@@ -235,13 +235,13 @@ ...@@ -235,13 +235,13 @@
CHECK_IF((value)->ndim == (_ndim), "ndim", value_name, _ndim) CHECK_IF((value)->ndim == (_ndim), "ndim", value_name, _ndim)
#define CHECK_SAME_DTYPE(VAR1, VAR2) \ #define CHECK_SAME_DTYPE(VAR1, VAR2) \
CHECK(VAR1->dtype == VAR2->dtype) \ CHECK((VAR1)->dtype == (VAR2)->dtype) \
<< "Expected " << (#VAR2) << " to be the same type as " << (#VAR1) << "(" \ << "Expected " << (#VAR2) << " to be the same type as " << (#VAR1) << "(" \
<< (VAR1)->dtype << ")" \ << (VAR1)->dtype << ")" \
<< ". But got " << (VAR2)->dtype << "."; << ". But got " << (VAR2)->dtype << ".";
#define CHECK_SAME_CONTEXT(VAR1, VAR2) \ #define CHECK_SAME_CONTEXT(VAR1, VAR2) \
CHECK(VAR1->ctx == VAR2->ctx) \ CHECK((VAR1)->ctx == (VAR2)->ctx) \
<< "Expected " << (#VAR2) << " to have the same device context as " << (#VAR1) << "(" \ << "Expected " << (#VAR2) << " to have the same device context as " << (#VAR1) << "(" \
<< (VAR1)->ctx << ")" \ << (VAR1)->ctx << ")" \
<< ". But got " << (VAR2)->ctx << "."; << ". But got " << (VAR2)->ctx << ".";
......
...@@ -17,13 +17,19 @@ namespace dgl { ...@@ -17,13 +17,19 @@ namespace dgl {
* \brief Sparse format. * \brief Sparse format.
*/ */
enum class SparseFormat { enum class SparseFormat {
kAny = 0,
kCOO = 1, kCOO = 1,
kCSR = 2, kCSR = 2,
kCSC = 3, kCSC = 3,
kAuto = 4 // kAuto is a placeholder that indicates it would be materialized later.
}; };
/*!
* \brief Sparse format codes
*/
const dgl_format_code_t all_code = 0x7;
const dgl_format_code_t coo_code = 0x1;
const dgl_format_code_t csr_code = 0x2;
const dgl_format_code_t csc_code = 0x4;
// Parse sparse format from string. // Parse sparse format from string.
inline SparseFormat ParseSparseFormat(const std::string& name) { inline SparseFormat ParseSparseFormat(const std::string& name) {
if (name == "coo") if (name == "coo")
...@@ -32,13 +38,9 @@ inline SparseFormat ParseSparseFormat(const std::string& name) { ...@@ -32,13 +38,9 @@ inline SparseFormat ParseSparseFormat(const std::string& name) {
return SparseFormat::kCSR; return SparseFormat::kCSR;
else if (name == "csc") else if (name == "csc")
return SparseFormat::kCSC; return SparseFormat::kCSC;
else if (name == "any")
return SparseFormat::kAny;
else if (name == "auto")
return SparseFormat::kAuto;
else else
LOG(FATAL) << "Sparse format not recognized"; LOG(FATAL) << "Sparse format not recognized";
return SparseFormat::kAny; return SparseFormat::kCOO;
} }
// Create string from sparse format. // Create string from sparse format.
...@@ -47,25 +49,59 @@ inline std::string ToStringSparseFormat(SparseFormat sparse_format) { ...@@ -47,25 +49,59 @@ inline std::string ToStringSparseFormat(SparseFormat sparse_format) {
return std::string("coo"); return std::string("coo");
else if (sparse_format == SparseFormat::kCSR) else if (sparse_format == SparseFormat::kCSR)
return std::string("csr"); return std::string("csr");
else if (sparse_format == SparseFormat::kCSC)
return std::string("csc");
else if (sparse_format == SparseFormat::kAny)
return std::string("any");
else else
return std::string("auto"); return std::string("csc");
} }
inline dgl_format_code_t SparseFormat2Code(SparseFormat sparse_format) { inline std::vector<SparseFormat> CodeToSparseFormats(dgl_format_code_t code) {
if (sparse_format == SparseFormat::kCOO) std::vector<SparseFormat> ret;
return 1; if (code & coo_code)
else if (sparse_format == SparseFormat::kCSR) ret.push_back(SparseFormat::kCOO);
return 2; if (code & csr_code)
else if (sparse_format == SparseFormat::kCSC) ret.push_back(SparseFormat::kCSR);
return 3; if (code & csc_code)
else if (sparse_format == SparseFormat::kAny) ret.push_back(SparseFormat::kCSC);
return 0; return ret;
else }
return 4;
inline dgl_format_code_t
SparseFormatsToCode(const std::vector<SparseFormat> &formats) {
dgl_format_code_t ret = 0;
for (auto format : formats) {
switch (format) {
case SparseFormat::kCOO:
ret |= coo_code;
break;
case SparseFormat::kCSR:
ret |= csr_code;
break;
case SparseFormat::kCSC:
ret |= csc_code;
break;
default:
LOG(FATAL) << "Only support COO/CSR/CSC formats.";
}
}
return ret;
}
inline std::string CodeToStr(dgl_format_code_t code) {
std::string ret = "";
if (code & coo_code)
ret += "coo ";
if (code & csr_code)
ret += "csr ";
if (code & csc_code)
ret += "csc ";
return ret;
}
inline SparseFormat DecodeFormat(dgl_format_code_t code) {
if (code & coo_code)
return SparseFormat::kCOO;
if (code & csc_code)
return SparseFormat::kCSC;
return SparseFormat::kCSR;
} }
// Sparse matrix object that is exposed to python API. // Sparse matrix object that is exposed to python API.
......
...@@ -32,6 +32,11 @@ typedef NDArray IntArray; ...@@ -32,6 +32,11 @@ typedef NDArray IntArray;
typedef NDArray FloatArray; typedef NDArray FloatArray;
typedef NDArray TypeArray; typedef NDArray TypeArray;
namespace aten {
static const DLContext CPU{kDLCPU, 0};
} // namespace aten
} // namespace dgl } // namespace dgl
#endif // DGL_ATEN_TYPES_H_ #endif // DGL_ATEN_TYPES_H_
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include <memory> #include <memory>
#include "./runtime/object.h" #include "./runtime/object.h"
#include "aten/spmat.h"
#include "aten/types.h"
#include "graph_interface.h" #include "graph_interface.h"
#include "array.h" #include "array.h"
...@@ -179,17 +181,26 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -179,17 +181,26 @@ class BaseHeteroGraph : public runtime::Object {
/*! /*!
* \brief Get all edge ids between the given endpoint pairs. * \brief Get all edge ids between the given endpoint pairs.
* \note Edges are associated with an integer id start from zero. *
* The id is assigned when the edge is being added to the graph. * \param etype The edge type
* If duplicate pairs exist, the returned edge IDs will also duplicate. * \param src The src vertex ids.
* The order of returned edge IDs will follow the order of src-dst pairs * \param dst The dst vertex ids.
* first, and ties are broken by the order of edge ID. * \return EdgeArray containing all edges between all pairs.
*/
virtual EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const = 0;
/*!
* \brief Get edge ids between the given endpoint pairs.
*
* Only find one matched edge Ids even if there are multiple matches due to parallel
* edges. The i^th Id in the returned array is for edge (src[i], dst[i]).
*
* \param etype The edge type * \param etype The edge type
* \param src The src vertex ids. * \param src The src vertex ids.
* \param dst The dst vertex ids. * \param dst The dst vertex ids.
* \return EdgeArray containing all edges between all pairs. * \return EdgeArray containing all edges between all pairs.
*/ */
virtual EdgeArray EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const = 0; virtual IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const = 0;
/*! /*!
* \brief Find the edge ID and return the pair of endpoints * \brief Find the edge ID and return the pair of endpoints
...@@ -358,36 +369,37 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -358,36 +369,37 @@ class BaseHeteroGraph : public runtime::Object {
/*! /*!
* \brief Determine which format to use with a preference. * \brief Determine which format to use with a preference.
* *
* Return the preferred format if the underlying relation graph supports it.
* Otherwise, it will return whatever DGL thinks is the most appropriate given * Otherwise, it will return whatever DGL thinks is the most appropriate given
* the arguments. * the arguments.
* *
* \param etype Edge type. * \param etype Edge type.
* \param preferred_format Preferred sparse format. * \param preferred_formats Preferred sparse formats.
* \return Available sparse format. * \return Available sparse format.
*/ */
virtual SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const = 0; virtual SparseFormat SelectFormat(
dgl_type_t etype, dgl_format_code_t preferred_formats) const = 0;
/*! /*!
* \brief Get restrict sparse format of the graph. * \brief Return sparse formats already created for the graph.
* *
* \return a string representing the sparse format: 'coo'/'csr'/'csc'/'any' * \return a number of type dgl_format_code_t.
*/ */
virtual std::string GetRestrictFormat() const = 0; virtual dgl_format_code_t GetCreatedFormats() const = 0;
/*! /*!
* \brief Return the sparse format in use for the graph. * \brief Return allowed sparse formats for the graph.
* *
* \return a number of type dgl_format_code_t. * \return a number of type dgl_format_code_t.
*/ */
virtual dgl_format_code_t GetFormatInUse() const = 0; virtual dgl_format_code_t GetAllowedFormats() const = 0;
/*! /*!
* \brief Return the graph in specified restrict format. * \brief Return the graph in specified available formats.
* *
* \return The new graph. * \return The new graph.
*/ */
virtual HeteroGraphPtr GetGraphInFormat(SparseFormat restrict_format) const = 0; virtual HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const = 0;
/*! /*!
* \brief Get adjacency matrix in COO format. * \brief Get adjacency matrix in COO format.
...@@ -592,23 +604,23 @@ HeteroGraphPtr CreateHeteroGraph( ...@@ -592,23 +604,23 @@ HeteroGraphPtr CreateHeteroGraph(
* \param num_dst Number of nodes in the destination type. * \param num_dst Number of nodes in the destination type.
* \param row Src node ids of the edges. * \param row Src node ids of the edges.
* \param col Dst node ids of the edges. * \param col Dst node ids of the edges.
* \param restrict_format Sparse format for storing this graph. * \param formats Sparse formats used for storing this graph.
* \return A heterograph pointer. * \return A heterograph pointer.
*/ */
HeteroGraphPtr CreateFromCOO( HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, SparseFormat restrict_format = SparseFormat::kAny); IdArray row, IdArray col, dgl_format_code_t formats = all_code);
/*! /*!
* \brief Create a heterograph from COO input. * \brief Create a heterograph from COO input.
* \param num_vtypes Number of vertex types. Must be 1 or 2. * \param num_vtypes Number of vertex types. Must be 1 or 2.
* \param mat The COO matrix * \param mat The COO matrix
* \param restrict_format Sparse format for storing this graph. * \param formats Sparse formats used for storing this graph.
* \return A heterograph pointer. * \return A heterograph pointer.
*/ */
HeteroGraphPtr CreateFromCOO( HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat, int64_t num_vtypes, const aten::COOMatrix& mat,
SparseFormat restrict_format = SparseFormat::kAny); dgl_format_code_t formats = all_code);
/*! /*!
* \brief Create a heterograph from CSR input. * \brief Create a heterograph from CSR input.
...@@ -618,24 +630,24 @@ HeteroGraphPtr CreateFromCOO( ...@@ -618,24 +630,24 @@ HeteroGraphPtr CreateFromCOO(
* \param indptr Indptr array * \param indptr Indptr array
* \param indices Indices array * \param indices Indices array
* \param edge_ids Edge ids * \param edge_ids Edge ids
* \param restrict_format Sparse format for storing this graph. * \param formats Sparse formats for storing this graph.
* \return A heterograph pointer. * \return A heterograph pointer.
*/ */
HeteroGraphPtr CreateFromCSR( HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format = SparseFormat::kAny); dgl_format_code_t formats = all_code);
/*! /*!
* \brief Create a heterograph from CSR input. * \brief Create a heterograph from CSR input.
* \param num_vtypes Number of vertex types. Must be 1 or 2. * \param num_vtypes Number of vertex types. Must be 1 or 2.
* \param mat The CSR matrix * \param mat The CSR matrix
* \param restrict_format Sparse format for storing this graph. * \param formats Sparse formats for storing this graph.
* \return A heterograph pointer. * \return A heterograph pointer.
*/ */
HeteroGraphPtr CreateFromCSR( HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format = SparseFormat::kAny); dgl_format_code_t formats = all_code);
/*! /*!
* \brief Create a heterograph from CSC input. * \brief Create a heterograph from CSC input.
...@@ -645,24 +657,24 @@ HeteroGraphPtr CreateFromCSR( ...@@ -645,24 +657,24 @@ HeteroGraphPtr CreateFromCSR(
* \param indptr Indptr array * \param indptr Indptr array
* \param indices Indices array * \param indices Indices array
* \param edge_ids Edge ids * \param edge_ids Edge ids
* \param restrict_format Sparse format for storing this graph. * \param formats Sparse formats used for storing this graph.
* \return A heterograph pointer. * \return A heterograph pointer.
*/ */
HeteroGraphPtr CreateFromCSC( HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format = SparseFormat::kAny); dgl_format_code_t formats = all_code);
/*! /*!
* \brief Create a heterograph from CSC input. * \brief Create a heterograph from CSC input.
* \param num_vtypes Number of vertex types. Must be 1 or 2. * \param num_vtypes Number of vertex types. Must be 1 or 2.
* \param mat The CSC matrix * \param mat The CSC matrix
* \param restrict_format Sparse format for storing this graph. * \param formats Sparse formats available for storing this graph.
* \return A heterograph pointer. * \return A heterograph pointer.
*/ */
HeteroGraphPtr CreateFromCSC( HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format = SparseFormat::kAny); dgl_format_code_t formats = all_code);
/*! /*!
* \brief Extract the subgraph of the in edges of the given nodes. * \brief Extract the subgraph of the in edges of the given nodes.
...@@ -818,13 +830,13 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph); ...@@ -818,13 +830,13 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph);
HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states); HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states);
#define FORMAT_HAS_CSC(format) \ #define FORMAT_HAS_CSC(format) \
(format & (1<<2)) ((format) & csc_code)
#define FORMAT_HAS_CSR(format) \ #define FORMAT_HAS_CSR(format) \
(format & (1<<1)) ((format) & csr_code)
#define FORMAT_HAS_COO(format) \ #define FORMAT_HAS_COO(format) \
(format & 1) ((format) & coo_code)
} // namespace dgl } // namespace dgl
......
...@@ -28,15 +28,13 @@ namespace aten { ...@@ -28,15 +28,13 @@ namespace aten {
* \param out_aux A list of NDArray's that contains auxiliary information such * \param out_aux A list of NDArray's that contains auxiliary information such
* as the argmax on source nodes and edges for reduce operators such as * as the argmax on source nodes and edges for reduce operators such as
* `min` and `max`. * `min` and `max`.
* \param format The format of sparse matrix.
*/ */
void SpMM(const std::string& op, const std::string& reduce, void SpMM(const std::string& op, const std::string& reduce,
HeteroGraphPtr graph, HeteroGraphPtr graph,
NDArray ufeat, NDArray ufeat,
NDArray efeat, NDArray efeat,
NDArray out, NDArray out,
std::vector<NDArray> out_aux, std::vector<NDArray> out_aux);
SparseFormat format = SparseFormat::kAny);
/*! /*!
* \brief Generalized Sampled Dense-Dense Matrix Multiplication. * \brief Generalized Sampled Dense-Dense Matrix Multiplication.
...@@ -46,14 +44,12 @@ void SpMM(const std::string& op, const std::string& reduce, ...@@ -46,14 +44,12 @@ void SpMM(const std::string& op, const std::string& reduce,
* \param ufeat The source node feature. * \param ufeat The source node feature.
* \param vfeat The destination node feature. * \param vfeat The destination node feature.
* \param out The output feature on edge. * \param out The output feature on edge.
* \param format The format of sparse matrix.
*/ */
void SDDMM(const std::string& op, void SDDMM(const std::string& op,
HeteroGraphPtr graph, HeteroGraphPtr graph,
NDArray ufeat, NDArray ufeat,
NDArray efeat, NDArray efeat,
NDArray out, NDArray out);
SparseFormat format = SparseFormat::kAny);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -20,11 +20,12 @@ from ._ffi.base import DGLError, __version__ ...@@ -20,11 +20,12 @@ from ._ffi.base import DGLError, __version__
from .base import ALL, NTYPE, NID, ETYPE, EID from .base import ALL, NTYPE, NID, ETYPE, EID
from .readout import * from .readout import *
from .batched_heterograph import * from .batch import *
from .convert import * from .convert import *
from .graph import DGLGraph, batch, unbatch from .graph import DGLGraph as DGLGraphStale
from .generators import * from .generators import *
from .heterograph import DGLHeteroGraph from .heterograph import DGLHeteroGraph
from .heterograph import DGLHeteroGraph as DGLGraph # pylint: disable=reimported
from .nodeflow import * from .nodeflow import *
from .traversal import * from .traversal import *
from .transform import * from .transform import *
......
...@@ -279,3 +279,9 @@ def is_enabled(api): ...@@ -279,3 +279,9 @@ def is_enabled(api):
True if the API is enabled by the current backend. True if the API is enabled by the current backend.
""" """
return api in _enabled_apis return api in _enabled_apis
def to_dgl_nd(data):
return zerocopy_to_dgl_ndarray(data)
def from_dgl_nd(data):
return zerocopy_from_dgl_ndarray(data)
...@@ -59,9 +59,15 @@ def cpu(): ...@@ -59,9 +59,15 @@ def cpu():
def tensor(data, dtype=None): def tensor(data, dtype=None):
"""Create a tensor given the data and data type. """Create a tensor given the data and data type.
If the input is already a tensor and has the same dtype,
directly return.
Scalar input is converted to a array of one element instead of
a 0-dim tensor to avoid certain issues with some backends.
Parameters Parameters
---------- ----------
data : input data data : int, iterable, Tensor
The interface should at least support list and numpy array. The interface should at least support list and numpy array.
The data is copied to a newly-allocated tensor. The data is copied to a newly-allocated tensor.
dtype : data type, optional dtype : data type, optional
...@@ -610,6 +616,11 @@ def split(input, sizes_or_sections, dim): ...@@ -610,6 +616,11 @@ def split(input, sizes_or_sections, dim):
Parameters Parameters
---------- ----------
input : Tensor input : Tensor
Tensor to split.
sizes_or_sections : int, list[int]
Split sizes or sections.
dim : int
The dimension to split on.
Returns Returns
------- -------
...@@ -625,7 +636,7 @@ def repeat(input, repeats, dim): ...@@ -625,7 +636,7 @@ def repeat(input, repeats, dim):
---------- ----------
input : Tensor input : Tensor
Input data array Input data array
repeats : int repeats : int, Tensor
The number of repetitions for each element The number of repetitions for each element
dim : int dim : int
The dim along which to repeat values. The dim along which to repeat values.
...@@ -917,7 +928,7 @@ def uniform(shape, dtype, ctx, low, high): ...@@ -917,7 +928,7 @@ def uniform(shape, dtype, ctx, low, high):
pass pass
def pad_packed_tensor(input, lengths, value, l_min=None): def pad_packed_tensor(input, lengths, value, l_min=None):
"""Pads a packed batch of variable length tensors with given value. r"""Pads a packed batch of variable length tensors with given value.
Parameters Parameters
---------- ----------
...@@ -941,7 +952,7 @@ def pad_packed_tensor(input, lengths, value, l_min=None): ...@@ -941,7 +952,7 @@ def pad_packed_tensor(input, lengths, value, l_min=None):
pass pass
def pack_padded_tensor(input, lengths): def pack_padded_tensor(input, lengths):
"""Packs a tensor containing padded sequence of variable length. r"""Packs a tensor containing padded sequence of variable length.
Parameters Parameters
---------- ----------
...@@ -1040,7 +1051,7 @@ def equal(x, y): ...@@ -1040,7 +1051,7 @@ def equal(x, y):
Returns Returns
------- -------
Boolean tensor Boolean or integer tensor
The result, with the same shape as input. The result, with the same shape as input.
""" """
pass pass
......
...@@ -53,12 +53,14 @@ def _reduce_grad(grad, shape): ...@@ -53,12 +53,14 @@ def _reduce_grad(grad, shape):
If there is broadcast in forward pass, gradients need to be reduced on If there is broadcast in forward pass, gradients need to be reduced on
broadcast dimension. This function checks the input tensor shape and broadcast dimension. This function checks the input tensor shape and
gradient shape and perform the reduction. gradient shape and perform the reduction.
Parameters Parameters
---------- ----------
grad: Tensor grad: Tensor
Gradient tensor Gradient tensor
shape: tuple shape: tuple
Shape of input tensor Shape of input tensor
Returns Returns
------- -------
Tensor Tensor
...@@ -78,6 +80,14 @@ def _reduce_grad(grad, shape): ...@@ -78,6 +80,14 @@ def _reduce_grad(grad, shape):
grad = grad.sum(axis=tuple(reduce_idx), keepdims=True) grad = grad.sum(axis=tuple(reduce_idx), keepdims=True)
return grad.reshape(shape) return grad.reshape(shape)
def _need_reduce_last_dim(ufeat, efeat):
"""Indicates whether to reduce the last dimension on edges
in the backward pass of spmm,
if so, use dot instead of mul."""
ushp = ufeat.shape
eshp = efeat.shape
return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1
def _muldiv(op, x): def _muldiv(op, x):
return 1. / x if op == 'div' else x return 1. / x if op == 'div' else x
...@@ -100,8 +110,7 @@ class GSpMM(mx.autograd.Function): ...@@ -100,8 +110,7 @@ class GSpMM(mx.autograd.Function):
ctx = context(dZ) ctx = context(dZ)
X, Y, argX, argY = self.saved_tensors X, Y, argX, argY = self.saved_tensors
gidx, op, reduce_op = self.gidx, self.op, self.reduce_op gidx, op, reduce_op = self.gidx, self.op, self.reduce_op
dX, dY = nd.empty((), ctx=ctx), nd.empty((), ctx=ctx) if op != 'copy_rhs':
if op != 'copy_rhs' and X.grad is not None:
g_rev = gidx.reverse() g_rev = gidx.reverse()
if reduce_op == 'sum': if reduce_op == 'sum':
if op in ['mul', 'div']: if op in ['mul', 'div']:
...@@ -119,9 +128,13 @@ class GSpMM(mx.autograd.Function): ...@@ -119,9 +128,13 @@ class GSpMM(mx.autograd.Function):
elif op in ['add', 'sub', 'copy_lhs']: elif op in ['add', 'sub', 'copy_lhs']:
dX = _scatter_nd(argX, dZ, X.shape[0]) dX = _scatter_nd(argX, dZ, X.shape[0])
dX = _reduce_grad(dX, X.shape) dX = _reduce_grad(dX, X.shape)
if op != 'copy_lhs' and Y.grad is not None: else:
dX = nd.zeros_like(X)
if op != 'copy_lhs':
if reduce_op == 'sum': if reduce_op == 'sum':
if op in ['mul', 'div']: if op == 'mul' and _need_reduce_last_dim(X, Y):
dY = _gsddmm(gidx, 'dot', X, dZ)
elif op in ['mul', 'div']:
dY = _gsddmm(gidx, 'mul', X, dZ) dY = _gsddmm(gidx, 'mul', X, dZ)
if op == 'div': dY = -dY / (Y ** 2) if op == 'div': dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']: elif op in ['add', 'sub', 'copy_rhs']:
...@@ -136,11 +149,17 @@ class GSpMM(mx.autograd.Function): ...@@ -136,11 +149,17 @@ class GSpMM(mx.autograd.Function):
elif op in ['add', 'sub', 'copy_rhs']: elif op in ['add', 'sub', 'copy_rhs']:
dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0]) dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0])
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
else:
dY = nd.zeros_like(Y)
self.saved_tensors = None self.saved_tensors = None
return dX, dY return dX, dY
def gspmm(g, op, reduce_op, lhs_data, rhs_data): def gspmm(g, op, reduce_op, lhs_data, rhs_data):
func = GSpMM(g, op, reduce_op) func = GSpMM(g, op, reduce_op)
if lhs_data is None:
lhs_data = nd.zeros((1,), ctx=g.device)
if rhs_data is None:
rhs_data = nd.zeros((1,), ctx=g.device)
return func(lhs_data, rhs_data) return func(lhs_data, rhs_data)
class GSDDMM(mx.autograd.Function): class GSDDMM(mx.autograd.Function):
...@@ -161,7 +180,6 @@ class GSDDMM(mx.autograd.Function): ...@@ -161,7 +180,6 @@ class GSDDMM(mx.autograd.Function):
X, Y = self.saved_tensors X, Y = self.saved_tensors
gidx, op = self.gidx, self.op gidx, op = self.gidx, self.op
lhs_target, rhs_target = self.lhs_target, self.rhs_target lhs_target, rhs_target = self.lhs_target, self.rhs_target
dX, dY = nd.empty((), ctx=ctx), nd.empty((), ctx=ctx)
if op != 'copy_rhs': if op != 'copy_rhs':
if lhs_target in ['u', 'v']: if lhs_target in ['u', 'v']:
_gidx = gidx if self.lhs_target == 'v' else gidx.reverse() _gidx = gidx if self.lhs_target == 'v' else gidx.reverse()
...@@ -180,6 +198,8 @@ class GSDDMM(mx.autograd.Function): ...@@ -180,6 +198,8 @@ class GSDDMM(mx.autograd.Function):
else: # mul, div, dot else: # mul, div, dot
dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target) dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target)
dX = _reduce_grad(dX, X.shape) dX = _reduce_grad(dX, X.shape)
else:
dX = nd.zeros_like(X)
if op != 'copy_lhs': if op != 'copy_lhs':
if self.rhs_target in ['u', 'v']: if self.rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse() _gidx = gidx if rhs_target == 'v' else gidx.reverse()
...@@ -200,9 +220,15 @@ class GSDDMM(mx.autograd.Function): ...@@ -200,9 +220,15 @@ class GSDDMM(mx.autograd.Function):
dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target) dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
if op == 'div': dY = -dY / (Y ** 2) if op == 'div': dY = -dY / (Y ** 2)
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
else:
dY = nd.zeros_like(Y)
self.saved_tensors = None self.saved_tensors = None
return dX, dY return dX, dY
def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'): def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
func = GSDDMM(g, op, lhs_target, rhs_target) func = GSDDMM(g, op, lhs_target, rhs_target)
if lhs_data is None:
lhs_data = nd.zeros((1,), ctx=g.device)
if rhs_data is None:
rhs_data = nd.zeros((1,), ctx=g.device)
return func(lhs_data, rhs_data) return func(lhs_data, rhs_data)
...@@ -29,26 +29,26 @@ def data_type_dict(): ...@@ -29,26 +29,26 @@ def data_type_dict():
'int16' : np.int16, 'int16' : np.int16,
'int32' : np.int32, 'int32' : np.int32,
'int64' : np.int64, 'int64' : np.int64,
'bool' : np.bool} 'bool' : np.bool} # mxnet does not support bool
def cpu(): def cpu():
return mx.cpu() return mx.cpu()
def tensor(data, dtype=None): def tensor(data, dtype=None):
if dtype == np.bool:
# mxnet doesn't support bool
dtype = np.int32
if isinstance(data, nd.NDArray): if isinstance(data, nd.NDArray):
if dtype is None or data.dtype == dtype: if dtype is None or data.dtype == dtype:
return data return data
else: else:
return nd.cast(data, dtype) return data.astype(dtype)
else: else:
if isinstance(data, numbers.Number):
data = [data]
if dtype is None: if dtype is None:
if isinstance(data, numbers.Number): if isinstance(data, np.ndarray):
dtype = np.int64 if isinstance(data, numbers.Integral) else np.float32 dtype = np.int32 if data.dtype == np.bool else data.dtype
elif isinstance(data, np.ndarray):
dtype = data.dtype
# mxnet doesn't support bool
if dtype == np.bool:
dtype = np.int32
else: else:
dtype = np.int64 if isinstance(data[0], numbers.Integral) else np.float32 dtype = np.int64 if isinstance(data[0], numbers.Integral) else np.float32
return nd.array(data, dtype=dtype) return nd.array(data, dtype=dtype)
...@@ -128,7 +128,9 @@ def to_backend_ctx(dglctx): ...@@ -128,7 +128,9 @@ def to_backend_ctx(dglctx):
raise ValueError('Unsupported DGL device context:', dglctx) raise ValueError('Unsupported DGL device context:', dglctx)
def astype(input, ty): def astype(input, ty):
return nd.cast(input, ty) if ty == np.bool:
ty = np.int32
return input.astype(ty)
def asnumpy(input): def asnumpy(input):
return input.asnumpy() return input.asnumpy()
...@@ -207,7 +209,11 @@ def split(x, sizes_or_sections, dim): ...@@ -207,7 +209,11 @@ def split(x, sizes_or_sections, dim):
return nd.split(x, sizes_or_sections, axis=dim) return nd.split(x, sizes_or_sections, axis=dim)
def repeat(input, repeats, dim): def repeat(input, repeats, dim):
return nd.repeat(input, repeats, axis=dim) if isinstance(repeats, nd.NDArray):
return nd.array(np.repeat(input.asnumpy(), repeats.asnumpy(), axis=dim),
ctx=input.context, dtype=input.dtype)
else:
return nd.repeat(input, repeats, axis=dim)
def gather_row(data, row_index): def gather_row(data, row_index):
# MXNet workaround for empty row index # MXNet workaround for empty row index
...@@ -273,9 +279,8 @@ def uniform(shape, dtype, ctx, low, high): ...@@ -273,9 +279,8 @@ def uniform(shape, dtype, ctx, low, high):
def pad_packed_tensor(input, lengths, value, l_min=None): def pad_packed_tensor(input, lengths, value, l_min=None):
old_shape = input.shape old_shape = input.shape
if isinstance(lengths, nd.NDArray): if isinstance(lengths, nd.NDArray):
max_len = as_scalar(input.max()) lengths = list(lengths.asnumpy())
else: max_len = builtins.max(lengths)
max_len = builtins.max(lengths)
if l_min is not None: if l_min is not None:
max_len = builtins.max(max_len, l_min) max_len = builtins.max(max_len, l_min)
...@@ -356,7 +361,8 @@ def nonzero_1d(input): ...@@ -356,7 +361,8 @@ def nonzero_1d(input):
# TODO: fallback to numpy is unfortunate # TODO: fallback to numpy is unfortunate
tmp = input.asnumpy() tmp = input.asnumpy()
tmp = np.nonzero(tmp)[0] tmp = np.nonzero(tmp)[0]
return nd.array(tmp, ctx=input.context, dtype=tmp.dtype) r = nd.array(tmp, ctx=input.context, dtype=tmp.dtype)
return r
def sort_1d(input): def sort_1d(input):
# TODO: this isn't an ideal implementation. # TODO: this isn't an ideal implementation.
...@@ -365,11 +371,11 @@ def sort_1d(input): ...@@ -365,11 +371,11 @@ def sort_1d(input):
idx = nd.cast(idx, dtype='int64') idx = nd.cast(idx, dtype='int64')
return val, idx return val, idx
def arange(start, stop, dtype="int64"): def arange(start, stop, dtype=np.int64):
if start >= stop: if start >= stop:
return nd.array([], dtype=data_type_dict()[dtype]) return nd.array([], dtype=dtype)
else: else:
return nd.arange(start, stop, dtype=data_type_dict()[dtype]) return nd.arange(start, stop, dtype=dtype)
def rand_shuffle(arr): def rand_shuffle(arr):
return mx.nd.random.shuffle(arr) return mx.nd.random.shuffle(arr)
...@@ -388,6 +394,7 @@ def zerocopy_from_numpy(np_data): ...@@ -388,6 +394,7 @@ def zerocopy_from_numpy(np_data):
return mx.nd.from_numpy(np_data, zero_copy=True) return mx.nd.from_numpy(np_data, zero_copy=True)
def zerocopy_to_dgl_ndarray(arr): def zerocopy_to_dgl_ndarray(arr):
arr.to_dlpack_for_read()
return dglnd.from_dlpack(arr.to_dlpack_for_read()) return dglnd.from_dlpack(arr.to_dlpack_for_read())
def zerocopy_to_dgl_ndarray_for_write(arr): def zerocopy_to_dgl_ndarray_for_write(arr):
......
from __future__ import absolute_import
import numpy as np
import scipy.sparse as sp
import warnings
warnings.warn('Detect using numpy backend. Please be aware that numpy does not support autograd!')
def data_type_dict():
return {'float16' : np.float16,
'float32' : np.float32,
'float64' : np.float64,
'uint8' : np.uint8,
'int8' : np.int8,
'int16' : np.int16,
'int32' : np.int32,
'int64' : np.int64}
def cpu():
return 'cpu'
def tensor(data, dtype=None):
return np.array(data, dtype)
def as_scalar(data):
if data.dim() > 1:
raise ValueError('The data must have shape (1,).')
return data[0]
def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend.
Different backends have their preferred backend. This info is useful when
constructing a sparse matrix.
"""
return "csr"
def sparse_matrix(data, index, shape, force_format=False):
fmt = index[0]
if fmt == 'coo':
i = index[1][0,:]
j = index[1][1,:]
return sp.coo_matrix((data, (i, j)), shape=shape)
elif fmt == 'csr':
indices = index[1]
indptr = index[2]
return sp.csr_matrix((data, indices, indptr), shape=shape)
else:
raise TypeError('Invalid format: %s.' % fmt)
def sparse_matrix_indices(spmat):
if spmat.format == 'coo':
return ('coo', np.stack(spmat.row, spmat.col))
elif spmat.format == 'csr':
return ('csr', spmat.indices, spmat.indptr)
else:
raise TypeError('Invalid format: %s.' % spmat.format)
def is_tensor(obj):
return isinstance(obj, np.ndarray)
def shape(input):
return input.shape
def dtype(input):
return input.dtype
def context(input):
return 'cpu'
def astype(input, ty):
return input.astype(ty)
def asnumpy(input):
return input
def copy_to(input, ctx):
return input
def sum(input, dim):
return np.sum(input, axis=dim)
def reduce_sum(input):
dtype = input.dtype
return np.array(input.sum(), dtype=dtype)
def mean(input, dim):
return np.mean(input, axis=dim)
def reduce_mean(input):
dtype = input.dtype
return np.array(input.mean(), dtype=dtype)
def max(input, dim):
return np.max(input, axis=dim)
def reduce_max(input):
dtype = input.dtype
return np.array(input.max(), dtype=dtype)
def min(input, dim):
return np.min(input, axis=dim)
def reduce_min(input):
dtype = input.dtype
return np.array(input.min(), dtype=dtype)
def argsort(input, dim, descending):
if descending:
return np.argsort(-input, axis=dim)
return np.argsort(input, axis=dim)
def topk(input, k, dim, descending=True):
topk_indices = argtopk(input, k, dim, descending)
return np.take_along_axis(input, topk_indices, axis=dim)
def argtopk(input, k, dim, descending=True):
sort_indces = argsort(input, dim, descending)
return slice_axis(sort_indces, dim, 0, k)
def exp(input):
return np.exp(input)
def softmax(input, dim=-1):
max_val = input.max(axis=dim)
minus_max = input - np.expand_dims(max_val, axis=dim)
exp_val = np.exp(minus_max)
sum_val = np.sum(exp_val, axis=dim)
return exp_val / np.expand_dims(sum_val, axis=dim)
def cat(seq, dim):
return np.concatenate(seq, axis=dim)
def split(input, sizes_or_sections, dim):
dimsize = input.shape[dim]
if isinstance(sizes_or_sections, int):
if dimsize % sizes_or_sections != 0:
raise ValueError('Require dimension %d to be equally splitted'
' to %d pieces, but got %d.' % (dim, sizes_or_sections, dimsize))
idx = np.arange(sizes_or_sections, dimsize, sizes_or_sections)
else:
idx = np.cumsum(sizes_or_sections)[0:-1]
return np.split(input, idx, axis=dim)
def repeat(input, repeats, dim):
return np.repeat(input, repeats, axis=dim)
def gather_row(data, row_index):
return data[row_index]
def slice_axis(data, axis, begin, end):
if begin >= end:
raise IndexError("Begin index ({}) equals or greater than end index ({})".format(begin, end))
return np.take(data, np.arange(begin, end), axis=axis)
def take(data, indices, dim):
return np.take(data, indices, axis=dim)
def scatter_row(data, row_index, value):
# NOTE: inplace instead of out-place
data[row_index] = value
return data
def scatter_row_inplace(data, row_index, value):
data[row_index] = value
def squeeze(input, dim):
return np.squeeze(input, dim)
def unsqueeze(input, dim):
return np.unsqueeze(input, dim)
def reshape(input, shape):
return np.reshape(input ,shape)
def zeros(shape, dtype):
return np.zeros(shape, dtype=dtype)
def ones(shape, dtype):
return np.ones(shape, dtype=dtype)
def spmm(x, y):
return x.dot(y)
def unique(input):
return np.unique(input)
def full_1d(length, fill_value):
return np.full((length,), fill_value)
def nonzero_1d(input):
return np.nonzero(input)[0]
def sort_1d(input):
return np.sort(input), np.argsort(input)
def arange(start, stop, dtype="int64"):
return np.arange(start, stop, dtype=getattr(np, dtype))
def rand_shuffle(arr):
copy = np.copy(arr)
np.random.shuffle(copy)
return copy
# zerocopy_to_dlpack not enabled
# zerocopy_from_dlpack not enabled
def zerocopy_to_numpy(input):
return input
def zerocopy_from_numpy(np_array):
return np_array
...@@ -8,13 +8,14 @@ def _reduce_grad(grad, shape): ...@@ -8,13 +8,14 @@ def _reduce_grad(grad, shape):
If there is broadcast in forward pass, gradients need to be reduced on If there is broadcast in forward pass, gradients need to be reduced on
broadcast dimension. This function checks the input tensor shape and broadcast dimension. This function checks the input tensor shape and
gradient shape and perform the reduction. gradient shape and perform the reduction.
Parameters Parameters
---------- ----------
grad: Tensor grad: Tensor
Gradient tensor Gradient tensor
shape: tuple shape: tuple
Shape of input tens Shape of input tensor
or
Returns Returns
------- -------
Tensor Tensor
...@@ -33,6 +34,14 @@ def _reduce_grad(grad, shape): ...@@ -33,6 +34,14 @@ def _reduce_grad(grad, shape):
grad = grad.sum(dim=tuple(reduce_idx), keepdim=True) grad = grad.sum(dim=tuple(reduce_idx), keepdim=True)
return grad.view(-1, *shape[1:]) return grad.view(-1, *shape[1:])
def _need_reduce_last_dim(ufeat, efeat):
"""Indicates whether to reduce the last dimension on edges
in the backward pass of spmm,
if so, use dot instead of mul."""
ushp = ufeat.shape
eshp = efeat.shape
return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1
def _muldiv(op, x): def _muldiv(op, x):
return 1. / x if op == 'div' else x return 1. / x if op == 'div' else x
...@@ -52,7 +61,6 @@ class GSpMM(th.autograd.Function): ...@@ -52,7 +61,6 @@ class GSpMM(th.autograd.Function):
def backward(ctx, dZ): def backward(ctx, dZ):
gidx, op, reduce_op = ctx.backward_cache gidx, op, reduce_op = ctx.backward_cache
X, Y, argX, argY = ctx.saved_tensors X, Y, argX, argY = ctx.saved_tensors
dX, dY = None, None
if op != 'copy_rhs' and ctx.needs_input_grad[3]: if op != 'copy_rhs' and ctx.needs_input_grad[3]:
g_rev = gidx.reverse() g_rev = gidx.reverse()
if reduce_op == 'sum': if reduce_op == 'sum':
...@@ -70,9 +78,13 @@ class GSpMM(th.autograd.Function): ...@@ -70,9 +78,13 @@ class GSpMM(th.autograd.Function):
elif op in ['add', 'sub', 'copy_lhs']: elif op in ['add', 'sub', 'copy_lhs']:
dX.scatter_add_(0, argX.long(), dZ) dX.scatter_add_(0, argX.long(), dZ)
dX = _reduce_grad(dX, X.shape) dX = _reduce_grad(dX, X.shape)
else:
dX = None
if op != 'copy_lhs' and ctx.needs_input_grad[4]: if op != 'copy_lhs' and ctx.needs_input_grad[4]:
if reduce_op == 'sum': if reduce_op == 'sum':
if op in ['mul', 'div']: if op == 'mul' and _need_reduce_last_dim(X, Y):
dY = _gsddmm(gidx, 'dot', X, dZ)
elif op in ['mul', 'div']:
dY = _gsddmm(gidx, 'mul', X, dZ) dY = _gsddmm(gidx, 'mul', X, dZ)
if op == 'div': dY = -dY / (Y ** 2) if op == 'div': dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']: elif op in ['add', 'sub', 'copy_rhs']:
...@@ -86,6 +98,8 @@ class GSpMM(th.autograd.Function): ...@@ -86,6 +98,8 @@ class GSpMM(th.autograd.Function):
elif op in ['add', 'sub', 'copy_rhs']: elif op in ['add', 'sub', 'copy_rhs']:
dY.scatter_add_(0, argY.long(), _addsub(op, dZ)) dY.scatter_add_(0, argY.long(), _addsub(op, dZ))
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
else:
dY = None
return None, None, None, dX, dY return None, None, None, dX, dY
class GSDDMM(th.autograd.Function): class GSDDMM(th.autograd.Function):
...@@ -101,7 +115,6 @@ class GSDDMM(th.autograd.Function): ...@@ -101,7 +115,6 @@ class GSDDMM(th.autograd.Function):
def backward(ctx, dZ): def backward(ctx, dZ):
gidx, op, lhs_target, rhs_target = ctx.backward_cache gidx, op, lhs_target, rhs_target = ctx.backward_cache
X, Y = ctx.saved_tensors X, Y = ctx.saved_tensors
dX, dY = None, None
if op != 'copy_rhs' and ctx.needs_input_grad[2]: if op != 'copy_rhs' and ctx.needs_input_grad[2]:
if lhs_target in ['u', 'v']: if lhs_target in ['u', 'v']:
_gidx = gidx if lhs_target == 'v' else gidx.reverse() _gidx = gidx if lhs_target == 'v' else gidx.reverse()
...@@ -120,6 +133,8 @@ class GSDDMM(th.autograd.Function): ...@@ -120,6 +133,8 @@ class GSDDMM(th.autograd.Function):
else: # mul, div, dot else: # mul, div, dot
dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target) dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target)
dX = _reduce_grad(dX, X.shape) dX = _reduce_grad(dX, X.shape)
else:
dX = None
if op != 'copy_lhs' and ctx.needs_input_grad[3]: if op != 'copy_lhs' and ctx.needs_input_grad[3]:
if rhs_target in ['u', 'v']: if rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse() _gidx = gidx if rhs_target == 'v' else gidx.reverse()
...@@ -140,6 +155,8 @@ class GSDDMM(th.autograd.Function): ...@@ -140,6 +155,8 @@ class GSDDMM(th.autograd.Function):
dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target) dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
if op == 'div': dY = -dY / (Y ** 2) if op == 'div': dY = -dY / (Y ** 2)
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
else:
dY = None
return None, None, dX, dY, None, None return None, None, dX, dY, None, None
def gspmm(g, op, reduce_op, lhs_data, rhs_data): def gspmm(g, op, reduce_op, lhs_data, rhs_data):
......
...@@ -5,6 +5,7 @@ from distutils.version import LooseVersion ...@@ -5,6 +5,7 @@ from distutils.version import LooseVersion
import scipy # Weird bug in new pytorch when import scipy after import torch import scipy # Weird bug in new pytorch when import scipy after import torch
import torch as th import torch as th
import builtins import builtins
import numbers
from torch.utils import dlpack from torch.utils import dlpack
from ... import ndarray as nd from ... import ndarray as nd
...@@ -31,7 +32,12 @@ def cpu(): ...@@ -31,7 +32,12 @@ def cpu():
return th.device('cpu') return th.device('cpu')
def tensor(data, dtype=None): def tensor(data, dtype=None):
return th.as_tensor(data, dtype=dtype) if isinstance(data, numbers.Number):
data = [data]
if isinstance(data, th.Tensor):
return th.as_tensor(data, dtype=dtype, device=data.device)
else:
return th.as_tensor(data, dtype=dtype)
def as_scalar(data): def as_scalar(data):
return data.item() return data.item()
...@@ -98,6 +104,7 @@ def asnumpy(input): ...@@ -98,6 +104,7 @@ def asnumpy(input):
return input.cpu().detach().numpy() return input.cpu().detach().numpy()
def copy_to(input, ctx, **kwargs): def copy_to(input, ctx, **kwargs):
ctx = th.device(ctx)
if ctx.type == 'cpu': if ctx.type == 'cpu':
return input.cpu() return input.cpu()
elif ctx.type == 'cuda': elif ctx.type == 'cuda':
...@@ -161,10 +168,7 @@ def split(input, sizes_or_sections, dim): ...@@ -161,10 +168,7 @@ def split(input, sizes_or_sections, dim):
return th.split(input, sizes_or_sections, dim) return th.split(input, sizes_or_sections, dim)
def repeat(input, repeats, dim): def repeat(input, repeats, dim):
# return th.repeat_interleave(input, repeats, dim) # PyTorch 1.1 return th.repeat_interleave(input, repeats, dim) # PyTorch 1.1
if dim < 0:
dim += input.dim()
return th.flatten(th.stack([input] * repeats, dim=dim+1), dim, dim+1)
def gather_row(data, row_index): def gather_row(data, row_index):
return th.index_select(data, 0, row_index.long()) return th.index_select(data, 0, row_index.long())
...@@ -186,7 +190,7 @@ def scatter_row(data, row_index, value): ...@@ -186,7 +190,7 @@ def scatter_row(data, row_index, value):
return data.index_copy(0, row_index.long(), value) return data.index_copy(0, row_index.long(), value)
def scatter_row_inplace(data, row_index, value): def scatter_row_inplace(data, row_index, value):
data[row_index] = value data[row_index.long()] = value
def squeeze(input, dim): def squeeze(input, dim):
return th.squeeze(input, dim) return th.squeeze(input, dim)
...@@ -286,8 +290,8 @@ def nonzero_1d(input): ...@@ -286,8 +290,8 @@ def nonzero_1d(input):
def sort_1d(input): def sort_1d(input):
return th.sort(input) return th.sort(input)
def arange(start, stop, dtype="int64"): def arange(start, stop, dtype=th.int64):
return th.arange(start, stop, dtype=data_type_dict()[dtype]) return th.arange(start, stop, dtype=dtype)
def rand_shuffle(arr): def rand_shuffle(arr):
idx = th.randperm(len(arr)) idx = th.randperm(len(arr))
...@@ -306,15 +310,21 @@ def zerocopy_to_numpy(input): ...@@ -306,15 +310,21 @@ def zerocopy_to_numpy(input):
def zerocopy_from_numpy(np_array): def zerocopy_from_numpy(np_array):
return th.as_tensor(np_array) return th.as_tensor(np_array)
def zerocopy_to_dgl_ndarray(input): def zerocopy_to_dgl_ndarray(data):
return nd.from_dlpack(dlpack.to_dlpack(input.contiguous())) return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))
def zerocopy_to_dgl_ndarray_for_write(input): def zerocopy_to_dgl_ndarray_for_write(input):
return zerocopy_to_dgl_ndarray(input) return zerocopy_to_dgl_ndarray(input)
def zerocopy_from_dgl_ndarray(input): def zerocopy_from_dgl_ndarray(data):
return dlpack.from_dlpack(input.to_dlpack()) if data.shape == (0,):
# NOTE: PyTorch v1.5 does not accept DLPack object representing empty CUDA tensor.
# Related issue: https://github.com/pytorch/pytorch/issues/41182
# The issue will be fixed in v1.6 and later.
return th.tensor([], dtype=getattr(th, data.dtype),
device=to_backend_ctx(data.ctx))
else:
return dlpack.from_dlpack(data.to_dlpack())
class BinaryReduce(th.autograd.Function): class BinaryReduce(th.autograd.Function):
......
...@@ -71,6 +71,14 @@ def _reduce_grad(grad, shape): ...@@ -71,6 +71,14 @@ def _reduce_grad(grad, shape):
grad = tf.reduce_sum(grad, axis=reduce_idx_tensor, keepdims=True) grad = tf.reduce_sum(grad, axis=reduce_idx_tensor, keepdims=True)
return tf.reshape(grad, shape) return tf.reshape(grad, shape)
def _need_reduce_last_dim(ufeat, efeat):
"""Indicates whether to reduce the last dimension on edges
in the backward pass of spmm,
if so, use dot instead of mul."""
ushp = ufeat.shape
eshp = efeat.shape
return ushp[1:-1] == eshp[1:-1] and eshp[-1] == 1 and ushp[-1] > 1
def _muldiv(op, x): def _muldiv(op, x):
return 1. / x if op == 'div' else x return 1. / x if op == 'div' else x
...@@ -83,7 +91,6 @@ def gspmm_real(g, op, reduce_op, X, Y): ...@@ -83,7 +91,6 @@ def gspmm_real(g, op, reduce_op, X, Y):
def grad(dZ): def grad(dZ):
dZ = tensor(dZ) dZ = tensor(dZ)
dX, dY = tf.zeros(()), tf.zeros(())
if op != 'copy_rhs': if op != 'copy_rhs':
g_rev = gidx.reverse() g_rev = gidx.reverse()
if reduce_op == 'sum': if reduce_op == 'sum':
...@@ -102,9 +109,13 @@ def gspmm_real(g, op, reduce_op, X, Y): ...@@ -102,9 +109,13 @@ def gspmm_real(g, op, reduce_op, X, Y):
elif op in ['add', 'sub', 'copy_lhs']: elif op in ['add', 'sub', 'copy_lhs']:
dX = _scatter_nd(argX, dZ, X.shape[0]) dX = _scatter_nd(argX, dZ, X.shape[0])
dX = _reduce_grad(dX, X.shape) dX = _reduce_grad(dX, X.shape)
else:
dX = tf.zeros_like(X)
if op != 'copy_lhs': if op != 'copy_lhs':
if reduce_op == 'sum': if reduce_op == 'sum':
if op in ['mul', 'div']: if op == 'mul' and _need_reduce_last_dim(X, Y):
dY = _gsddmm(gidx, 'dot', X, dZ)
elif op in ['mul', 'div']:
dY = _gsddmm(gidx, 'mul', X, dZ) dY = _gsddmm(gidx, 'mul', X, dZ)
if op == 'div': dY = -dY / (Y ** 2) if op == 'div': dY = -dY / (Y ** 2)
elif op in ['add', 'sub', 'copy_rhs']: elif op in ['add', 'sub', 'copy_rhs']:
...@@ -120,6 +131,8 @@ def gspmm_real(g, op, reduce_op, X, Y): ...@@ -120,6 +131,8 @@ def gspmm_real(g, op, reduce_op, X, Y):
elif op in ['add', 'sub', 'copy_rhs']: elif op in ['add', 'sub', 'copy_rhs']:
dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0]) dY = _scatter_nd(argY, _addsub(op, dZ), Y.shape[0])
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
else:
dY = tf.zeros_like(Y)
return dX, dY return dX, dY
return out, grad return out, grad
...@@ -127,6 +140,10 @@ def gspmm(g, op, reduce_op, X, Y): ...@@ -127,6 +140,10 @@ def gspmm(g, op, reduce_op, X, Y):
@tf.custom_gradient @tf.custom_gradient
def _lambda(X, Y): def _lambda(X, Y):
return gspmm_real(g, op, reduce_op, X, Y) return gspmm_real(g, op, reduce_op, X, Y)
if X is None:
X = tf.zeros(())
if Y is None:
Y = tf.zeros(())
return _lambda(X, Y) return _lambda(X, Y)
def gsddmm_real(g, op, X, Y, lhs_target, rhs_target): def gsddmm_real(g, op, X, Y, lhs_target, rhs_target):
...@@ -134,7 +151,6 @@ def gsddmm_real(g, op, X, Y, lhs_target, rhs_target): ...@@ -134,7 +151,6 @@ def gsddmm_real(g, op, X, Y, lhs_target, rhs_target):
out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target) out = _gsddmm(gidx, op, X, Y, lhs_target, rhs_target)
def grad(dZ): def grad(dZ):
dX, dY = tf.zeros(()), tf.zeros(())
if op != 'copy_rhs': if op != 'copy_rhs':
if lhs_target in ['u', 'v']: if lhs_target in ['u', 'v']:
_gidx = gidx if lhs_target == 'v' else gidx.reverse() _gidx = gidx if lhs_target == 'v' else gidx.reverse()
...@@ -153,6 +169,8 @@ def gsddmm_real(g, op, X, Y, lhs_target, rhs_target): ...@@ -153,6 +169,8 @@ def gsddmm_real(g, op, X, Y, lhs_target, rhs_target):
else: # mul, div, dot else: # mul, div, dot
dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target) dX = _gsddmm(gidx, 'mul', dZ, _muldiv(op, Y), 'e', rhs_target)
dX = _reduce_grad(dX, X.shape) dX = _reduce_grad(dX, X.shape)
else:
dX = tf.zeros_like(X)
if op != 'copy_lhs': if op != 'copy_lhs':
if rhs_target in ['u', 'v']: if rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else gidx.reverse() _gidx = gidx if rhs_target == 'v' else gidx.reverse()
...@@ -173,6 +191,8 @@ def gsddmm_real(g, op, X, Y, lhs_target, rhs_target): ...@@ -173,6 +191,8 @@ def gsddmm_real(g, op, X, Y, lhs_target, rhs_target):
dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target) dY = _gsddmm(gidx, 'mul', dZ, X, 'e', lhs_target)
if op == 'div': dY = -dY / (Y ** 2) if op == 'div': dY = -dY / (Y ** 2)
dY = _reduce_grad(dY, Y.shape) dY = _reduce_grad(dY, Y.shape)
else:
dY = tf.zeros_like(Y)
return dX, dY return dX, dY
return out, grad return out, grad
...@@ -180,4 +200,8 @@ def gsddmm(g, op, X, Y, lhs_target='u', rhs_target='v'): ...@@ -180,4 +200,8 @@ def gsddmm(g, op, X, Y, lhs_target='u', rhs_target='v'):
@tf.custom_gradient @tf.custom_gradient
def _lambda(X, Y): def _lambda(X, Y):
return gsddmm_real(g, op, X, Y, lhs_target, rhs_target) return gsddmm_real(g, op, X, Y, lhs_target, rhs_target)
if X is None:
X = tf.zeros(())
if Y is None:
Y = tf.zeros(())
return _lambda(X, Y) return _lambda(X, Y)
...@@ -6,6 +6,7 @@ from distutils.version import LooseVersion ...@@ -6,6 +6,7 @@ from distutils.version import LooseVersion
import tensorflow as tf import tensorflow as tf
from tensorflow.python.eager import context from tensorflow.python.eager import context
import builtins import builtins
import numbers
import numpy as np import numpy as np
import os import os
...@@ -17,8 +18,8 @@ if not os.getenv("USE_TFDLPACK", False): ...@@ -17,8 +18,8 @@ if not os.getenv("USE_TFDLPACK", False):
if LooseVersion(tf.__version__) < LooseVersion("2.2.0"): if LooseVersion(tf.__version__) < LooseVersion("2.2.0"):
raise RuntimeError("DGL requires tensorflow>=2.2.0 for the official DLPack support.") raise RuntimeError("DGL requires tensorflow>=2.2.0 for the official DLPack support.")
def zerocopy_to_dlpack(input): def zerocopy_to_dlpack(data):
return tf.experimental.dlpack.to_dlpack(input) return tf.experimental.dlpack.to_dlpack(data)
def zerocopy_from_dlpack(dlpack_tensor): def zerocopy_from_dlpack(dlpack_tensor):
# TODO(Jinjing): Tensorflow requires memory to be 64-bytes aligned. We check the # TODO(Jinjing): Tensorflow requires memory to be 64-bytes aligned. We check the
...@@ -57,15 +58,22 @@ def cpu(): ...@@ -57,15 +58,22 @@ def cpu():
return "/cpu:0" return "/cpu:0"
def tensor(data, dtype=None): def tensor(data, dtype=None):
return tf.convert_to_tensor(data, dtype=dtype) if isinstance(data, tf.Tensor):
if dtype is None or data.dtype == dtype:
return data
else:
return tf.cast(data, dtype=dtype)
else:
if isinstance(data, numbers.Number):
data = [data]
return tf.convert_to_tensor(data, dtype=dtype)
def initialize_context(): def initialize_context():
tf.zeros(1) tf.zeros(1)
def as_scalar(data): def as_scalar(data):
data = data.numpy() data = data.numpy()
return data if np.isscalar(data) else data.asscalar() return data if np.isscalar(data) else data.item()
def get_preferred_sparse_format(): def get_preferred_sparse_format():
"""Get the preferred sparse matrix format supported by the backend. """Get the preferred sparse matrix format supported by the backend.
...@@ -109,8 +117,8 @@ def ndim(input): ...@@ -109,8 +117,8 @@ def ndim(input):
def context(input): def context(input):
return input.device spec = tf.DeviceSpec.from_string(input.device)
return "/{}:{}".format(spec.device_type.lower(), spec.device_index)
def device_type(ctx): def device_type(ctx):
return tf.DeviceSpec.from_string(ctx).device_type.lower() return tf.DeviceSpec.from_string(ctx).device_type.lower()
...@@ -147,10 +155,14 @@ def copy_to(input, ctx, **kwargs): ...@@ -147,10 +155,14 @@ def copy_to(input, ctx, **kwargs):
def sum(input, dim, keepdims=False): def sum(input, dim, keepdims=False):
if input.dtype == tf.bool:
input = tf.cast(input, tf.int32)
return tf.reduce_sum(input, axis=dim, keepdims=keepdims) return tf.reduce_sum(input, axis=dim, keepdims=keepdims)
def reduce_sum(input): def reduce_sum(input):
if input.dtype == tf.bool:
input = tf.cast(input, tf.int32)
return tf.reduce_sum(input) return tf.reduce_sum(input)
...@@ -234,7 +246,7 @@ def split(input, sizes_or_sections, dim): ...@@ -234,7 +246,7 @@ def split(input, sizes_or_sections, dim):
def repeat(input, repeats, dim): def repeat(input, repeats, dim):
return tf.keras.backend.repeat_elements(input, repeats, dim) return tf.repeat(input, repeats, dim)
def gather_row(data, row_index): def gather_row(data, row_index):
...@@ -312,7 +324,7 @@ def uniform(shape, dtype, ctx, low, high): ...@@ -312,7 +324,7 @@ def uniform(shape, dtype, ctx, low, high):
def pad_packed_tensor(input, lengths, value, l_min=None): def pad_packed_tensor(input, lengths, value, l_min=None):
old_shape = input.shape old_shape = input.shape
if isinstance(lengths, tf.Tensor): if isinstance(lengths, tf.Tensor):
max_len = as_scalar(lengths.max()) max_len = as_scalar(tf.reduce_max(lengths))
else: else:
max_len = builtins.max(lengths) max_len = builtins.max(lengths)
...@@ -393,9 +405,9 @@ def sort_1d(input): ...@@ -393,9 +405,9 @@ def sort_1d(input):
return tf.sort(input), tf.cast(tf.argsort(input), dtype=tf.int64) return tf.sort(input), tf.cast(tf.argsort(input), dtype=tf.int64)
def arange(start, stop, dtype="int64"): def arange(start, stop, dtype=tf.int64):
with tf.device("/cpu:0"): with tf.device("/cpu:0"):
t = tf.range(start, stop, dtype=data_type_dict()[dtype]) t = tf.range(start, stop, dtype=dtype)
return t return t
...@@ -415,8 +427,14 @@ def zerocopy_from_numpy(np_array): ...@@ -415,8 +427,14 @@ def zerocopy_from_numpy(np_array):
return t return t
def zerocopy_to_dgl_ndarray(input): def zerocopy_to_dgl_ndarray(data):
return nd.from_dlpack(zerocopy_to_dlpack(input)) if data.dtype == tf.int32 and device_type(data.device) == 'gpu':
# NOTE: TF doesn't keep int32 tensors on GPU due to legacy issues with
# shape inference. Convert it to uint32 and cast it back afterwards.
data = tf.cast(data, tf.uint32)
return nd.cast_to_signed(nd.from_dlpack(zerocopy_to_dlpack(data)))
else:
return nd.from_dlpack(zerocopy_to_dlpack(data))
def zerocopy_to_dgl_ndarray_for_write(input): def zerocopy_to_dgl_ndarray_for_write(input):
return zerocopy_to_dgl_ndarray(input) return zerocopy_to_dgl_ndarray(input)
...@@ -437,56 +455,55 @@ def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, ...@@ -437,56 +455,55 @@ def binary_reduce(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
def binary_reduce_real(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, def binary_reduce_real(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
out_size, lhs_map, rhs_map, out_map): out_size, lhs_map, rhs_map, out_map):
lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data) with tf.device(lhs_data.device):
rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data) lhs_data_nd = zerocopy_to_dgl_ndarray(lhs_data)
feat_shape = K.infer_binary_feature_shape( rhs_data_nd = zerocopy_to_dgl_ndarray(rhs_data)
binary_op, lhs_data_nd, rhs_data_nd) feat_shape = K.infer_binary_feature_shape(
out_shape = feat_shape binary_op, lhs_data_nd, rhs_data_nd)
if binary_op == 'dot': out_shape = feat_shape
out_shape = feat_shape[:-1] if binary_op == 'dot':
# out_data = lhs_data.new_empty((out_size,) + out_shape) out_shape = feat_shape[:-1]
out_data = tf.zeros((out_size,) + out_shape, dtype=lhs_data.dtype) out_data = tf.zeros((out_size,) + out_shape, dtype=lhs_data.dtype)
out_data_nd = zerocopy_to_dgl_ndarray(out_data) out_data_nd = zerocopy_to_dgl_ndarray(out_data)
K.binary_op_reduce( K.binary_op_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != 'mean' else 'sum',
binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd, binary_op, graph, lhs, rhs, lhs_data_nd, rhs_data_nd,
out_data_nd, lhs_map[0], rhs_map[0], out_map[0]) out_data_nd, lhs_map[0], rhs_map[0], out_map[0])
# normalize if mean reducer # normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future. # NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if reducer == 'mean': if reducer == 'mean':
# degs = lhs_data.new_empty((out_data.shape[0],)) degs = tf.zeros((out_data.shape[0],), dtype=lhs_data.dtype)
degs = tf.zeros((out_data.shape[0],), dtype=lhs_data.dtype) degs_nd = zerocopy_to_dgl_ndarray(degs)
degs_nd = zerocopy_to_dgl_ndarray(degs) if lhs != TargetCode.DST: # src or edge
if lhs != TargetCode.DST: # src or edge target = lhs
target = lhs n = lhs_data.shape[0]
n = lhs_data.shape[0] in_map = lhs_map[0]
in_map = lhs_map[0] else: # rhs != TargetCode.DST
else: # rhs != TargetCode.DST target = rhs
target = rhs n = rhs_data.shape[0]
n = rhs_data.shape[0] in_map = rhs_map[0]
in_map = rhs_map[0] in_ones = tf.ones((n,), dtype=lhs_data.dtype)
# in_ones = lhs_data.new_ones((n,)) in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
in_ones = tf.ones((n,), dtype=lhs_data.dtype) K.copy_reduce(
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones) 'sum', graph, target, in_ones_nd, degs_nd, in_map, out_map[0])
K.copy_reduce( # reshape
'sum', graph, target, in_ones_nd, degs_nd, in_map, out_map[0]) degs = tf.reshape(degs,
# reshape (out_data.shape[0],) + (1,) * (out_data.ndim - 1))
degs = tf.reshape(degs, degs = tf.clip_by_value(degs, clip_value_min=1,
(out_data.shape[0],) + (1,) * (out_data.ndim - 1)) clip_value_max=np.inf) # ???
degs = tf.clip_by_value(degs, clip_value_min=1, out_data = out_data / degs
clip_value_max=np.inf) # ??? else:
out_data = out_data / degs degs = None
else:
degs = None
def grad(grad_out): def grad(grad_out):
grad_lhs = None with tf.device(grad_out.device):
grad_rhs = None grad_lhs = None
if reducer == 'mean': grad_rhs = None
grad_out = grad_out / degs if reducer == 'mean':
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out) grad_out = grad_out / degs
if True: grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
# grad_lhs = grad_out.new_empty((lhs_data_nd.shape[0],) + feat_shape)
# comptue gradient for lhs
grad_lhs = tf.zeros((lhs_data_nd.shape[0],) + feat_shape) grad_lhs = tf.zeros((lhs_data_nd.shape[0],) + feat_shape)
K.backward_lhs_binary_op_reduce( K.backward_lhs_binary_op_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != 'mean' else 'sum',
...@@ -494,8 +511,8 @@ def binary_reduce_real(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, ...@@ -494,8 +511,8 @@ def binary_reduce_real(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
out_data_nd, grad_out_nd, zerocopy_to_dgl_ndarray(grad_lhs), out_data_nd, grad_out_nd, zerocopy_to_dgl_ndarray(grad_lhs),
lhs_map[1], rhs_map[1], out_map[1]) lhs_map[1], rhs_map[1], out_map[1])
grad_lhs = _reduce_grad(grad_lhs, lhs_data_nd.shape) grad_lhs = _reduce_grad(grad_lhs, lhs_data_nd.shape)
if True:
# grad_rhs = grad_out.new_empty((rhs_data_nd.shape[0],) + feat_shape) # compute gradient for rhs
grad_rhs = tf.zeros((rhs_data_nd.shape[0],) + feat_shape) grad_rhs = tf.zeros((rhs_data_nd.shape[0],) + feat_shape)
K.backward_rhs_binary_op_reduce( K.backward_rhs_binary_op_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != 'mean' else 'sum',
...@@ -504,7 +521,7 @@ def binary_reduce_real(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data, ...@@ -504,7 +521,7 @@ def binary_reduce_real(reducer, binary_op, graph, lhs, rhs, lhs_data, rhs_data,
lhs_map[1], rhs_map[1], out_map[1]) lhs_map[1], rhs_map[1], out_map[1])
grad_rhs = _reduce_grad(grad_rhs, rhs_data_nd.shape) grad_rhs = _reduce_grad(grad_rhs, rhs_data_nd.shape)
return grad_lhs, grad_rhs return grad_lhs, grad_rhs
return out_data, grad return out_data, grad
...@@ -519,50 +536,45 @@ def copy_reduce(reducer, graph, target, in_data, out_size, in_map=(None, None), ...@@ -519,50 +536,45 @@ def copy_reduce(reducer, graph, target, in_data, out_size, in_map=(None, None),
def copy_reduce_real(reducer, graph, target, in_data, out_size, in_map, def copy_reduce_real(reducer, graph, target, in_data, out_size, in_map,
out_map): out_map):
out_data = tf.zeros( with tf.device(in_data.device):
(out_size,) + tuple(in_data.shape[1:]), dtype=in_data.dtype) out_data = tf.zeros(
in_data_nd = zerocopy_to_dgl_ndarray(in_data) (out_size,) + tuple(in_data.shape[1:]), dtype=in_data.dtype)
out_data_nd = zerocopy_to_dgl_ndarray(out_data) in_data_nd = zerocopy_to_dgl_ndarray(in_data)
K.copy_reduce( out_data_nd = zerocopy_to_dgl_ndarray(out_data)
reducer if reducer != 'mean' else 'sum',
graph, target, in_data_nd, out_data_nd, in_map[0], out_map[0])
# normalize if mean reducer
# NOTE(zihao): this is a temporary hack and we should have better solution in the future.
if reducer == 'mean':
# in_ones = in_data.new_ones((in_data.shape[0],))
in_ones = tf.ones(in_data.shape[0], dtype=in_data.dtype)
# degs = in_data.new_empty((out_data.shape[0],))
degs = tf.zeros(out_data.shape[0], dtype=in_data.dtype)
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
degs_nd = zerocopy_to_dgl_ndarray(degs)
K.copy_reduce( K.copy_reduce(
'sum', graph, target, in_ones_nd, degs_nd, in_map[0], out_map[0]) reducer if reducer != 'mean' else 'sum',
# reshape graph, target, in_data_nd, out_data_nd, in_map[0], out_map[0])
degs = tf.reshape(degs, # normalize if mean reducer
(out_data.shape[0],) + (1,) * (out_data.ndim - 1)) # NOTE(zihao): this is a temporary hack and we should have better solution in the future.
degs = tf.clip_by_value(degs, clip_value_min=1, if reducer == 'mean':
clip_value_max=np.inf) # TODO: ??? in_ones = tf.ones(in_data.shape[0], dtype=in_data.dtype)
out_data = out_data / degs degs = tf.zeros(out_data.shape[0], dtype=in_data.dtype)
else: in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
degs = None degs_nd = zerocopy_to_dgl_ndarray(degs)
# save_for_backward can only save variables K.copy_reduce(
'sum', graph, target, in_ones_nd, degs_nd, in_map[0], out_map[0])
# reshape
degs = tf.reshape(degs,
(out_data.shape[0],) + (1,) * (out_data.ndim - 1))
degs = tf.clip_by_value(degs, clip_value_min=1,
clip_value_max=np.inf) # TODO: ???
out_data = out_data / degs
else:
degs = None
def grad(grad_out): def grad(grad_out):
if reducer == 'mean': with tf.device(grad_out.device):
grad_out = grad_out / degs if reducer == 'mean':
grad_out_nd = zerocopy_to_dgl_ndarray(grad_out) grad_out = grad_out / degs
# if ctx.needs_input_grad[3]: grad_out_nd = zerocopy_to_dgl_ndarray(grad_out)
if True:
# grad_in = grad_out.new_empty(in_data_nd.shape)
grad_in = tf.zeros(in_data_nd.shape) grad_in = tf.zeros(in_data_nd.shape)
K.backward_copy_reduce( K.backward_copy_reduce(
reducer if reducer != 'mean' else 'sum', reducer if reducer != 'mean' else 'sum',
graph, target, in_data_nd, out_data_nd, grad_out_nd, graph, target, in_data_nd, out_data_nd, grad_out_nd,
zerocopy_to_dgl_ndarray(grad_in), in_map[1], out_map[1]) zerocopy_to_dgl_ndarray(grad_in), in_map[1], out_map[1])
return grad_in return grad_in
return out_data, grad return out_data, grad
def _reduce_grad(grad, shape): def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension """Reduce gradient on the broadcast dimension
...@@ -601,7 +613,6 @@ def sync(): ...@@ -601,7 +613,6 @@ def sync():
context = context().context() context = context().context()
context.async_wait() context.async_wait()
class GradContext: class GradContext:
def __init__(self): def __init__(self):
self.tensor_for_grad = [] self.tensor_for_grad = []
......
...@@ -26,6 +26,12 @@ def is_all(arg): ...@@ -26,6 +26,12 @@ def is_all(arg):
"""Return true if the argument is a special symbol for all nodes or edges.""" """Return true if the argument is a special symbol for all nodes or edges."""
return isinstance(arg, str) and arg == ALL return isinstance(arg, str) and arg == ALL
# pylint: disable=unused-argument
def dgl_warning_format(message, category, filename, lineno, file=None, line=None):
"""Format DGL warnings."""
return "DGL Warning: {}\n".format(message)
warnings.formatwarning = dgl_warning_format
dgl_warning = warnings.warn # pylint: disable=invalid-name dgl_warning = warnings.warn # pylint: disable=invalid-name
_init_internal_api() _init_internal_api()
"""Utilities for batching/unbatching graphs."""
from collections.abc import Mapping
from . import backend as F
from .base import ALL, is_all, DGLError, dgl_warning
from . import convert
from . import utils
__all__ = ['batch', 'unbatch', 'batch_hetero', 'unbatch_hetero']
def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
r"""Batch a collection of ``DGLGraph``s into one graph for more efficient
graph computation.
Each input graph becomes one disjoint component of the batched graph. The nodes
and edges are relabeled to be disjoint segments:
================= ========= ================= === =========
graphs[0] graphs[1] ... graphs[k]
================= ========= ================= === =========
Original node ID 0 ~ N_0 0 ~ N_1 ... 0 ~ N_k
New node ID 0 ~ N_0 N_0+1 ~ N_0+N_1+1 ... 1+\sum_{i=0}^{k-1} N_i ~
1+\sum_{i=0}^k N_i
================= ========= ================= === =========
Because of this, many of the computations on a batched graph are the same as if
performed on each graph individually, but become much more efficient
since they can be parallelized easily. This makes ``dgl.batch`` very useful
for tasks dealing with many graph samples such as graph classification tasks.
For heterograph inputs, they must share the same set of relations (i.e., node types
and edge types) and the function will perform batching on each relation one by one.
Thus, the result is also a heterograph and has the same set of relations as the inputs.
The numbers of nodes and edges of the input graphs are accessible via the
:func:`DGLGraph.batch_num_nodes` and :func:`DGLGraph.batch_num_edges` attributes
of the result graph. For homographs, they are 1D integer tensors, with each element
being the number of nodes/edges of the corresponding input graph. For
heterographs, they are dictionaries of 1D integer tensors, with node
type or edge type as the keys.
The function supports batching batched graphs. The batch size of the result
graph is the sum of the batch sizes of all the input graphs.
By default, node/edge features are batched by concatenating the feature tensors
of all input graphs. This thus requires features of the same name to have
the same data type and feature size. One can pass ``None`` to the ``ndata``
or ``edata`` argument to prevent feature batching, or pass a list of string
to specify which features to batch.
To unbatch the graph back to a list, use the :func:`dgl.unbatch` function.
Parameters
----------
graphs : list[DGLGraph]
Input graphs.
ndata : list[str], None, optional
Node features to batch.
edata : list[str], None, optional
Edge features to batch.
Returns
-------
DGLGraph
Batched graph.
Examples
--------
Batch homographs
>>> import dgl
>>> import torch as th
>>> # 4 nodes, 3 edges
>>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
>>> # 3 nodes, 4 edges
>>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))
>>> bg = dgl.batch([g1, g2])
>>> bg
Graph(num_nodes=7, num_edges=7,
ndata_schemes={}
edata_schemes={})
>>> bg.batch_size
2
>>> bg.batch_num_nodes()
tensor([4, 3])
>>> bg.batch_num_edges()
tensor([3, 4])
>>> bg.edges()
(tensor([0, 1, 2, 4, 4, 4, 5], tensor([1, 2, 3, 4, 5, 6, 4]))
Batch batched graphs
>>> bbg = dgl.batch([bg, bg])
>>> bbg.batch_size
4
>>> bbg.batch_num_nodes()
tensor([4, 3, 4, 3])
>>> bbg.batch_num_edges()
tensor([3, 4, 3, 4])
Batch graphs with feature data
>>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3)
>>> g1.edata['w'] = th.ones(g1.num_edges(), 2)
>>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3)
>>> g2.edata['w'] = th.zeros(g2.num_edges(), 2)
>>> bg = dgl.batch([g1, g2])
>>> bg.ndata['x']
tensor([[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1]])
>>> bg.edata['w']
tensor([[1, 1],
[1, 1],
[1, 1],
[0, 0],
[0, 0],
[0, 0],
[0, 0]])
Batch heterographs
>>> hg1 = dgl.heterograph({
... ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))})
>>> hg2 = dgl.heterograph({
... ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))})
>>> bhg = dgl.batch([hg1, hg2])
>>> bhg
Graph(num_nodes={'user': 3, 'game': 4},
num_edges={('user', 'plays', 'game'): 5},
metagraph=[('drug', 'game')])
>>> bhg.batch_size
2
>>> bhg.batch_num_nodes()
{'user' : tensor([2, 1]), 'game' : tensor([1, 3])}
>>> bhg.batch_num_edges()
{('user', 'plays', 'game') : tensor([2, 3])}
See Also
--------
unbatch
"""
if len(graphs) == 0:
raise DGLError('The input list of graphs cannot be empty.')
if node_attrs is not None:
dgl_warning('Arguments node_attrs has been deprecated. Please use'
' ndata instead.')
ndata = node_attrs
if edge_attrs is not None:
dgl_warning('Arguments edge_attrs has been deprecated. Please use'
' edata instead.')
edata = edge_attrs
if not (is_all(ndata) or isinstance(ndata, list)):
raise DGLError('Invalid argument ndata: must be a string list but got {}.'.format(
type(ndata)))
if not (is_all(edata) or isinstance(edata, list)):
raise DGLError('Invalid argument edata: must be a string list but got {}.'.format(
type(edata)))
utils.check_all_same_device(graphs, 'graphs')
utils.check_all_same_idtype(graphs, 'graphs')
relations = graphs[0].canonical_etypes
idtype = graphs[0].idtype
device = graphs[0].device
# Batch graph structure for each relation graph
edge_dict = {}
num_nodes_dict = {}
for rel in relations:
srctype, etype, dsttype = rel
srcnid_off = dstnid_off = 0
src, dst = [], []
for g in graphs:
u, v = g.edges(order='eid', etype=rel)
src.append(u + srcnid_off)
dst.append(v + dstnid_off)
srcnid_off += g.number_of_nodes(srctype)
dstnid_off += g.number_of_nodes(dsttype)
edge_dict[rel] = (F.cat(src, 0), F.cat(dst, 0))
num_nodes_dict.update({srctype : srcnid_off, dsttype : dstnid_off})
retg = convert.heterograph(edge_dict, num_nodes_dict, idtype=idtype, device=device)
# Compute batch num nodes
bnn = {}
for ntype in graphs[0].ntypes:
bnn[ntype] = F.cat([g.batch_num_nodes(ntype) for g in graphs], 0)
retg.set_batch_num_nodes(bnn)
# Compute batch num edges
bne = {}
for etype in graphs[0].canonical_etypes:
bne[etype] = F.cat([g.batch_num_edges(etype) for g in graphs], 0)
retg.set_batch_num_edges(bne)
# Batch node feature
if ndata is not None:
for ntype in graphs[0].ntypes:
feat_dicts = [g.nodes[ntype].data for g in graphs if g.number_of_nodes(ntype) > 0]
ret_feat = _batch_feat_dicts(feat_dicts, ndata, 'nodes["{}"].data'.format(ntype))
retg.nodes[ntype].data.update(ret_feat)
# Batch edge feature
if edata is not None:
for etype in graphs[0].canonical_etypes:
feat_dicts = [g.edges[etype].data for g in graphs if g.number_of_edges(etype) > 0]
ret_feat = _batch_feat_dicts(feat_dicts, edata, 'edges[{}].data'.format(etype))
retg.edges[etype].data.update(ret_feat)
return retg
def _batch_feat_dicts(feat_dicts, keys, feat_dict_name):
"""Internal function to batch feature dictionaries.
Parameters
----------
feat_dicts : list[dict[str, Tensor]]
Feature dictionary list.
keys : list[str]
Feature keys. Can be '__ALL__', meaning batching all features.
feat_dict_name : str
Name of the feature dictionary for reporting errors.
Returns
-------
dict[str, Tensor]
New feature dict.
"""
if len(feat_dicts) == 0:
return {}
# sanity checks
if is_all(keys):
utils.check_all_same_keys(feat_dicts, feat_dict_name)
keys = feat_dicts[0].keys()
else:
utils.check_all_have_keys(feat_dicts, keys, feat_dict_name)
utils.check_all_same_schema(feat_dicts, keys, feat_dict_name)
# concat features
ret_feat = {k : F.cat([fd[k] for fd in feat_dicts], 0) for k in keys}
return ret_feat
def unbatch(g, node_split=None, edge_split=None):
"""Revert the batch operation by split the given graph into a list of small ones.
This is the reverse operation of :func:``dgl.batch``. If the ``node_split``
or the ``edge_split`` is not given, it uses the :func:`DGLGraph.batch_num_nodes`
and :func:`DGLGraph.batch_num_edges` of the input graph.
If the ``node_split`` or the ``edge_split`` arguments are given,
it will partition the graph according to the given segments. One must assure
that the partition is valid -- edges of the i^th graph only connect nodes
belong to the i^th graph. Otherwise, an error will be thrown.
The function supports heterograph input, in which case the two split
section arguments shall be of dictionary type -- similar to the
:func:`DGLGraph.batch_num_nodes`
and :func:`DGLGraph.batch_num_edges` attributes of a heterograph.
Parameters
----------
g : DGLGraph
Input graph to unbatch.
node_split : Tensor, dict[str, Tensor], optional
Number of nodes of each result graph.
edge_split : Tensor, dict[str, Tensor], optional
Number of edges of each result graph.
Returns
-------
list[DGLGraph]
Unbatched list of graphs.
Examples
--------
Unbatch a batched graph
>>> import dgl
>>> import torch as th
>>> # 4 nodes, 3 edges
>>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
>>> # 3 nodes, 4 edges
>>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))
>>> # add features
>>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3)
>>> g1.edata['w'] = th.ones(g1.num_edges(), 2)
>>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3)
>>> g2.edata['w'] = th.zeros(g2.num_edges(), 2)
>>> bg = dgl.batch([g1, g2])
>>> f1, f2 = dgl.unbatch(bg)
>>> f1
Graph(num_nodes=4, num_edges=3,
ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)}
edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)})
>>> f2
Graph(num_nodes=3, num_edges=4,
ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)}
edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)})
With provided split arguments:
>>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
>>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))
>>> g3 = dgl.graph((th.tensor([0]), th.tensor([1])))
>>> bg = dgl.batch([g1, g2, g3])
>>> bg.batch_num_nodes()
tensor([4, 3, 2])
>>> bg.batch_num_edges()
tensor([3, 4, 1])
>>> # unbatch but merge g2 and g3
>>> f1, f2 = dgl.unbatch(bg, th.tensor([4, 5]), th.tensor([3, 5]))
>>> f1
Graph(num_nodes=4, num_edges=3,
ndata_schemes={}
edata_schemes={})
>>> f2
Graph(num_nodes=5, num_edges=5,
ndata_schemes={}
edata_schemes={})
Heterograph input
>>> hg1 = dgl.heterograph({
... ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))})
>>> hg2 = dgl.heterograph({
... ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))})
>>> bhg = dgl.batch([hg1, hg2])
>>> f1, f2 = dgl.unbatch(bhg)
>>> f1
Graph(num_nodes={'user': 2, 'game': 1},
num_edges={('user', 'plays', 'game'): 2},
metagraph=[('drug', 'game')])
>>> f2
Graph(num_nodes={'user': 1, 'game': 3},
num_edges={('user', 'plays', 'game'): 3},
metagraph=[('drug', 'game')])
See Also
--------
batch
"""
num_split = None
# Parse node_split
if node_split is None:
node_split = {ntype : g.batch_num_nodes(ntype) for ntype in g.ntypes}
elif not isinstance(node_split, Mapping):
if len(g.ntypes) != 1:
raise DGLError('Must provide a dictionary for argument node_split when'
' there are multiple node types.')
node_split = {g.ntypes[0] : node_split}
if node_split.keys() != set(g.ntypes):
raise DGLError('Must specify node_split for each node type.')
for split in node_split.values():
if num_split is not None and num_split != len(split):
raise DGLError('All node_split and edge_split must specify the same number'
' of split sizes.')
num_split = len(split)
# Parse edge_split
if edge_split is None:
edge_split = {etype : g.batch_num_edges(etype) for etype in g.canonical_etypes}
elif not isinstance(edge_split, Mapping):
if len(g.etypes) != 1:
raise DGLError('Must provide a dictionary for argument edge_split when'
' there are multiple edge types.')
edge_split = {g.canonical_etypes[0] : edge_split}
if edge_split.keys() != set(g.canonical_etypes):
raise DGLError('Must specify edge_split for each canonical edge type.')
for split in edge_split.values():
if num_split is not None and num_split != len(split):
raise DGLError('All edge_split and edge_split must specify the same number'
' of split sizes.')
num_split = len(split)
node_split = {k : F.asnumpy(split).tolist() for k, split in node_split.items()}
edge_split = {k : F.asnumpy(split).tolist() for k, split in edge_split.items()}
# Split edges for each relation
edge_dict_per = [{} for i in range(num_split)]
for rel in g.canonical_etypes:
srctype, etype, dsttype = rel
srcnid_off = dstnid_off = 0
u, v = g.edges(order='eid', etype=rel)
us = F.split(u, edge_split[rel], 0)
vs = F.split(v, edge_split[rel], 0)
for i, (subu, subv) in enumerate(zip(us, vs)):
edge_dict_per[i][rel] = (subu - srcnid_off, subv - dstnid_off)
srcnid_off += node_split[srctype][i]
dstnid_off += node_split[dsttype][i]
num_nodes_dict_per = [{k : split[i] for k, split in node_split.items()}
for i in range(num_split)]
# Create graphs
gs = [convert.heterograph(edge_dict, num_nodes_dict, validate=True, idtype=g.idtype)
for edge_dict, num_nodes_dict in zip(edge_dict_per, num_nodes_dict_per)]
# Unbatch node features
for ntype in g.ntypes:
for key, feat in g.nodes[ntype].data.items():
subfeats = F.split(feat, node_split[ntype], 0)
for subg, subf in zip(gs, subfeats):
subg.nodes[ntype].data[key] = subf
# Unbatch edge features
for etype in g.canonical_etypes:
for key, feat in g.edges[etype].data.items():
subfeats = F.split(feat, edge_split[etype], 0)
for subg, subf in zip(gs, subfeats):
subg.edges[etype].data[key] = subf
return gs
#### DEPRECATED APIS ####
def batch_hetero(*args, **kwargs):
"""DEPREACTED: please use dgl.batch """
dgl_warning('From v0.5, DGLHeteroGraph is merged into DGLGraph. You can safely'
' replace dgl.batch_hetero with dgl.batch')
return batch(*args, **kwargs)
def unbatch_hetero(*args, **kwargs):
"""DEPREACTED: please use dgl.unbatch """
dgl_warning('From v0.5, DGLHeteroGraph is merged into DGLGraph. You can safely'
' replace dgl.unbatch_hetero with dgl.unbatch')
return batch(*args, **kwargs)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment