Unverified Commit 929d8634 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Sparse-sparse matrix multiplication, addition, and masking (#2753)

* test

* more stuff

* add test

* fixes

* optimize algo

* replace unordered_map with arrays

* lint

* lint x2

* oops

* disable gpu csrmm tests

* remove gpu invocation

* optimize with openmp

* remove python functions

* add back with docstrings

* lint

* lint

* update python interface

* functionize

* functionize

* lint

* lint
parent d04d59ee
...@@ -267,6 +267,7 @@ if(BUILD_CPP_TEST) ...@@ -267,6 +267,7 @@ if(BUILD_CPP_TEST)
include_directories("third_party/dlpack/include") include_directories("third_party/dlpack/include")
include_directories("third_party/xbyak") include_directories("third_party/xbyak")
include_directories("third_party/dmlc-core/include") include_directories("third_party/dmlc-core/include")
include_directories("third_party/phmap")
file(GLOB_RECURSE TEST_SRC_FILES ${PROJECT_SOURCE_DIR}/tests/cpp/*.cc) file(GLOB_RECURSE TEST_SRC_FILES ${PROJECT_SOURCE_DIR}/tests/cpp/*.cc)
add_executable(runUnitTests ${TEST_SRC_FILES}) add_executable(runUnitTests ${TEST_SRC_FILES})
target_link_libraries(runUnitTests gtest gtest_main) target_link_libraries(runUnitTests gtest gtest_main)
......
...@@ -25,10 +25,10 @@ enum class SparseFormat { ...@@ -25,10 +25,10 @@ enum class SparseFormat {
/*! /*!
* \brief Sparse format codes * \brief Sparse format codes
*/ */
const dgl_format_code_t all_code = 0x7; const dgl_format_code_t ALL_CODE = 0x7;
const dgl_format_code_t coo_code = 0x1; const dgl_format_code_t COO_CODE = 0x1;
const dgl_format_code_t csr_code = 0x2; const dgl_format_code_t CSR_CODE = 0x2;
const dgl_format_code_t csc_code = 0x4; const dgl_format_code_t CSC_CODE = 0x4;
// Parse sparse format from string. // Parse sparse format from string.
inline SparseFormat ParseSparseFormat(const std::string& name) { inline SparseFormat ParseSparseFormat(const std::string& name) {
...@@ -55,11 +55,11 @@ inline std::string ToStringSparseFormat(SparseFormat sparse_format) { ...@@ -55,11 +55,11 @@ inline std::string ToStringSparseFormat(SparseFormat sparse_format) {
inline std::vector<SparseFormat> CodeToSparseFormats(dgl_format_code_t code) { inline std::vector<SparseFormat> CodeToSparseFormats(dgl_format_code_t code) {
std::vector<SparseFormat> ret; std::vector<SparseFormat> ret;
if (code & coo_code) if (code & COO_CODE)
ret.push_back(SparseFormat::kCOO); ret.push_back(SparseFormat::kCOO);
if (code & csr_code) if (code & CSR_CODE)
ret.push_back(SparseFormat::kCSR); ret.push_back(SparseFormat::kCSR);
if (code & csc_code) if (code & CSC_CODE)
ret.push_back(SparseFormat::kCSC); ret.push_back(SparseFormat::kCSC);
return ret; return ret;
} }
...@@ -70,13 +70,13 @@ SparseFormatsToCode(const std::vector<SparseFormat> &formats) { ...@@ -70,13 +70,13 @@ SparseFormatsToCode(const std::vector<SparseFormat> &formats) {
for (auto format : formats) { for (auto format : formats) {
switch (format) { switch (format) {
case SparseFormat::kCOO: case SparseFormat::kCOO:
ret |= coo_code; ret |= COO_CODE;
break; break;
case SparseFormat::kCSR: case SparseFormat::kCSR:
ret |= csr_code; ret |= CSR_CODE;
break; break;
case SparseFormat::kCSC: case SparseFormat::kCSC:
ret |= csc_code; ret |= CSC_CODE;
break; break;
default: default:
LOG(FATAL) << "Only support COO/CSR/CSC formats."; LOG(FATAL) << "Only support COO/CSR/CSC formats.";
...@@ -87,19 +87,19 @@ SparseFormatsToCode(const std::vector<SparseFormat> &formats) { ...@@ -87,19 +87,19 @@ SparseFormatsToCode(const std::vector<SparseFormat> &formats) {
inline std::string CodeToStr(dgl_format_code_t code) { inline std::string CodeToStr(dgl_format_code_t code) {
std::string ret = ""; std::string ret = "";
if (code & coo_code) if (code & COO_CODE)
ret += "coo "; ret += "coo ";
if (code & csr_code) if (code & CSR_CODE)
ret += "csr "; ret += "csr ";
if (code & csc_code) if (code & CSC_CODE)
ret += "csc "; ret += "csc ";
return ret; return ret;
} }
inline SparseFormat DecodeFormat(dgl_format_code_t code) { inline SparseFormat DecodeFormat(dgl_format_code_t code) {
if (code & coo_code) if (code & COO_CODE)
return SparseFormat::kCOO; return SparseFormat::kCOO;
if (code & csc_code) if (code & CSC_CODE)
return SparseFormat::kCSC; return SparseFormat::kCSC;
return SparseFormat::kCSR; return SparseFormat::kCSR;
} }
......
...@@ -609,7 +609,7 @@ HeteroGraphPtr CreateHeteroGraph( ...@@ -609,7 +609,7 @@ HeteroGraphPtr CreateHeteroGraph(
*/ */
HeteroGraphPtr CreateFromCOO( HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, dgl_format_code_t formats = all_code); IdArray row, IdArray col, dgl_format_code_t formats = ALL_CODE);
/*! /*!
* \brief Create a heterograph from COO input. * \brief Create a heterograph from COO input.
...@@ -620,7 +620,7 @@ HeteroGraphPtr CreateFromCOO( ...@@ -620,7 +620,7 @@ HeteroGraphPtr CreateFromCOO(
*/ */
HeteroGraphPtr CreateFromCOO( HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat, int64_t num_vtypes, const aten::COOMatrix& mat,
dgl_format_code_t formats = all_code); dgl_format_code_t formats = ALL_CODE);
/*! /*!
* \brief Create a heterograph from CSR input. * \brief Create a heterograph from CSR input.
...@@ -636,7 +636,7 @@ HeteroGraphPtr CreateFromCOO( ...@@ -636,7 +636,7 @@ HeteroGraphPtr CreateFromCOO(
HeteroGraphPtr CreateFromCSR( HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indptr, IdArray indices, IdArray edge_ids,
dgl_format_code_t formats = all_code); dgl_format_code_t formats = ALL_CODE);
/*! /*!
* \brief Create a heterograph from CSR input. * \brief Create a heterograph from CSR input.
...@@ -647,7 +647,7 @@ HeteroGraphPtr CreateFromCSR( ...@@ -647,7 +647,7 @@ HeteroGraphPtr CreateFromCSR(
*/ */
HeteroGraphPtr CreateFromCSR( HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats = all_code); dgl_format_code_t formats = ALL_CODE);
/*! /*!
* \brief Create a heterograph from CSC input. * \brief Create a heterograph from CSC input.
...@@ -663,7 +663,7 @@ HeteroGraphPtr CreateFromCSR( ...@@ -663,7 +663,7 @@ HeteroGraphPtr CreateFromCSR(
HeteroGraphPtr CreateFromCSC( HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indptr, IdArray indices, IdArray edge_ids,
dgl_format_code_t formats = all_code); dgl_format_code_t formats = ALL_CODE);
/*! /*!
* \brief Create a heterograph from CSC input. * \brief Create a heterograph from CSC input.
...@@ -674,7 +674,7 @@ HeteroGraphPtr CreateFromCSC( ...@@ -674,7 +674,7 @@ HeteroGraphPtr CreateFromCSC(
*/ */
HeteroGraphPtr CreateFromCSC( HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats = all_code); dgl_format_code_t formats = ALL_CODE);
/*! /*!
* \brief Extract the subgraph of the in edges of the given nodes. * \brief Extract the subgraph of the in edges of the given nodes.
...@@ -830,13 +830,13 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph); ...@@ -830,13 +830,13 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph);
HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states); HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states);
#define FORMAT_HAS_CSC(format) \ #define FORMAT_HAS_CSC(format) \
((format) & csc_code) ((format) & CSC_CODE)
#define FORMAT_HAS_CSR(format) \ #define FORMAT_HAS_CSR(format) \
((format) & csr_code) ((format) & CSR_CODE)
#define FORMAT_HAS_COO(format) \ #define FORMAT_HAS_COO(format) \
((format) & coo_code) ((format) & COO_CODE)
} // namespace dgl } // namespace dgl
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility>
#include "array.h" #include "array.h"
#include "./bcast.h" #include "./bcast.h"
...@@ -51,6 +52,29 @@ void SDDMM(const std::string& op, ...@@ -51,6 +52,29 @@ void SDDMM(const std::string& op,
NDArray efeat, NDArray efeat,
NDArray out); NDArray out);
/*!
* \brief Sparse-sparse matrix multiplication.
*
* \note B is transposed (i.e. in CSC format).
*/
std::pair<CSRMatrix, NDArray> CSRMM(
CSRMatrix A,
NDArray A_weights,
CSRMatrix B,
NDArray B_weights);
/*!
* \brief Sparse-sparse matrix summation.
*/
std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& A,
const std::vector<NDArray>& A_weights);
/*!
* \brief Return a sparse matrix with the values of A but nonzero entry locations of B.
*/
NDArray CSRMask(const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
"""Module for sparse matrix operators.""" """Module for sparse matrix operators."""
# pylint: disable= invalid-name # pylint: disable= invalid-name
from __future__ import absolute_import from __future__ import absolute_import
import dgl.ndarray as nd from . import ndarray as nd
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import DGLError from .base import DGLError
from . import backend as F from . import backend as F
...@@ -366,5 +366,124 @@ def _bwd_segment_cmp(feat, arg, m): ...@@ -366,5 +366,124 @@ def _bwd_segment_cmp(feat, arg, m):
to_dgl_nd_for_write(out)) to_dgl_nd_for_write(out))
return out return out
class CSRMatrix(object):
"""Device- and backend-agnostic sparse matrix in CSR format.
Parameters
----------
data : Tensor
The data array.
indices : Tensor
The column indices array.
indptr : Tensor
The row index pointer array.
num_rows : int
The number of rows.
num_cols : int
The number of columns.
"""
def __init__(self, data, indices, indptr, num_rows, num_cols):
self.indptr = indptr
self.indices = indices
self.data = data
self.shape = (num_rows, num_cols)
def csrmm(A, B):
"""Sparse-sparse matrix multiplication.
This is an internal function whose interface is subject to changes.
Parameters
----------
A : dgl.sparse.CSRMatrix
The left operand
B : dgl.sparse.CSRMatrix
The right operand
Returns
-------
dgl.sparse.CSRMatrix
The result
"""
A_indptr = F.zerocopy_from_numpy(A.indptr)
A_indices = F.zerocopy_from_numpy(A.indices)
A_data = F.zerocopy_from_numpy(A.data)
B_indptr = F.zerocopy_from_numpy(B.indptr)
B_indices = F.zerocopy_from_numpy(B.indices)
B_data = F.zerocopy_from_numpy(B.data)
C_indptr, C_indices, C_data = _CAPI_DGLCSRMM(
A.shape[0], A.shape[1], B.shape[1],
F.to_dgl_nd(A_indptr),
F.to_dgl_nd(A_indices),
F.to_dgl_nd(A_data),
F.to_dgl_nd(B_indptr),
F.to_dgl_nd(B_indices),
F.to_dgl_nd(B_data))
return CSRMatrix(
F.from_dgl_nd(C_data),
F.from_dgl_nd(C_indices),
F.from_dgl_nd(C_indptr),
A.shape[0],
B.shape[1])
def csrsum(As):
"""Sparse-sparse matrix summation.
This is an internal function whose interface is subject to changes.
Parameters
----------
As : List[dgl.sparse.CSRMatrix]
List of scipy sparse matrices in CSR format.
Returns
-------
dgl.sparse.CSRMatrix
The result
"""
A_indptr = [F.zerocopy_from_numpy(x.indptr) for x in As]
A_indices = [F.zerocopy_from_numpy(x.indices) for x in As]
A_data = [F.zerocopy_from_numpy(x.data) for x in As]
C_indptr, C_indices, C_data = _CAPI_DGLCSRSum(
As[0].shape[0], As[0].shape[1],
[F.to_dgl_nd(x) for x in A_indptr],
[F.to_dgl_nd(x) for x in A_indices],
[F.to_dgl_nd(x) for x in A_data])
return CSRMatrix(
F.from_dgl_nd(C_data),
F.from_dgl_nd(C_indices),
F.from_dgl_nd(C_indptr),
As[0].shape[0], As[0].shape[1])
def csrmask(A, B):
"""Sparse-sparse matrix masking operation that computes ``A[B != 0]``.
This is an internal function whose interface is subject to changes.
Parameters
----------
A : dgl.sparse.CSRMatrix
The left operand
B : dgl.sparse.CSRMatrix
The right operand
Returns
-------
Tensor
The result
"""
A_indptr = F.zerocopy_from_numpy(A.indptr)
A_indices = F.zerocopy_from_numpy(A.indices)
A_data = F.zerocopy_from_numpy(A.data)
B_indptr = F.zerocopy_from_numpy(B.indptr)
B_indices = F.zerocopy_from_numpy(B.indices)
B_data = _CAPI_DGLCSRMask(
A.shape[0], A.shape[1],
F.to_dgl_nd(A_indptr),
F.to_dgl_nd(A_indices),
F.to_dgl_nd(A_data),
F.to_dgl_nd(B_indptr),
F.to_dgl_nd(B_indices))
return F.from_dgl_nd(B_data)
_init_api("dgl.sparse") _init_api("dgl.sparse")
/*!
* Copyright (c) 2019 by Contributors
* \file array/cpu/csr_mask.cc
* \brief CSR Masking Operation
*/
#include <dgl/array.h>
#include <parallel_hashmap/phmap.h>
#include <vector>
#include "array_utils.h"
namespace dgl {
using dgl::runtime::NDArray;
namespace aten {
namespace {
// TODO(BarclayII): avoid using map for sorted CSRs
template <typename IdType, typename DType>
void ComputeValues(
const IdType* A_indptr,
const IdType* A_indices,
const IdType* A_eids,
const DType* A_data,
const IdType* B_indptr,
const IdType* B_indices,
const IdType* B_eids,
DType* C_data,
int64_t M) {
phmap::flat_hash_map<IdType, DType> map;
#pragma omp parallel for firstprivate(map)
for (IdType i = 0; i < M; ++i) {
map.clear();
for (IdType u = A_indptr[i]; u < A_indptr[i + 1]; ++u) {
IdType kA = A_indices[u];
map[kA] = A_data[A_eids ? A_eids[u] : u];
}
for (IdType v = B_indptr[i]; v < B_indptr[i + 1]; ++v) {
IdType kB = B_indices[v];
auto it = map.find(kB);
C_data[B_eids ? B_eids[v] : v] = (it != map.end()) ? it->second : 0;
}
}
}
}; // namespace
template <int XPU, typename IdType, typename DType>
NDArray CSRMask(
const CSRMatrix& A,
NDArray A_weights,
const CSRMatrix& B) {
CHECK_EQ(A.num_rows, B.num_rows) << "Number of rows must match.";
CHECK_EQ(A.num_cols, B.num_cols) << "Number of columns must match.";
const bool A_has_eid = !IsNullArray(A.data);
const bool B_has_eid = !IsNullArray(B.data);
const IdType* A_indptr = A.indptr.Ptr<IdType>();
const IdType* A_indices = A.indices.Ptr<IdType>();
const IdType* A_eids = A_has_eid ? A.data.Ptr<IdType>() : nullptr;
const IdType* B_indptr = B.indptr.Ptr<IdType>();
const IdType* B_indices = B.indices.Ptr<IdType>();
const IdType* B_eids = B_has_eid ? B.data.Ptr<IdType>() : nullptr;
const DType* A_data = A_weights.Ptr<DType>();
const int64_t M = A.num_rows;
const int64_t N = A.num_cols;
NDArray C_weights = NDArray::Empty({B.indices->shape[0]}, A_weights->dtype, A_weights->ctx);
DType* C_data = C_weights.Ptr<DType>();
ComputeValues(A_indptr, A_indices, A_eids, A_data, B_indptr, B_indices, B_eids, C_data, M);
return C_weights;
}
template NDArray CSRMask<kDLCPU, int32_t, float>(const CSRMatrix&, NDArray, const CSRMatrix&);
template NDArray CSRMask<kDLCPU, int64_t, float>(const CSRMatrix&, NDArray, const CSRMatrix&);
template NDArray CSRMask<kDLCPU, int32_t, double>(const CSRMatrix&, NDArray, const CSRMatrix&);
template NDArray CSRMask<kDLCPU, int64_t, double>(const CSRMatrix&, NDArray, const CSRMatrix&);
}; // namespace aten
}; // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file array/cpu/csr_mm.cc
* \brief CSR Matrix Multiplication
*/
#include <dgl/array.h>
#include <parallel_hashmap/phmap.h>
#include <vector>
#include "array_utils.h"
namespace dgl {
using dgl::runtime::NDArray;
namespace aten {
namespace {
// TODO(BarclayII): avoid using map for sorted CSRs
template <typename IdType>
void CountNNZPerRow(
const IdType* A_indptr,
const IdType* A_indices,
const IdType* B_indptr,
const IdType* B_indices,
IdType* C_indptr_data,
int64_t M) {
phmap::flat_hash_set<IdType> set;
#pragma omp parallel for firstprivate(set)
for (int64_t i = 0; i < M; ++i) {
set.clear();
for (IdType u = A_indptr[i]; u < A_indptr[i + 1]; ++u) {
IdType w = A_indices[u];
for (IdType v = B_indptr[w]; v < B_indptr[w + 1]; ++v)
set.insert(B_indices[v]);
}
C_indptr_data[i] = set.size();
}
}
template <typename IdType>
int64_t ComputeIndptrInPlace(IdType* C_indptr_data, int64_t M) {
int64_t nnz = 0;
IdType len = 0;
for (IdType i = 0; i < M; ++i) {
len = C_indptr_data[i];
C_indptr_data[i] = nnz;
nnz += len;
}
C_indptr_data[M] = nnz;
return nnz;
}
template <typename IdType, typename DType>
void ComputeIndicesAndData(
const IdType* A_indptr,
const IdType* A_indices,
const IdType* A_eids,
const DType* A_data,
const IdType* B_indptr,
const IdType* B_indices,
const IdType* B_eids,
const DType* B_data,
const IdType* C_indptr_data,
IdType* C_indices_data,
DType* C_weights_data,
int64_t M) {
phmap::flat_hash_map<IdType, DType> map;
#pragma omp parallel for firstprivate(map)
for (int64_t i = 0; i < M; ++i) {
map.clear();
for (IdType u = A_indptr[i]; u < A_indptr[i + 1]; ++u) {
IdType w = A_indices[u];
DType vA = A_data[A_eids ? A_eids[u] : u];
for (IdType v = B_indptr[w]; v < B_indptr[w + 1]; ++v) {
IdType t = B_indices[v];
DType vB = B_data[B_eids ? B_eids[v] : v];
map[t] += vA * vB;
}
}
IdType v = C_indptr_data[i];
for (auto it : map) {
C_indices_data[v] = it.first;
C_weights_data[v] = it.second;
++v;
}
}
}
}; // namespace
template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM(
const CSRMatrix& A,
NDArray A_weights,
const CSRMatrix& B,
NDArray B_weights) {
CHECK_EQ(A.num_cols, B.num_rows) << "A's number of columns must equal to B's number of rows";
const bool A_has_eid = !IsNullArray(A.data);
const bool B_has_eid = !IsNullArray(B.data);
const IdType* A_indptr = A.indptr.Ptr<IdType>();
const IdType* A_indices = A.indices.Ptr<IdType>();
const IdType* A_eids = A_has_eid ? A.data.Ptr<IdType>() : nullptr;
const IdType* B_indptr = B.indptr.Ptr<IdType>();
const IdType* B_indices = B.indices.Ptr<IdType>();
const IdType* B_eids = B_has_eid ? B.data.Ptr<IdType>() : nullptr;
const DType* A_data = A_weights.Ptr<DType>();
const DType* B_data = B_weights.Ptr<DType>();
const int64_t M = A.num_rows;
const int64_t P = B.num_cols;
IdArray C_indptr = IdArray::Empty({M + 1}, A.indptr->dtype, A.indptr->ctx);
IdType* C_indptr_data = C_indptr.Ptr<IdType>();
CountNNZPerRow<IdType>(A_indptr, A_indices, B_indptr, B_indices, C_indptr_data, M);
int64_t nnz = ComputeIndptrInPlace<IdType>(C_indptr_data, M);
// Allocate indices and weights array
IdArray C_indices = IdArray::Empty({nnz}, A.indices->dtype, A.indices->ctx);
NDArray C_weights = NDArray::Empty({nnz}, A_weights->dtype, A_weights->ctx);
IdType* C_indices_data = C_indices.Ptr<IdType>();
DType* C_weights_data = C_weights.Ptr<DType>();
ComputeIndicesAndData<IdType, DType>(
A_indptr, A_indices, A_eids, A_data,
B_indptr, B_indices, B_eids, B_data,
C_indptr_data, C_indices_data, C_weights_data, M);
return {CSRMatrix(M, P, C_indptr, C_indices), C_weights};
}
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int32_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int64_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int32_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int64_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
}; // namespace aten
}; // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file array/cpu/csr_sum.cc
* \brief CSR Summation
*/
#include <dgl/array.h>
#include <parallel_hashmap/phmap.h>
#include <vector>
#include "array_utils.h"
namespace dgl {
using dgl::runtime::NDArray;
namespace aten {
namespace {
// TODO(BarclayII): avoid using map for sorted CSRs
template <typename IdType>
void CountNNZPerRow(
const std::vector<const IdType*>& A_indptr,
const std::vector<const IdType*>& A_indices,
IdType* C_indptr_data,
int64_t M) {
int64_t n = A_indptr.size();
phmap::flat_hash_set<IdType> set;
#pragma omp parallel for firstprivate(set)
for (IdType i = 0; i < M; ++i) {
set.clear();
for (int64_t k = 0; k < n; ++k) {
for (IdType u = A_indptr[k][i]; u < A_indptr[k][i + 1]; ++u)
set.insert(A_indices[k][u]);
}
C_indptr_data[i] = set.size();
}
}
template <typename IdType>
int64_t ComputeIndptrInPlace(IdType* C_indptr_data, int64_t M) {
int64_t nnz = 0;
IdType len = 0;
for (IdType i = 0; i < M; ++i) {
len = C_indptr_data[i];
C_indptr_data[i] = nnz;
nnz += len;
}
C_indptr_data[M] = nnz;
return nnz;
}
template <typename IdType, typename DType>
void ComputeIndicesAndData(
const std::vector<const IdType*>& A_indptr,
const std::vector<const IdType*>& A_indices,
const std::vector<const IdType*>& A_eids,
const std::vector<const DType*>& A_data,
const IdType* C_indptr_data,
IdType* C_indices_data,
DType* C_weights_data,
int64_t M) {
int64_t n = A_indptr.size();
phmap::flat_hash_map<IdType, DType> map;
#pragma omp parallel for firstprivate(map)
for (int64_t i = 0; i < M; ++i) {
map.clear();
for (int64_t k = 0; k < n; ++k) {
for (IdType u = A_indptr[k][i]; u < A_indptr[k][i + 1]; ++u) {
IdType kA = A_indices[k][u];
DType vA = A_data[k][A_eids[k] ? A_eids[k][u] : u];
map[kA] += vA;
}
}
IdType j = C_indptr_data[i];
for (auto it : map) {
C_indices_data[j] = it.first;
C_weights_data[j] = it.second;
++j;
}
}
}
}; // namespace
template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& A,
const std::vector<NDArray>& A_weights) {
CHECK(A.size() > 0) << "List of matrices can't be empty.";
CHECK_EQ(A.size(), A_weights.size()) << "List of matrices and weights must have same length";
const int64_t M = A[0].num_rows;
const int64_t N = A[0].num_cols;
const int64_t n = A.size();
std::vector<bool> A_has_eid(n);
std::vector<const IdType*> A_indptr(n);
std::vector<const IdType*> A_indices(n);
std::vector<const IdType*> A_eids(n);
std::vector<const DType*> A_data(n);
for (int64_t i = 0; i < n; ++i) {
const CSRMatrix& csr = A[i];
const NDArray& data = A_weights[i];
A_has_eid[i] = !IsNullArray(csr.data);
A_indptr[i] = csr.indptr.Ptr<IdType>();
A_indices[i] = csr.indices.Ptr<IdType>();
A_eids[i] = A_has_eid[i] ? csr.data.Ptr<IdType>() : nullptr;
A_data[i] = data.Ptr<DType>();
}
IdArray C_indptr = IdArray::Empty({M + 1}, A[0].indptr->dtype, A[0].indptr->ctx);
IdType* C_indptr_data = C_indptr.Ptr<IdType>();
CountNNZPerRow<IdType>(A_indptr, A_indices, C_indptr_data, M);
IdType nnz = ComputeIndptrInPlace<IdType>(C_indptr_data, M);
// Allocate indices and weights array
IdArray C_indices = IdArray::Empty({nnz}, A[0].indices->dtype, A[0].indices->ctx);
NDArray C_weights = NDArray::Empty({nnz}, A_weights[0]->dtype, A_weights[0]->ctx);
IdType* C_indices_data = C_indices.Ptr<IdType>();
DType* C_weights_data = C_weights.Ptr<DType>();
ComputeIndicesAndData<IdType, DType>(
A_indptr, A_indices, A_eids, A_data,
C_indptr_data, C_indices_data, C_weights_data, M);
return {CSRMatrix(M, N, C_indptr, C_indices), C_weights};
}
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int32_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int64_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int32_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int64_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
}; // namespace aten
}; // namespace dgl
...@@ -30,7 +30,7 @@ void SpMM(const std::string& op, const std::string& reduce, ...@@ -30,7 +30,7 @@ void SpMM(const std::string& op, const std::string& reduce,
NDArray out, NDArray out,
std::vector<NDArray> out_aux) { std::vector<NDArray> out_aux) {
// TODO(zihao): format tuning // TODO(zihao): format tuning
SparseFormat format = graph->SelectFormat(0, csc_code); SparseFormat format = graph->SelectFormat(0, CSC_CODE);
const auto& bcast = CalcBcastOff(op, ufeat, efeat); const auto& bcast = CalcBcastOff(op, ufeat, efeat);
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", { ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", {
...@@ -61,7 +61,7 @@ void SDDMM(const std::string& op, ...@@ -61,7 +61,7 @@ void SDDMM(const std::string& op,
int lhs_target, int lhs_target,
int rhs_target) { int rhs_target) {
// TODO(zihao): format tuning // TODO(zihao): format tuning
SparseFormat format = graph->SelectFormat(0, coo_code); SparseFormat format = graph->SelectFormat(0, COO_CODE);
const auto &bcast = CalcBcastOff(op, lhs, rhs); const auto &bcast = CalcBcastOff(op, lhs, rhs);
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", { ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", {
...@@ -84,7 +84,7 @@ void SDDMM(const std::string& op, ...@@ -84,7 +84,7 @@ void SDDMM(const std::string& op,
} }
NDArray GetEdgeMapping(HeteroGraphRef graph) { NDArray GetEdgeMapping(HeteroGraphRef graph) {
SparseFormat format = graph->SelectFormat(0, csc_code); SparseFormat format = graph->SelectFormat(0, CSC_CODE);
if (format == SparseFormat::kCSC) { if (format == SparseFormat::kCSC) {
return graph.sptr()->GetCSCMatrix(0).data; return graph.sptr()->GetCSCMatrix(0).data;
} else { } else {
...@@ -129,6 +129,86 @@ void BackwardSegmentCmpDispatch(NDArray feat, NDArray arg, NDArray out) { ...@@ -129,6 +129,86 @@ void BackwardSegmentCmpDispatch(NDArray feat, NDArray arg, NDArray out) {
}); });
} }
std::pair<CSRMatrix, NDArray> CSRMM(
CSRMatrix A,
NDArray A_weights,
CSRMatrix B,
NDArray B_weights) {
CheckCtx(
A.indptr->ctx,
{A_weights, B_weights},
{"A's edge weights", "B's edge weights"});
CHECK_EQ(A.indptr->ctx, B.indptr->ctx) << "Device of two graphs must match.";
CHECK_EQ(A.indptr->dtype, B.indptr->dtype) << "ID types of two graphs must match.";
CHECK_EQ(A_weights->dtype, B_weights->dtype) << "Data types of two edge weights must match.";
std::pair<CSRMatrix, NDArray> ret;
// TODO(BarclayII): change to ATEN_XPU_SWITCH_CUDA once the GPU kernels are implemented
ATEN_XPU_SWITCH(A.indptr->ctx.device_type, XPU, "CSRMM", {
ATEN_ID_TYPE_SWITCH(A.indptr->dtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH(A_weights->dtype, DType, "Edge weights", {
ret = CSRMM<XPU, IdType, DType>(A, A_weights, B, B_weights);
});
});
});
return ret;
}
std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& A,
const std::vector<NDArray>& A_weights) {
CHECK(A.size() > 0) << "The list of graphs must not be empty.";
CHECK_EQ(A.size(), A_weights.size()) <<
"The list of edge weights must have the same length as the list of graphs.";
auto ctx = A[0].indptr->ctx;
auto idtype = A[0].indptr->dtype;
auto dtype = A_weights[0]->dtype;
for (size_t i = 0; i < A.size(); ++i) {
CHECK_EQ(A[i].indptr->ctx, ctx) << "The devices of all graphs must be equal.";
CHECK_EQ(A[i].indptr->dtype, idtype) << "The ID types of all graphs must be equal.";
CHECK_EQ(A[i].indices->shape[0], A_weights[i]->shape[0]) <<
"Shape of edge weights does not match the number of edges.";
CHECK_EQ(A_weights[i]->ctx, ctx) <<
"The devices of edge weights must be the same as that of the graphs.";
CHECK_EQ(A_weights[i]->dtype, dtype) <<
"The data types of all edge weights must be equal.";
}
std::pair<CSRMatrix, NDArray> ret;
// TODO(BarclayII): change to ATEN_XPU_SWITCH_CUDA once the GPU kernels are implemented
ATEN_XPU_SWITCH(ctx.device_type, XPU, "CSRSum", {
ATEN_ID_TYPE_SWITCH(idtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH(dtype, DType, "Edge weights", {
ret = CSRSum<XPU, IdType, DType>(A, A_weights);
});
});
});
return ret;
}
NDArray CSRMask(const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B) {
CHECK_EQ(A.indptr->ctx, A_weights->ctx) <<
"Device of the graph and the edge weights must match.";
CHECK_EQ(A.indptr->ctx, B.indptr->ctx) << "Device of two graphs must match.";
CHECK_EQ(A.indptr->dtype, B.indptr->dtype) << "ID types of two graphs must match.";
CHECK_EQ(A_weights->shape[0], A.indices->shape[0]) <<
"Shape of edge weights does not match the number of edges.";
auto ctx = A.indptr->ctx;
auto idtype = A.indptr->dtype;
auto dtype = A_weights->dtype;
NDArray ret;
// TODO(BarclayII): change to ATEN_XPU_SWITCH_CUDA once the GPU kernels are implemented
ATEN_XPU_SWITCH(ctx.device_type, XPU, "CSRMask", {
ATEN_ID_TYPE_SWITCH(idtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH(dtype, DType, "Edge weights", {
ret = CSRMask<XPU, IdType, DType>(A, A_weights, B);
});
});
});
return ret;
}
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
...@@ -216,6 +296,64 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGetEdgeMapping") ...@@ -216,6 +296,64 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGetEdgeMapping")
*rv = GetEdgeMapping(graph); *rv = GetEdgeMapping(graph);
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int M = args[0];
int N = args[1];
int P = args[2];
NDArray A_indptr = args[3];
NDArray A_indices = args[4];
NDArray A_data = args[5];
NDArray B_indptr = args[6];
NDArray B_indices = args[7];
NDArray B_data = args[8];
auto result = CSRMM(
CSRMatrix(M, N, A_indptr, A_indices),
A_data,
CSRMatrix(N, P, B_indptr, B_indices),
B_data);
List<Value> ret;
ret.push_back(Value(MakeValue(result.first.indptr)));
ret.push_back(Value(MakeValue(result.first.indices)));
ret.push_back(Value(MakeValue(result.second)));
*rv = ret;
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRSum")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int M = args[0];
int N = args[1];
List<Value> A_indptr = args[2];
List<Value> A_indices = args[3];
List<Value> A_data = args[4];
std::vector<NDArray> weights = ListValueToVector<NDArray>(A_data);
std::vector<CSRMatrix> mats(A_indptr.size());
for (int i = 0; i < A_indptr.size(); ++i)
mats[i] = CSRMatrix(M, N, A_indptr[i]->data, A_indices[i]->data);
auto result = CSRSum(mats, weights);
List<Value> ret;
ret.push_back(Value(MakeValue(result.first.indptr)));
ret.push_back(Value(MakeValue(result.first.indices)));
ret.push_back(Value(MakeValue(result.second)));
*rv = ret;
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRMask")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int M = args[0];
int N = args[1];
NDArray A_indptr = args[2];
NDArray A_indices = args[3];
NDArray A_data = args[4];
NDArray B_indptr = args[5];
NDArray B_indices = args[6];
auto result = CSRMask(
CSRMatrix(M, N, A_indptr, A_indices),
A_data,
CSRMatrix(M, N, B_indptr, B_indices));
*rv = result;
});
#ifdef USE_TVM #ifdef USE_TVM
DGL_REGISTER_GLOBAL("sparse._CAPI_FG_LoadModule") DGL_REGISTER_GLOBAL("sparse._CAPI_FG_LoadModule")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -92,6 +93,32 @@ void BackwardSegmentCmp(NDArray feat, ...@@ -92,6 +93,32 @@ void BackwardSegmentCmp(NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
/*!
* \brief Sparse-sparse matrix multiplication
*
* \note B is transposed (i.e. in CSC format).
*/
template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM(
const CSRMatrix& A,
NDArray A_weights,
const CSRMatrix& B,
NDArray B_weights);
/*!
* \brief Sparse-sparse matrix summation.
*/
template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& A,
const std::vector<NDArray>& A_weights);
/*!
* \brief Return a sparse matrix with the values of A but nonzero entry locations of B.
*/
template <int XPU, typename IdType, typename DType>
NDArray CSRMask(const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -23,7 +23,7 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) { ...@@ -23,7 +23,7 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph())); strm->Write(ImmutableGraph::ToImmutable(graph->meta_graph()));
strm->Write(graph->NumVerticesPerType()); strm->Write(graph->NumVerticesPerType());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
SparseFormat fmt = graph->SelectFormat(etype, all_code); SparseFormat fmt = graph->SelectFormat(etype, ALL_CODE);
switch (fmt) { switch (fmt) {
case SparseFormat::kCOO: { case SparseFormat::kCOO: {
strm->Write(SparseFormat::kCOO); strm->Write(SparseFormat::kCOO);
...@@ -84,7 +84,7 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) { ...@@ -84,7 +84,7 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'"; CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
auto coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted); auto coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
// TODO(zihao) fix // TODO(zihao) fix
relgraph = CreateFromCOO(num_vtypes, coo, all_code); relgraph = CreateFromCOO(num_vtypes, coo, ALL_CODE);
break; break;
} }
case SparseFormat::kCSR: { case SparseFormat::kCSR: {
...@@ -96,7 +96,7 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) { ...@@ -96,7 +96,7 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'"; CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
auto csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted); auto csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
// TODO(zihao) fix // TODO(zihao) fix
relgraph = CreateFromCSR(num_vtypes, csr, all_code); relgraph = CreateFromCSR(num_vtypes, csr, ALL_CODE);
break; break;
} }
case SparseFormat::kCSC: case SparseFormat::kCSC:
......
...@@ -62,7 +62,7 @@ HeteroSubgraph SampleNeighbors( ...@@ -62,7 +62,7 @@ HeteroSubgraph SampleNeighbors(
induced_edges[etype] = earr.id; induced_edges[etype] = earr.id;
} else { } else {
// sample from one relation graph // sample from one relation graph
auto req_fmt = (dir == EdgeDir::kOut)? csr_code : csc_code; auto req_fmt = (dir == EdgeDir::kOut)? CSR_CODE : CSC_CODE;
auto avail_fmt = hg->SelectFormat(etype, req_fmt); auto avail_fmt = hg->SelectFormat(etype, req_fmt);
COOMatrix sampled_coo; COOMatrix sampled_coo;
switch (avail_fmt) { switch (avail_fmt) {
...@@ -148,7 +148,7 @@ HeteroSubgraph SampleNeighborsTopk( ...@@ -148,7 +148,7 @@ HeteroSubgraph SampleNeighborsTopk(
induced_edges[etype] = earr.id; induced_edges[etype] = earr.id;
} else { } else {
// sample from one relation graph // sample from one relation graph
auto req_fmt = (dir == EdgeDir::kOut)? csr_code : csc_code; auto req_fmt = (dir == EdgeDir::kOut)? CSR_CODE : CSC_CODE;
auto avail_fmt = hg->SelectFormat(etype, req_fmt); auto avail_fmt = hg->SelectFormat(etype, req_fmt);
COOMatrix sampled_coo; COOMatrix sampled_coo;
switch (avail_fmt) { switch (avail_fmt) {
......
...@@ -33,11 +33,11 @@ class HaloHeteroSubgraph : public HeteroSubgraph { ...@@ -33,11 +33,11 @@ class HaloHeteroSubgraph : public HeteroSubgraph {
HeteroGraphPtr ReorderUnitGraph(UnitGraphPtr ug, IdArray new_order) { HeteroGraphPtr ReorderUnitGraph(UnitGraphPtr ug, IdArray new_order) {
auto format = ug->GetCreatedFormats(); auto format = ug->GetCreatedFormats();
// We only need to reorder one of the graph structure. // We only need to reorder one of the graph structure.
if (format & csc_code) { if (format & CSC_CODE) {
auto cscmat = ug->GetCSCMatrix(0); auto cscmat = ug->GetCSCMatrix(0);
auto new_cscmat = aten::CSRReorder(cscmat, new_order, new_order); auto new_cscmat = aten::CSRReorder(cscmat, new_order, new_order);
return UnitGraph::CreateFromCSC(ug->NumVertexTypes(), new_cscmat, ug->GetAllowedFormats()); return UnitGraph::CreateFromCSC(ug->NumVertexTypes(), new_cscmat, ug->GetAllowedFormats());
} else if (format & csr_code) { } else if (format & CSR_CODE) {
auto csrmat = ug->GetCSRMatrix(0); auto csrmat = ug->GetCSRMatrix(0);
auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order); auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order);
return UnitGraph::CreateFromCSR(ug->NumVertexTypes(), new_csrmat, ug->GetAllowedFormats()); return UnitGraph::CreateFromCSR(ug->NumVertexTypes(), new_csrmat, ug->GetAllowedFormats());
......
...@@ -28,7 +28,7 @@ RemoveEdges(const HeteroGraphPtr graph, const std::vector<IdArray> &eids) { ...@@ -28,7 +28,7 @@ RemoveEdges(const HeteroGraphPtr graph, const std::vector<IdArray> &eids) {
const int64_t num_etypes = graph->NumEdgeTypes(); const int64_t num_etypes = graph->NumEdgeTypes();
for (int64_t etype = 0; etype < num_etypes; ++etype) { for (int64_t etype = 0; etype < num_etypes; ++etype) {
const SparseFormat fmt = graph->SelectFormat(etype, coo_code); const SparseFormat fmt = graph->SelectFormat(etype, COO_CODE);
const auto src_dst_types = graph->GetEndpointTypes(etype); const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t srctype = src_dst_types.first; const dgl_type_t srctype = src_dst_types.first;
const dgl_type_t dsttype = src_dst_types.second; const dgl_type_t dsttype = src_dst_types.second;
......
...@@ -827,13 +827,13 @@ uint8_t UnitGraph::NumBits() const { ...@@ -827,13 +827,13 @@ uint8_t UnitGraph::NumBits() const {
} }
bool UnitGraph::IsMultigraph() const { bool UnitGraph::IsMultigraph() const {
const SparseFormat fmt = SelectFormat(csc_code); const SparseFormat fmt = SelectFormat(CSC_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->IsMultigraph(); return ptr->IsMultigraph();
} }
uint64_t UnitGraph::NumVertices(dgl_type_t vtype) const { uint64_t UnitGraph::NumVertices(dgl_type_t vtype) const {
const SparseFormat fmt = SelectFormat(all_code); const SparseFormat fmt = SelectFormat(ALL_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
// TODO(BarclayII): we have a lot of special handling for CSC. // TODO(BarclayII): we have a lot of special handling for CSC.
// Need to have a UnitGraph::CSC backend instead. // Need to have a UnitGraph::CSC backend instead.
...@@ -847,7 +847,7 @@ uint64_t UnitGraph::NumEdges(dgl_type_t etype) const { ...@@ -847,7 +847,7 @@ uint64_t UnitGraph::NumEdges(dgl_type_t etype) const {
} }
bool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const { bool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const {
const SparseFormat fmt = SelectFormat(all_code); const SparseFormat fmt = SelectFormat(ALL_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) if (fmt == SparseFormat::kCSC)
vtype = (vtype == SrcType()) ? DstType() : SrcType(); vtype = (vtype == SrcType()) ? DstType() : SrcType();
...@@ -860,7 +860,7 @@ BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const { ...@@ -860,7 +860,7 @@ BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
} }
bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const { bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
const SparseFormat fmt = SelectFormat(csc_code); const SparseFormat fmt = SelectFormat(CSC_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) if (fmt == SparseFormat::kCSC)
return ptr->HasEdgeBetween(etype, dst, src); return ptr->HasEdgeBetween(etype, dst, src);
...@@ -870,7 +870,7 @@ bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) con ...@@ -870,7 +870,7 @@ bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) con
BoolArray UnitGraph::HasEdgesBetween( BoolArray UnitGraph::HasEdgesBetween(
dgl_type_t etype, IdArray src, IdArray dst) const { dgl_type_t etype, IdArray src, IdArray dst) const {
const SparseFormat fmt = SelectFormat(csc_code); const SparseFormat fmt = SelectFormat(CSC_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) if (fmt == SparseFormat::kCSC)
return ptr->HasEdgesBetween(etype, dst, src); return ptr->HasEdgesBetween(etype, dst, src);
...@@ -879,7 +879,7 @@ BoolArray UnitGraph::HasEdgesBetween( ...@@ -879,7 +879,7 @@ BoolArray UnitGraph::HasEdgesBetween(
} }
IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const { IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const {
const SparseFormat fmt = SelectFormat(csc_code); const SparseFormat fmt = SelectFormat(CSC_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) if (fmt == SparseFormat::kCSC)
return ptr->Successors(etype, dst); return ptr->Successors(etype, dst);
...@@ -888,13 +888,13 @@ IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const { ...@@ -888,13 +888,13 @@ IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const {
} }
IdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const { IdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const {
const SparseFormat fmt = SelectFormat(csr_code); const SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->Successors(etype, src); return ptr->Successors(etype, src);
} }
IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const { IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
const SparseFormat fmt = SelectFormat(csr_code); const SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) if (fmt == SparseFormat::kCSC)
return ptr->EdgeId(etype, dst, src); return ptr->EdgeId(etype, dst, src);
...@@ -903,7 +903,7 @@ IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const { ...@@ -903,7 +903,7 @@ IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
} }
EdgeArray UnitGraph::EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const { EdgeArray UnitGraph::EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const {
const SparseFormat fmt = SelectFormat(csr_code); const SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) { if (fmt == SparseFormat::kCSC) {
EdgeArray edges = ptr->EdgeIdsAll(etype, dst, src); EdgeArray edges = ptr->EdgeIdsAll(etype, dst, src);
...@@ -914,7 +914,7 @@ EdgeArray UnitGraph::EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) cons ...@@ -914,7 +914,7 @@ EdgeArray UnitGraph::EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) cons
} }
IdArray UnitGraph::EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const { IdArray UnitGraph::EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const {
const SparseFormat fmt = SelectFormat(csr_code); const SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) { if (fmt == SparseFormat::kCSC) {
return ptr->EdgeIdsOne(etype, dst, src); return ptr->EdgeIdsOne(etype, dst, src);
...@@ -924,19 +924,19 @@ IdArray UnitGraph::EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const ...@@ -924,19 +924,19 @@ IdArray UnitGraph::EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const
} }
std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(dgl_type_t etype, dgl_id_t eid) const { std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(dgl_type_t etype, dgl_id_t eid) const {
const SparseFormat fmt = SelectFormat(coo_code); const SparseFormat fmt = SelectFormat(COO_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->FindEdge(etype, eid); return ptr->FindEdge(etype, eid);
} }
EdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const { EdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const {
const SparseFormat fmt = SelectFormat(coo_code); const SparseFormat fmt = SelectFormat(COO_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->FindEdges(etype, eids); return ptr->FindEdges(etype, eids);
} }
EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const { EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {
const SparseFormat fmt = SelectFormat(csc_code); const SparseFormat fmt = SelectFormat(CSC_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) { if (fmt == SparseFormat::kCSC) {
const EdgeArray& ret = ptr->OutEdges(etype, vid); const EdgeArray& ret = ptr->OutEdges(etype, vid);
...@@ -947,7 +947,7 @@ EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const { ...@@ -947,7 +947,7 @@ EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {
} }
EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const { EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {
const SparseFormat fmt = SelectFormat(csc_code); const SparseFormat fmt = SelectFormat(CSC_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) { if (fmt == SparseFormat::kCSC) {
const EdgeArray& ret = ptr->OutEdges(etype, vids); const EdgeArray& ret = ptr->OutEdges(etype, vids);
...@@ -958,13 +958,13 @@ EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const { ...@@ -958,13 +958,13 @@ EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {
} }
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const { EdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const {
const SparseFormat fmt = SelectFormat(csr_code); const SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->OutEdges(etype, vid); return ptr->OutEdges(etype, vid);
} }
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const { EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
const SparseFormat fmt = SelectFormat(csr_code); const SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->OutEdges(etype, vids); return ptr->OutEdges(etype, vids);
} }
...@@ -972,12 +972,12 @@ EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const { ...@@ -972,12 +972,12 @@ EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const { EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const {
SparseFormat fmt; SparseFormat fmt;
if (order == std::string("eid")) { if (order == std::string("eid")) {
fmt = SelectFormat(coo_code); fmt = SelectFormat(COO_CODE);
} else if (order.empty()) { } else if (order.empty()) {
// arbitrary order // arbitrary order
fmt = SelectFormat(all_code); fmt = SelectFormat(ALL_CODE);
} else if (order == std::string("srcdst")) { } else if (order == std::string("srcdst")) {
fmt = SelectFormat(csr_code); fmt = SelectFormat(CSR_CODE);
} else { } else {
LOG(FATAL) << "Unsupported order request: " << order; LOG(FATAL) << "Unsupported order request: " << order;
return {}; return {};
...@@ -991,7 +991,7 @@ EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const { ...@@ -991,7 +991,7 @@ EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const {
} }
uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const { uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(csc_code); SparseFormat fmt = SelectFormat(CSC_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) if (fmt == SparseFormat::kCSC)
return ptr->OutDegree(etype, vid); return ptr->OutDegree(etype, vid);
...@@ -1000,7 +1000,7 @@ uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const { ...@@ -1000,7 +1000,7 @@ uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const {
} }
DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const { DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const {
SparseFormat fmt = SelectFormat(csc_code); SparseFormat fmt = SelectFormat(CSC_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) if (fmt == SparseFormat::kCSC)
return ptr->OutDegrees(etype, vids); return ptr->OutDegrees(etype, vids);
...@@ -1009,38 +1009,38 @@ DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const { ...@@ -1009,38 +1009,38 @@ DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const {
} }
uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const { uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(csr_code); SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->OutDegree(etype, vid); return ptr->OutDegree(etype, vid);
} }
DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const { DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const {
SparseFormat fmt = SelectFormat(csr_code); SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->OutDegrees(etype, vids); return ptr->OutDegrees(etype, vids);
} }
DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const { DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(csr_code); SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->SuccVec(etype, vid); return ptr->SuccVec(etype, vid);
} }
DGLIdIters32 UnitGraph::SuccVec32(dgl_type_t etype, dgl_id_t vid) const { DGLIdIters32 UnitGraph::SuccVec32(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(csr_code); SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = std::dynamic_pointer_cast<CSR>(GetFormat(fmt)); const auto ptr = std::dynamic_pointer_cast<CSR>(GetFormat(fmt));
CHECK_NOTNULL(ptr); CHECK_NOTNULL(ptr);
return ptr->SuccVec32(etype, vid); return ptr->SuccVec32(etype, vid);
} }
DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const { DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(csr_code); SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->OutEdgeVec(etype, vid); return ptr->OutEdgeVec(etype, vid);
} }
DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const { DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(csc_code); SparseFormat fmt = SelectFormat(CSC_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) if (fmt == SparseFormat::kCSC)
return ptr->SuccVec(etype, vid); return ptr->SuccVec(etype, vid);
...@@ -1049,7 +1049,7 @@ DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const { ...@@ -1049,7 +1049,7 @@ DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const {
} }
DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const { DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(csc_code); SparseFormat fmt = SelectFormat(CSC_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) if (fmt == SparseFormat::kCSC)
return ptr->OutEdgeVec(etype, vid); return ptr->OutEdgeVec(etype, vid);
...@@ -1079,7 +1079,7 @@ std::vector<IdArray> UnitGraph::GetAdj( ...@@ -1079,7 +1079,7 @@ std::vector<IdArray> UnitGraph::GetAdj(
HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const { HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const {
// We prefer to generate a subgraph from out-csr. // We prefer to generate a subgraph from out-csr.
SparseFormat fmt = SelectFormat(csr_code); SparseFormat fmt = SelectFormat(CSR_CODE);
HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids); HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
HeteroSubgraph ret; HeteroSubgraph ret;
...@@ -1109,7 +1109,7 @@ HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const ...@@ -1109,7 +1109,7 @@ HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const
HeteroSubgraph UnitGraph::EdgeSubgraph( HeteroSubgraph UnitGraph::EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes) const { const std::vector<IdArray>& eids, bool preserve_nodes) const {
SparseFormat fmt = SelectFormat(coo_code); SparseFormat fmt = SelectFormat(COO_CODE);
auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes); auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes);
HeteroSubgraph ret; HeteroSubgraph ret;
...@@ -1306,7 +1306,7 @@ HeteroGraphPtr UnitGraph::CreateHomographFrom( ...@@ -1306,7 +1306,7 @@ HeteroGraphPtr UnitGraph::CreateHomographFrom(
UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const { UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
if (inplace) if (inplace)
if (!(formats_ & csc_code)) if (!(formats_ & CSC_CODE))
LOG(FATAL) << "The graph have restricted sparse format " << LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot create CSC matrix."; CodeToStr(formats_) << ", cannot create CSC matrix.";
CSRPtr ret = in_csr_; CSRPtr ret = in_csr_;
...@@ -1335,7 +1335,7 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const { ...@@ -1335,7 +1335,7 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
/* !\brief Return out csr. If not exist, transpose the other one.*/ /* !\brief Return out csr. If not exist, transpose the other one.*/
UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const { UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
if (inplace) if (inplace)
if (!(formats_ & csr_code)) if (!(formats_ & CSR_CODE))
LOG(FATAL) << "The graph have restricted sparse format " << LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot create CSR matrix."; CodeToStr(formats_) << ", cannot create CSR matrix.";
CSRPtr ret = out_csr_; CSRPtr ret = out_csr_;
...@@ -1363,7 +1363,7 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const { ...@@ -1363,7 +1363,7 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
/* !\brief Return coo. If not exist, create from csr.*/ /* !\brief Return coo. If not exist, create from csr.*/
UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const { UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
if (inplace) if (inplace)
if (!(formats_ & coo_code)) if (!(formats_ & COO_CODE))
LOG(FATAL) << "The graph have restricted sparse format " << LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot create COO matrix."; CodeToStr(formats_) << ", cannot create COO matrix.";
COOPtr ret = coo_; COOPtr ret = coo_;
...@@ -1413,11 +1413,11 @@ HeteroGraphPtr UnitGraph::GetAny() const { ...@@ -1413,11 +1413,11 @@ HeteroGraphPtr UnitGraph::GetAny() const {
dgl_format_code_t UnitGraph::GetCreatedFormats() const { dgl_format_code_t UnitGraph::GetCreatedFormats() const {
dgl_format_code_t ret = 0; dgl_format_code_t ret = 0;
if (in_csr_->defined()) if (in_csr_->defined())
ret |= csc_code; ret |= CSC_CODE;
if (out_csr_->defined()) if (out_csr_->defined())
ret |= csr_code; ret |= CSR_CODE;
if (coo_->defined()) if (coo_->defined())
ret |= coo_code; ret |= COO_CODE;
return ret; return ret;
} }
...@@ -1437,7 +1437,7 @@ HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const { ...@@ -1437,7 +1437,7 @@ HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
} }
HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const { HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const {
if (formats == all_code) if (formats == ALL_CODE)
return HeteroGraphPtr( return HeteroGraphPtr(
// TODO(xiangsx) Make it as graph storage.Clone() // TODO(xiangsx) Make it as graph storage.Clone()
new UnitGraph(meta_graph_, new UnitGraph(meta_graph_,
...@@ -1452,9 +1452,9 @@ HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const { ...@@ -1452,9 +1452,9 @@ HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const {
: nullptr, : nullptr,
formats)); formats));
int64_t num_vtypes = NumVertexTypes(); int64_t num_vtypes = NumVertexTypes();
if (formats & coo_code) if (formats & COO_CODE)
return CreateFromCOO(num_vtypes, GetCOO(false)->adj(), formats); return CreateFromCOO(num_vtypes, GetCOO(false)->adj(), formats);
if (formats & csr_code) if (formats & CSR_CODE)
return CreateFromCSR(num_vtypes, GetOutCSR(false)->adj(), formats); return CreateFromCSR(num_vtypes, GetOutCSR(false)->adj(), formats);
return CreateFromCSC(num_vtypes, GetInCSR(false)->adj(), formats); return CreateFromCSC(num_vtypes, GetInCSR(false)->adj(), formats);
} }
...@@ -1501,7 +1501,7 @@ GraphPtr UnitGraph::AsImmutableGraph() const { ...@@ -1501,7 +1501,7 @@ GraphPtr UnitGraph::AsImmutableGraph() const {
HeteroGraphPtr UnitGraph::LineGraph(bool backtracking) const { HeteroGraphPtr UnitGraph::LineGraph(bool backtracking) const {
// TODO(xiangsx) currently we only support homogeneous graph // TODO(xiangsx) currently we only support homogeneous graph
auto fmt = SelectFormat(all_code); auto fmt = SelectFormat(ALL_CODE);
switch (fmt) { switch (fmt) {
case SparseFormat::kCOO: { case SparseFormat::kCOO: {
return CreateFromCOO(1, aten::COOLineGraph(coo_->adj(), backtracking)); return CreateFromCOO(1, aten::COOLineGraph(coo_->adj(), backtracking));
...@@ -1541,16 +1541,16 @@ bool UnitGraph::Load(dmlc::Stream* fs) { ...@@ -1541,16 +1541,16 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
// NOTE(zihao): to be compatible with old formats. // NOTE(zihao): to be compatible with old formats.
switch (formats_code & 0xffffffff) { switch (formats_code & 0xffffffff) {
case 0: case 0:
formats_ = all_code; formats_ = ALL_CODE;
break; break;
case 1: case 1:
formats_ = coo_code; formats_ = COO_CODE;
break; break;
case 2: case 2:
formats_ = csr_code; formats_ = CSR_CODE;
break; break;
case 3: case 3:
formats_ = csc_code; formats_ = CSC_CODE;
break; break;
default: default:
LOG(FATAL) << "Load graph failed, formats code " << formats_code << LOG(FATAL) << "Load graph failed, formats code " << formats_code <<
...@@ -1593,7 +1593,7 @@ void UnitGraph::Save(dmlc::Stream* fs) const { ...@@ -1593,7 +1593,7 @@ void UnitGraph::Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_UnitGraphMagic); fs->Write(kDGLSerialize_UnitGraphMagic);
// Didn't write UnitGraph::meta_graph_, since it's included in the underlying // Didn't write UnitGraph::meta_graph_, since it's included in the underlying
// sparse matrix // sparse matrix
auto avail_fmt = SelectFormat(all_code); auto avail_fmt = SelectFormat(ALL_CODE);
fs->Write(static_cast<int64_t>(avail_fmt)); fs->Write(static_cast<int64_t>(avail_fmt));
fs->Write(static_cast<int64_t>(formats_ | 0x100000000)); fs->Write(static_cast<int64_t>(formats_ | 0x100000000));
switch (avail_fmt) { switch (avail_fmt) {
...@@ -1629,7 +1629,7 @@ UnitGraph::ToSimple() const { ...@@ -1629,7 +1629,7 @@ UnitGraph::ToSimple() const {
IdArray count; IdArray count;
IdArray edge_map; IdArray edge_map;
auto avail_fmt = SelectFormat(all_code); auto avail_fmt = SelectFormat(ALL_CODE);
switch (avail_fmt) { switch (avail_fmt) {
case SparseFormat::kCOO: { case SparseFormat::kCOO: {
auto ret = aten::COOToSimple(GetCOO()->adj()); auto ret = aten::COOToSimple(GetCOO()->adj());
......
...@@ -174,31 +174,31 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -174,31 +174,31 @@ class UnitGraph : public BaseHeteroGraph {
/*! \brief Create a graph from COO arrays */ /*! \brief Create a graph from COO arrays */
static HeteroGraphPtr CreateFromCOO( static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, dgl_format_code_t formats = all_code); IdArray row, IdArray col, dgl_format_code_t formats = ALL_CODE);
static HeteroGraphPtr CreateFromCOO( static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat, int64_t num_vtypes, const aten::COOMatrix& mat,
dgl_format_code_t formats = all_code); dgl_format_code_t formats = ALL_CODE);
/*! \brief Create a graph from (out) CSR arrays */ /*! \brief Create a graph from (out) CSR arrays */
static HeteroGraphPtr CreateFromCSR( static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indptr, IdArray indices, IdArray edge_ids,
dgl_format_code_t formats = all_code); dgl_format_code_t formats = ALL_CODE);
static HeteroGraphPtr CreateFromCSR( static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats = all_code); dgl_format_code_t formats = ALL_CODE);
/*! \brief Create a graph from (in) CSC arrays */ /*! \brief Create a graph from (in) CSC arrays */
static HeteroGraphPtr CreateFromCSC( static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indptr, IdArray indices, IdArray edge_ids,
dgl_format_code_t formats = all_code); dgl_format_code_t formats = ALL_CODE);
static HeteroGraphPtr CreateFromCSC( static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats = all_code); dgl_format_code_t formats = ALL_CODE);
/*! \brief Convert the graph to use the given number of bits for storage */ /*! \brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
...@@ -298,7 +298,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -298,7 +298,7 @@ class UnitGraph : public BaseHeteroGraph {
* \param coo coo * \param coo coo
*/ */
UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
dgl_format_code_t formats = all_code); dgl_format_code_t formats = ALL_CODE);
/*! /*!
* \brief constructor * \brief constructor
...@@ -317,7 +317,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -317,7 +317,7 @@ class UnitGraph : public BaseHeteroGraph {
bool has_in_csr, bool has_in_csr,
bool has_out_csr, bool has_out_csr,
bool has_coo, bool has_coo,
dgl_format_code_t formats = all_code); dgl_format_code_t formats = ALL_CODE);
/*! \return Return any existing format. */ /*! \return Return any existing format. */
HeteroGraphPtr GetAny() const; HeteroGraphPtr GetAny() const;
......
#include <gtest/gtest.h>
#include <dgl/array.h>
#include <dgl/kernel.h>
#include "../../src/array/cpu/array_utils.h" // PairHash
#include "./common.h"
using namespace dgl;
using namespace dgl::runtime;
namespace {
// Unit tests:
// CSRMM(A, B) == A_mm_B
// CSRSum({A, C}) == A_plus_C
// CSRMask(A, C) = A_mask_C
template <typename IdType, typename DType>
std::unordered_map<std::pair<IdType, IdType>, DType, aten::PairHash> COOToMap(
aten::COOMatrix coo, NDArray weights) {
std::unordered_map<std::pair<IdType, IdType>, DType, aten::PairHash> map;
for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
IdType irow = aten::IndexSelect<IdType>(coo.row, i);
IdType icol = aten::IndexSelect<IdType>(coo.col, i);
IdType ieid = aten::COOHasData(coo) ? aten::IndexSelect<IdType>(coo.data, i) : i;
DType idata = aten::IndexSelect<DType>(weights, ieid);
map.insert({{irow, icol}, idata});
}
return map;
}
template <typename IdType, typename DType>
bool CSRIsClose(
aten::CSRMatrix A,
aten::CSRMatrix B,
NDArray A_weights,
NDArray B_weights,
DType rtol,
DType atol) {
auto Amap = COOToMap<IdType, DType>(CSRToCOO(A, false), A_weights);
auto Bmap = COOToMap<IdType, DType>(CSRToCOO(B, false), B_weights);
if (Amap.size() != Bmap.size())
return false;
for (auto itA : Amap) {
auto itB = Bmap.find(itA.first);
if (itB == Bmap.end())
return false;
if (fabs(itA.second - itB->second) >= rtol * fabs(itA.second) + atol)
return false;
}
return true;
}
template <typename IdType, typename DType>
std::pair<aten::CSRMatrix, NDArray> CSR_A(DLContext ctx = CTX) {
// matrix([[0. , 0. , 1. , 0.7, 0. ],
// [0. , 0. , 0.5, 0.+, 0. ],
// [0.4, 0.7, 0. , 0.2, 0. ],
// [0. , 0. , 0. , 0. , 0.2]])
// (0.+ indicates that the entry exists but the value is 0.)
auto csr = aten::CSRMatrix(
4, 5,
NDArray::FromVector(std::vector<IdType>({0, 2, 4, 7, 8}), ctx),
NDArray::FromVector(std::vector<IdType>({2, 3, 2, 3, 0, 1, 3, 4}), ctx));
auto weights = NDArray::FromVector(
std::vector<DType>({1.0, 0.7, 0.5, 0.0, 0.4, 0.7, 0.2, 0.2}), ctx);
return {csr, weights};
}
template <typename IdType, typename DType>
std::pair<aten::CSRMatrix, NDArray> CSR_B(DLContext ctx = CTX) {
// matrix([[0. , 0.9, 0. , 0.6, 0. , 0.3],
// [0. , 0. , 0. , 0. , 0. , 0.4],
// [0.+, 0. , 0. , 0. , 0. , 0.9],
// [0.8, 0.2, 0.3, 0.2, 0. , 0. ],
// [0.2, 0.4, 0. , 0. , 0. , 0. ]])
// (0.+ indicates that the entry exists but the value is 0.)
auto csr = aten::CSRMatrix(
5, 6,
NDArray::FromVector(std::vector<IdType>({0, 3, 4, 6, 10, 12}), ctx),
NDArray::FromVector(std::vector<IdType>({1, 3, 5, 5, 0, 5, 0, 1, 2, 3, 0, 1}), ctx));
auto weights = NDArray::FromVector(
std::vector<DType>({0.9, 0.6, 0.3, 0.4, 0.0, 0.9, 0.8, 0.2, 0.3, 0.2, 0.2, 0.4}), ctx);
return {csr, weights};
}
template <typename IdType, typename DType>
std::pair<aten::CSRMatrix, NDArray> CSR_C(DLContext ctx = CTX) {
// matrix([[0. , 0. , 0. , 0.2, 0. ],
// [0. , 0. , 0. , 0.5, 0.4],
// [0. , 0.2, 0. , 0.9, 0.2],
// [0. , 1. , 0. , 0.7, 0. ]])
auto csr = aten::CSRMatrix(
4, 5,
NDArray::FromVector(std::vector<IdType>({0, 1, 3, 6, 8}), ctx),
NDArray::FromVector(std::vector<IdType>({3, 3, 4, 1, 3, 4, 1, 3}), ctx));
auto weights = NDArray::FromVector(
std::vector<DType>({0.2, 0.5, 0.4, 0.2, 0.9, 0.2, 1. , 0.7}), ctx);
return {csr, weights};
}
template <typename IdType, typename DType>
std::pair<aten::CSRMatrix, NDArray> CSR_A_mm_B(DLContext ctx = CTX) {
// matrix([[0.56, 0.14, 0.21, 0.14, 0. , 0.9 ],
// [0.+ , 0.+ , 0.+ , 0.+ , 0. , 0.45],
// [0.16, 0.4 , 0.06, 0.28, 0. , 0.4 ],
// [0.04, 0.08, 0. , 0. , 0. , 0. ]])
// (0.+ indicates that the entry exists but the value is 0.)
auto csr = aten::CSRMatrix(
4, 6,
NDArray::FromVector(std::vector<IdType>({0, 5, 10, 15, 17}), ctx),
NDArray::FromVector(std::vector<IdType>(
{0, 1, 2, 3, 5, 0, 1, 2, 3, 5, 0, 1, 2, 3, 5, 0, 1}), ctx));
auto weights = NDArray::FromVector(
std::vector<DType>({
0.56, 0.14, 0.21, 0.14, 0.9 , 0. , 0. , 0. , 0. , 0.45, 0.16, 0.4 , 0.06, 0.28, 0.4 ,
0.04, 0.08}), ctx);
return {csr, weights};
}
template <typename IdType, typename DType>
std::pair<aten::CSRMatrix, NDArray> CSR_A_plus_C(DLContext ctx = CTX) {
auto csr = aten::CSRMatrix(
4, 5,
NDArray::FromVector(std::vector<IdType>({0, 2, 5, 9, 12}), ctx),
NDArray::FromVector(std::vector<IdType>({2, 3, 2, 3, 4, 0, 1, 3, 4, 1, 3, 4}), ctx));
auto weights = NDArray::FromVector(
std::vector<DType>({1. , 0.9, 0.5, 0.5, 0.4, 0.4, 0.9, 1.1, 0.2, 1. , 0.7, 0.2}), ctx);
return {csr, weights};
}
template <typename DType>
NDArray CSR_A_mask_C(DLContext ctx = CTX) {
return NDArray::FromVector(std::vector<DType>({0.7, 0.0, 0.0, 0.7, 0.2, 0.0, 0.0, 0.0}), ctx);
}
template <typename IdType, typename DType>
void _TestCsrmm(DLContext ctx = CTX) {
auto A = CSR_A<IdType, DType>(ctx);
auto B = CSR_B<IdType, DType>(ctx);
auto A_mm_B = aten::CSRMM(A.first, A.second, B.first, B.second);
auto A_mm_B2 = CSR_A_mm_B<IdType, DType>(ctx);
bool result = CSRIsClose<IdType, DType>(A_mm_B.first, A_mm_B2.first, A_mm_B.second, A_mm_B2.second, 1e-4, 1e-4);
ASSERT_TRUE(result);
}
template <typename IdType, typename DType>
void _TestCsrsum(DLContext ctx = CTX) {
auto A = CSR_A<IdType, DType>(ctx);
auto C = CSR_C<IdType, DType>(ctx);
auto A_plus_C = aten::CSRSum({A.first, C.first}, {A.second, C.second});
auto A_plus_C2 = CSR_A_plus_C<IdType, DType>(ctx);
bool result = CSRIsClose<IdType, DType>(
A_plus_C.first, A_plus_C2.first, A_plus_C.second, A_plus_C2.second, 1e-4, 1e-4);
ASSERT_TRUE(result);
}
template <typename IdType, typename DType>
void _TestCsrmask(DLContext ctx = CTX) {
auto A = CSR_A<IdType, DType>(ctx);
auto C = CSR_C<IdType, DType>(ctx);
auto A_mask_C = aten::CSRMask(A.first, A.second, C.first);
auto A_mask_C2 = CSR_A_mask_C<DType>(ctx);
ASSERT_TRUE(ArrayEQ<DType>(A_mask_C, A_mask_C2));
}
TEST(CsrmmTest, TestCsrmm) {
_TestCsrmm<int32_t, float>(CPU);
_TestCsrmm<int32_t, double>(CPU);
_TestCsrmm<int64_t, float>(CPU);
_TestCsrmm<int64_t, double>(CPU);
#ifdef DGL_USE_CUDA
_TestCsrmm<int32_t, float>(GPU);
_TestCsrmm<int32_t, double>(GPU);
_TestCsrmm<int64_t, float>(GPU);
_TestCsrmm<int64_t, double>(GPU);
#endif
}
TEST(CsrmmTest, TestCsrsum) {
_TestCsrsum<int32_t, float>(CPU);
_TestCsrsum<int32_t, double>(CPU);
_TestCsrsum<int64_t, float>(CPU);
_TestCsrsum<int64_t, double>(CPU);
#ifdef DGL_USE_CUDA
_TestCsrsum<int32_t, float>(GPU);
_TestCsrsum<int32_t, double>(GPU);
_TestCsrsum<int64_t, float>(GPU);
_TestCsrsum<int64_t, double>(GPU);
#endif
}
TEST(CsrmmTest, TestCsrmask) {
_TestCsrmask<int32_t, float>(CPU);
_TestCsrmask<int32_t, double>(CPU);
_TestCsrmask<int64_t, float>(CPU);
_TestCsrmask<int64_t, double>(CPU);
#ifdef DGL_USE_CUDA
_TestCsrmask<int32_t, float>(GPU);
_TestCsrmask<int32_t, double>(GPU);
_TestCsrmask<int64_t, float>(GPU);
_TestCsrmask<int64_t, double>(GPU);
#endif
}
}; // namespace
...@@ -19,7 +19,7 @@ TEST(Serialize, UnitGraph_COO) { ...@@ -19,7 +19,7 @@ TEST(Serialize, UnitGraph_COO) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3}); auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6}); auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg = std::dynamic_pointer_cast<UnitGraph>( auto mg = std::dynamic_pointer_cast<UnitGraph>(
dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, coo_code)); dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, COO_CODE));
std::string blob; std::string blob;
dmlc::MemoryStringStream ifs(&blob); dmlc::MemoryStringStream ifs(&blob);
...@@ -43,7 +43,7 @@ TEST(Serialize, UnitGraph_CSR) { ...@@ -43,7 +43,7 @@ TEST(Serialize, UnitGraph_CSR) {
auto coo_g = std::dynamic_pointer_cast<UnitGraph>( auto coo_g = std::dynamic_pointer_cast<UnitGraph>(
dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst)); dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst));
auto csr_g = auto csr_g =
std::dynamic_pointer_cast<UnitGraph>(coo_g->GetGraphInFormat(csr_code)); std::dynamic_pointer_cast<UnitGraph>(coo_g->GetGraphInFormat(CSR_CODE));
std::string blob; std::string blob;
dmlc::MemoryStringStream ifs(&blob); dmlc::MemoryStringStream ifs(&blob);
......
...@@ -82,19 +82,19 @@ void _TestUnitGraph(DLContext ctx) { ...@@ -82,19 +82,19 @@ void _TestUnitGraph(DLContext ctx) {
auto src = aten::VecToIdArray<int64_t>({1, 2, 5, 3}); auto src = aten::VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = aten::VecToIdArray<int64_t>({1, 6, 2, 6}); auto dst = aten::VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, coo_code); auto mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, COO_CODE);
ASSERT_EQ(mg->GetCreatedFormats(), 1); ASSERT_EQ(mg->GetCreatedFormats(), 1);
auto hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, coo_code); auto hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, COO_CODE);
auto img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph()); auto img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph());
ASSERT_TRUE(img != nullptr); ASSERT_TRUE(img != nullptr);
mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, csr_code | coo_code); mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, CSR_CODE | COO_CODE);
ASSERT_EQ(mg->GetCreatedFormats(), 1); ASSERT_EQ(mg->GetCreatedFormats(), 1);
hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, csr_code | coo_code); hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, CSR_CODE | COO_CODE);
img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph()); img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph());
ASSERT_TRUE(img != nullptr); ASSERT_TRUE(img != nullptr);
mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, csc_code | coo_code); mg = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, CSC_CODE | COO_CODE);
ASSERT_EQ(mg->GetCreatedFormats(), 1); ASSERT_EQ(mg->GetCreatedFormats(), 1);
hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, csc_code | coo_code); hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, CSC_CODE | COO_CODE);
img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph()); img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph());
ASSERT_TRUE(img != nullptr); ASSERT_TRUE(img != nullptr);
...@@ -121,7 +121,7 @@ void _TestUnitGraph_GetInCSR(DLContext ctx) { ...@@ -121,7 +121,7 @@ void _TestUnitGraph_GetInCSR(DLContext ctx) {
// test out csr // test out csr
g = CreateFromCSR(2, csr); g = CreateFromCSR(2, csr);
auto g_ptr = g->GetGraphInFormat(csc_code); auto g_ptr = g->GetGraphInFormat(CSC_CODE);
in_csr_matrix = g_ptr->GetCSCMatrix(0); in_csr_matrix = g_ptr->GetCSCMatrix(0);
ASSERT_EQ(in_csr_matrix.num_cols, csr.num_rows); ASSERT_EQ(in_csr_matrix.num_cols, csr.num_rows);
ASSERT_EQ(in_csr_matrix.num_rows, csr.num_cols); ASSERT_EQ(in_csr_matrix.num_rows, csr.num_cols);
...@@ -133,7 +133,7 @@ void _TestUnitGraph_GetInCSR(DLContext ctx) { ...@@ -133,7 +133,7 @@ void _TestUnitGraph_GetInCSR(DLContext ctx) {
// test out coo // test out coo
g = CreateFromCOO(2, coo); g = CreateFromCOO(2, coo);
g_ptr = g->GetGraphInFormat(csc_code); g_ptr = g->GetGraphInFormat(CSC_CODE);
in_csr_matrix = g_ptr->GetCSCMatrix(0); in_csr_matrix = g_ptr->GetCSCMatrix(0);
ASSERT_EQ(in_csr_matrix.num_cols, coo.num_rows); ASSERT_EQ(in_csr_matrix.num_cols, coo.num_rows);
ASSERT_EQ(in_csr_matrix.num_rows, coo.num_cols); ASSERT_EQ(in_csr_matrix.num_rows, coo.num_cols);
...@@ -151,7 +151,7 @@ void _TestUnitGraph_GetOutCSR(DLContext ctx) { ...@@ -151,7 +151,7 @@ void _TestUnitGraph_GetOutCSR(DLContext ctx) {
const aten::COOMatrix &coo = COO1<IdType>(ctx); const aten::COOMatrix &coo = COO1<IdType>(ctx);
auto g = CreateFromCSC(2, csr); auto g = CreateFromCSC(2, csr);
auto g_ptr = g->GetGraphInFormat(csr_code); auto g_ptr = g->GetGraphInFormat(CSR_CODE);
auto out_csr_matrix = g_ptr->GetCSRMatrix(0); auto out_csr_matrix = g_ptr->GetCSRMatrix(0);
ASSERT_EQ(out_csr_matrix.num_cols, csr.num_rows); ASSERT_EQ(out_csr_matrix.num_cols, csr.num_rows);
ASSERT_EQ(out_csr_matrix.num_rows, csr.num_cols); ASSERT_EQ(out_csr_matrix.num_rows, csr.num_cols);
...@@ -170,7 +170,7 @@ void _TestUnitGraph_GetOutCSR(DLContext ctx) { ...@@ -170,7 +170,7 @@ void _TestUnitGraph_GetOutCSR(DLContext ctx) {
// test out coo // test out coo
g = CreateFromCOO(2, coo); g = CreateFromCOO(2, coo);
g_ptr = g->GetGraphInFormat(csr_code); g_ptr = g->GetGraphInFormat(CSR_CODE);
out_csr_matrix = g_ptr->GetCSRMatrix(0); out_csr_matrix = g_ptr->GetCSRMatrix(0);
ASSERT_EQ(out_csr_matrix.num_rows, coo.num_rows); ASSERT_EQ(out_csr_matrix.num_rows, coo.num_rows);
ASSERT_EQ(out_csr_matrix.num_cols, coo.num_cols); ASSERT_EQ(out_csr_matrix.num_cols, coo.num_cols);
...@@ -188,7 +188,7 @@ void _TestUnitGraph_GetCOO(DLContext ctx) { ...@@ -188,7 +188,7 @@ void _TestUnitGraph_GetCOO(DLContext ctx) {
const aten::COOMatrix &coo = COO1<IdType>(ctx); const aten::COOMatrix &coo = COO1<IdType>(ctx);
auto g = CreateFromCSC(2, csr); auto g = CreateFromCSC(2, csr);
auto g_ptr = g->GetGraphInFormat(coo_code); auto g_ptr = g->GetGraphInFormat(COO_CODE);
auto out_coo_matrix = g_ptr->GetCOOMatrix(0); auto out_coo_matrix = g_ptr->GetCOOMatrix(0);
ASSERT_EQ(out_coo_matrix.num_cols, csr.num_rows); ASSERT_EQ(out_coo_matrix.num_cols, csr.num_rows);
ASSERT_EQ(out_coo_matrix.num_rows, csr.num_cols); ASSERT_EQ(out_coo_matrix.num_rows, csr.num_cols);
...@@ -200,7 +200,7 @@ void _TestUnitGraph_GetCOO(DLContext ctx) { ...@@ -200,7 +200,7 @@ void _TestUnitGraph_GetCOO(DLContext ctx) {
// test out csr // test out csr
g = CreateFromCSR(2, csr); g = CreateFromCSR(2, csr);
g_ptr = g->GetGraphInFormat(coo_code); g_ptr = g->GetGraphInFormat(COO_CODE);
out_coo_matrix = g_ptr->GetCOOMatrix(0); out_coo_matrix = g_ptr->GetCOOMatrix(0);
ASSERT_EQ(out_coo_matrix.num_rows, csr.num_rows); ASSERT_EQ(out_coo_matrix.num_rows, csr.num_rows);
ASSERT_EQ(out_coo_matrix.num_cols, csr.num_cols); ASSERT_EQ(out_coo_matrix.num_cols, csr.num_cols);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment