Unverified Commit f25b1a06 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Feature] Autograd of gspmm and gsddmm on PyTorch/MXNet/Tensorflow (#1680)

* init

* reverse(by minjie

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* gpu

* upd

* upd

* upd

* upd

* udp

* upd

* upd

* imidiot

* fix

* upd

* upd

* upd

* udp

* upd

* upd

* fix

* udp

* upd

* upd

* upd

* upd

* upd

* fix

* remove redundency

* upd

* upd

* upd

* cache

* upd

* upd

* upd

* upd

* upd

* udp

* upd

* trigger

* upd

* fix

* upd

* unused import

* upd

* upd
parent c13903bf
...@@ -24,11 +24,11 @@ namespace aten { ...@@ -24,11 +24,11 @@ namespace aten {
} else if ((op) == "div") { \ } else if ((op) == "div") { \
typedef cuda::binary::Div<DType> Op; \ typedef cuda::binary::Div<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((op) == "copy_u") { \ } else if ((op) == "copy_lhs") { \
typedef cuda::binary::CopyU<DType> Op; \ typedef cuda::binary::CopyLhs<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((op) == "copy_e") { \ } else if ((op) == "copy_rhs") { \
typedef cuda::binary::CopyE<DType> Op; \ typedef cuda::binary::CopyRhs<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((op) == "dot") { \ } else if ((op) == "dot") { \
typedef cuda::binary::Dot<DType> Op; \ typedef cuda::binary::Dot<DType> Op; \
...@@ -38,6 +38,37 @@ namespace aten { ...@@ -38,6 +38,37 @@ namespace aten {
} \ } \
} while (0) } while (0)
#define SWITCH_RHS(rhs_target, RhsTarget, ...) \
do { \
if ((rhs_target) == 0) { \
constexpr int RhsTarget = 0; \
{ __VA_ARGS__ } \
} else if ((rhs_target) == 1) { \
constexpr int RhsTarget = 1; \
{ __VA_ARGS__ } \
} else if ((rhs_target) == 2) { \
constexpr int RhsTarget = 2; \
{ __VA_ARGS__ } \
} else { \
LOG(INFO) << "Invalid rhs target: " << (rhs_target); \
} \
} while (0)
#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\
do { \
if ((lhs_target) == 0) { \
constexpr int LhsTarget = 0; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 1) { \
constexpr int LhsTarget = 1; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 2) { \
constexpr int LhsTarget = 2; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else { \
LOG(INFO) << "Invalid lhs target: " << (lhs_target); \
} \
} while (0)
/*! /*!
* \brief CUDA implementation of g-SDDMM on Csr format. * \brief CUDA implementation of g-SDDMM on Csr format.
...@@ -46,11 +77,15 @@ template <int XPU, typename IdType, typename DType> ...@@ -46,11 +77,15 @@ template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op, void SDDMMCsr(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const CSRMatrix& csr, const CSRMatrix& csr,
NDArray ufeat, NDArray lhs,
NDArray vfeat, NDArray rhs,
NDArray out) { NDArray out,
int lhs_target,
int rhs_target) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
cuda::SDDMMCsr<IdType, DType, Op>(bcast, csr, ufeat, vfeat, out); SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
});
}); });
} }
...@@ -61,39 +96,51 @@ template <int XPU, typename IdType, typename DType> ...@@ -61,39 +96,51 @@ template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op, void SDDMMCoo(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const COOMatrix& coo, const COOMatrix& coo,
NDArray ufeat, NDArray lhs,
NDArray vfeat, NDArray rhs,
NDArray out) { NDArray out,
int lhs_target,
int rhs_target) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
cuda::SDDMMCoo<IdType, DType, Op>(bcast, coo, ufeat, vfeat, out); SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
});
}); });
} }
template void SDDMMCsr<kDLGPU, int32_t, float>( template void SDDMMCsr<kDLGPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray vfeat, NDArray out); NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int64_t, float>( template void SDDMMCsr<kDLGPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray vfeat, NDArray out); NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int32_t, double>( template void SDDMMCsr<kDLGPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray vfeat, NDArray out); NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int64_t, double>( template void SDDMMCsr<kDLGPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray vfeat, NDArray out); NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int32_t, float>( template void SDDMMCoo<kDLGPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray vfeat, NDArray out); NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int64_t, float>( template void SDDMMCoo<kDLGPU, int64_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray vfeat, NDArray out); NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int32_t, double>( template void SDDMMCoo<kDLGPU, int32_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray vfeat, NDArray out); NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int64_t, double>( template void SDDMMCoo<kDLGPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray vfeat, NDArray out); NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "atomic.cuh" #include "atomic.cuh"
#include "functor.cuh" #include "functor.cuh"
#include "./utils.h" #include "./utils.h"
#include "../selector.h"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
namespace dgl { namespace dgl {
...@@ -28,13 +29,19 @@ namespace cuda { ...@@ -28,13 +29,19 @@ namespace cuda {
* in feature dimension. * in feature dimension.
*/ */
template <typename Idx, typename DType, typename BinaryOp, template <typename Idx, typename DType, typename BinaryOp,
bool UseBcast = false, bool UseIdx = false> bool UseBcast = false, bool UseIdx = false,
int LhsTarget = 0, int RhsTarget = 2>
__global__ void SDDMMCooKernel( __global__ void SDDMMCooKernel(
const DType *ufeat, const DType *vfeat, DType *out, const DType* __restrict__ lhs,
const Idx *row, const Idx *col, const Idx* edge_map, const DType* __restrict__ rhs,
DType* __restrict__ out,
const Idx* __restrict__ row,
const Idx* __restrict__ col,
const Idx* __restrict__ edge_map,
int64_t N, int64_t M, int64_t E, int64_t reduce_size, int64_t N, int64_t M, int64_t E, int64_t reduce_size,
const int64_t *ubcast_off, const int64_t *vbcast_off, const int64_t* __restrict__ lhs_off,
int64_t ufeat_len, int64_t vfeat_len, int64_t out_len) { const int64_t* __restrict__ rhs_off,
int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
// SDDMM with COO. // SDDMM with COO.
Idx ty = blockIdx.y * blockDim.y + threadIdx.y; Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y; const Idx stride_y = blockDim.y * gridDim.y;
...@@ -43,15 +50,15 @@ __global__ void SDDMMCooKernel( ...@@ -43,15 +50,15 @@ __global__ void SDDMMCooKernel(
const Idx dst = _ldg(col + ty); const Idx dst = _ldg(col + ty);
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty; const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
const DType* lhsoff = BinaryOp::use_lhs ? const DType* lhsoff = BinaryOp::use_lhs ?
(ufeat + src * ufeat_len): nullptr; (lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len): nullptr;
const DType* rhsoff = BinaryOp::use_rhs ? const DType* rhsoff = BinaryOp::use_rhs ?
(vfeat + dst * vfeat_len): nullptr; (rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len): nullptr;
DType* outoff = out + eid * out_len; DType* outoff = out + eid * out_len;
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = blockDim.x * gridDim.x; const int stride_x = blockDim.x * gridDim.x;
while (tx < out_len) { while (tx < out_len) {
const Idx lhs_add = UseBcast ? ubcast_off[tx] : tx; const Idx lhs_add = UseBcast ? lhs_off[tx] : tx;
const Idx rhs_add = UseBcast ? vbcast_off[tx] : tx; const Idx rhs_add = UseBcast ? rhs_off[tx] : tx;
DType val = BinaryOp::Call( DType val = BinaryOp::Call(
lhsoff + lhs_add * reduce_size, lhsoff + lhs_add * reduce_size,
rhsoff + rhs_add * reduce_size, rhsoff + rhs_add * reduce_size,
...@@ -93,13 +100,19 @@ __device__ __forceinline__ Idx BinarySearchSrc(const Idx *array, Idx length, Idx ...@@ -93,13 +100,19 @@ __device__ __forceinline__ Idx BinarySearchSrc(const Idx *array, Idx length, Idx
* given edge on Csr format, it uses binary search (time complexity O(log N)). * given edge on Csr format, it uses binary search (time complexity O(log N)).
*/ */
template <typename Idx, typename DType, typename BinaryOp, template <typename Idx, typename DType, typename BinaryOp,
bool UseBcast = false, bool UseIdx = false> bool UseBcast = false, bool UseIdx = false,
int LhsTarget = 0, int RhsTarget = 2>
__global__ void SDDMMCsrKernel( __global__ void SDDMMCsrKernel(
const DType *ufeat, const DType *vfeat, DType *out, const DType* __restrict__ lhs,
const Idx *indptr, const Idx *indices, const Idx* edge_map, const DType* __restrict__ rhs,
DType* __restrict__ out,
const Idx* __restrict__ indptr,
const Idx* __restrict__ indices,
const Idx* __restrict__ edge_map,
int64_t N, int64_t M, int64_t E, int64_t reduce_size, int64_t N, int64_t M, int64_t E, int64_t reduce_size,
int64_t *ubcast_off, int64_t *vbcast_off, const int64_t* __restrict__ lhs_off,
int64_t ufeat_len, int64_t vfeat_len, int64_t out_len) { const int64_t* __restrict__ rhs_off,
int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
// SDDMM with Csr. // SDDMM with Csr.
Idx ty = blockIdx.y * blockDim.y + threadIdx.y; Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y; const Idx stride_y = blockDim.y * gridDim.y;
...@@ -109,12 +122,14 @@ __global__ void SDDMMCsrKernel( ...@@ -109,12 +122,14 @@ __global__ void SDDMMCsrKernel(
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty; const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t stride_x = blockDim.x * gridDim.x; const int64_t stride_x = blockDim.x * gridDim.x;
const DType* lhsoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len): nullptr; const DType* lhsoff = BinaryOp::use_lhs ?
const DType* rhsoff = BinaryOp::use_rhs ? (vfeat + dst * vfeat_len): nullptr; (lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len): nullptr;
const DType* rhsoff = BinaryOp::use_rhs ?
(rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len): nullptr;
DType* outoff = out + eid * out_len; DType* outoff = out + eid * out_len;
while (tx < out_len) { while (tx < out_len) {
const Idx lhs_add = UseBcast ? ubcast_off[tx] : tx; const Idx lhs_add = UseBcast ? lhs_off[tx] : tx;
const Idx rhs_add = UseBcast ? vbcast_off[tx] : tx; const Idx rhs_add = UseBcast ? rhs_off[tx] : tx;
DType val = BinaryOp::Call( DType val = BinaryOp::Call(
lhsoff + lhs_add * reduce_size, lhsoff + lhs_add * reduce_size,
rhsoff + rhs_add * reduce_size, rhsoff + rhs_add * reduce_size,
...@@ -130,26 +145,27 @@ __global__ void SDDMMCsrKernel( ...@@ -130,26 +145,27 @@ __global__ void SDDMMCsrKernel(
* \brief CUDA implementation of g-SDDMM on Coo format. * \brief CUDA implementation of g-SDDMM on Coo format.
* \param bcast Broadcast information. * \param bcast Broadcast information.
* \param coo The Coo matrix. * \param coo The Coo matrix.
* \param ufeat The feature on source nodes. * \param lhs The left hand side operand feature.
* \param vfeat The feature on destination nodes. * \param rhs The right hand size operand feature.
* \param out The result feature on edges. * \param out The result feature on edges.
*/ */
template <typename Idx, typename DType, typename Op> template <typename Idx, typename DType, typename Op,
int LhsTarget = 0, int RhsTarget = 2>
void SDDMMCoo( void SDDMMCoo(
const BcastOff& bcast, const BcastOff& bcast,
const COOMatrix& coo, const COOMatrix& coo,
NDArray ufeat, NDArray lhs,
NDArray vfeat, NDArray rhs,
NDArray out) { NDArray out) {
const Idx *row = coo.row.Ptr<Idx>(); const Idx *row = coo.row.Ptr<Idx>();
const Idx *col = coo.col.Ptr<Idx>(); const Idx *col = coo.col.Ptr<Idx>();
const Idx *edge_map = coo.data.Ptr<Idx>(); const Idx *edge_map = coo.data.Ptr<Idx>();
const DType *ufeat_data = ufeat.Ptr<DType>(); const DType *lhs_data = lhs.Ptr<DType>();
const DType *vfeat_data = vfeat.Ptr<DType>(); const DType *rhs_data = rhs.Ptr<DType>();
DType *out_data = out.Ptr<DType>(); DType *out_data = out.Ptr<DType>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
int64_t *ubcast_off = nullptr, *vbcast_off = nullptr; int64_t *lhs_off = nullptr, *rhs_off = nullptr;
int64_t len = bcast.out_len, int64_t len = bcast.out_len,
lhs_len = bcast.lhs_len, lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len; rhs_len = bcast.rhs_len;
...@@ -165,13 +181,13 @@ void SDDMMCoo( ...@@ -165,13 +181,13 @@ void SDDMMCoo(
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
const bool use_idx = !IsNullArray(coo.data); const bool use_idx = !IsNullArray(coo.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, vbcast_off, { BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
SDDMMCooKernel<Idx, DType, Op, UseBcast, UseIdx> SDDMMCooKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>
<<<nblks, nthrs, 0, thr_entry->stream>>>( <<<nblks, nthrs, 0, thr_entry->stream>>>(
ufeat_data, vfeat_data, out_data, lhs_data, rhs_data, out_data,
row, col, edge_map, row, col, edge_map,
coo.num_rows, coo.num_cols, nnz, reduce_dim, coo.num_rows, coo.num_cols, nnz, reduce_dim,
ubcast_off, vbcast_off, lhs_off, rhs_off,
lhs_len, rhs_len, len lhs_len, rhs_len, len
); );
}); });
...@@ -181,26 +197,28 @@ void SDDMMCoo( ...@@ -181,26 +197,28 @@ void SDDMMCoo(
* \brief CUDA implementation of g-SDDMM on Csr format. * \brief CUDA implementation of g-SDDMM on Csr format.
* \param bcast Broadcast information. * \param bcast Broadcast information.
* \param csr The Csr matrix. * \param csr The Csr matrix.
* \param ufeat The feature on source nodes. * \param lhs The left hand side operand feature.
* \param vfeat The feature on destination nodes. * \param rhs The right hand size operand feature.
* \param out The result feature on edges. * \param out The result feature on edges.
*/ */
template <typename Idx, typename DType, typename Op> template <typename Idx, typename DType, typename Op,
int LhsTarget = 0, int RhsTarget = 2>
void SDDMMCsr( void SDDMMCsr(
const BcastOff& bcast, const BcastOff& bcast,
const CSRMatrix& csr, const CSRMatrix& csr,
NDArray ufeat, NDArray lhs,
NDArray vfeat, NDArray rhs,
NDArray out) { const Idx *indptr = csr.indptr.Ptr<Idx>(); NDArray out) {
const Idx *indptr = csr.indptr.Ptr<Idx>();
const Idx *indices = csr.indices.Ptr<Idx>(); const Idx *indices = csr.indices.Ptr<Idx>();
const Idx *edge_map = csr.data.Ptr<Idx>(); const Idx *edge_map = csr.data.Ptr<Idx>();
const DType *ufeat_data = ufeat.Ptr<DType>(); const DType *lhs_data = lhs.Ptr<DType>();
const DType *vfeat_data = vfeat.Ptr<DType>(); const DType *rhs_data = rhs.Ptr<DType>();
DType *out_data = out.Ptr<DType>(); DType *out_data = out.Ptr<DType>();
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
int64_t N = csr.num_rows, M = csr.num_cols, E = csr.indices->shape[0]; int64_t N = csr.num_rows, M = csr.num_cols, E = csr.indices->shape[0];
int64_t *ubcast_off = nullptr, *vbcast_off = nullptr; int64_t *lhs_off = nullptr, *rhs_off = nullptr;
int64_t len = bcast.out_len, int64_t len = bcast.out_len,
lhs_len = bcast.lhs_len, lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len; rhs_len = bcast.rhs_len;
...@@ -214,13 +232,13 @@ void SDDMMCsr( ...@@ -214,13 +232,13 @@ void SDDMMCsr(
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
const bool use_idx = !IsNullArray(csr.data); const bool use_idx = !IsNullArray(csr.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, vbcast_off, { BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
SDDMMCsrKernel<Idx, DType, Op, UseBcast, UseIdx> SDDMMCsrKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>
<<<nblks, nthrs, 0, thr_entry->stream>>>( <<<nblks, nthrs, 0, thr_entry->stream>>>(
ufeat_data, vfeat_data, out_data, lhs_data, rhs_data, out_data,
indptr, indices, edge_map, indptr, indices, edge_map,
N, M, E, reduce_dim, N, M, E, reduce_dim,
ubcast_off, vbcast_off, lhs_off, rhs_off,
lhs_len, rhs_len, len lhs_len, rhs_len, len
); );
}); });
......
...@@ -173,11 +173,11 @@ void CusparseCsrmm2( ...@@ -173,11 +173,11 @@ void CusparseCsrmm2(
} else if ((op) == "div") { \ } else if ((op) == "div") { \
typedef cuda::binary::Div<DType> Op; \ typedef cuda::binary::Div<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((op) == "copy_u") { \ } else if ((op) == "copy_lhs") { \
typedef cuda::binary::CopyU<DType> Op; \ typedef cuda::binary::CopyLhs<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((op) == "copy_e") { \ } else if ((op) == "copy_rhs") { \
typedef cuda::binary::CopyE<DType> Op; \ typedef cuda::binary::CopyRhs<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else { \ } else { \
LOG(FATAL) << "Unsupported SpMM binary operator: " << op; \ LOG(FATAL) << "Unsupported SpMM binary operator: " << op; \
...@@ -198,7 +198,7 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -198,7 +198,7 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
NDArray out, NDArray out,
std::vector<NDArray> out_aux) { std::vector<NDArray> out_aux) {
if (reduce == "sum") { if (reduce == "sum") {
if (sizeof(IdType) == 4 && op == "copy_u") { if (sizeof(IdType) == 4 && op == "copy_lhs") {
int64_t x_length = 1; int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i) for (int i = 1; i < ufeat->ndim; ++i)
x_length *= ufeat->shape[i]; x_length *= ufeat->shape[i];
......
...@@ -46,10 +46,17 @@ template <typename Idx, typename DType, ...@@ -46,10 +46,17 @@ template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp, typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false> bool UseBcast = false, bool UseIdx = false>
__global__ void SpMMCooKernel( __global__ void SpMMCooKernel(
const DType *ufeat, const DType *efeat, DType *out, Idx *arg_u, Idx *arg_e, const DType* __restrict__ ufeat,
const Idx *row, const Idx *col, const Idx* edge_map, const DType* __restrict__ efeat,
DType* __restrict__ out,
Idx* __restrict__ arg_u,
Idx* __restrict__ arg_e,
const Idx* __restrict__ row,
const Idx* __restrict__ col,
const Idx* __restrict__ edge_map,
int64_t N, int64_t M, int64_t E, int64_t N, int64_t M, int64_t E,
int64_t *ubcast_off, int64_t *ebcast_off, const int64_t* __restrict__ ubcast_off,
const int64_t* __restrict__ ebcast_off,
int64_t ufeat_len, int64_t efeat_len, int64_t out_len) { int64_t ufeat_len, int64_t efeat_len, int64_t out_len) {
// SPMM with COO. // SPMM with COO.
Idx ty = blockIdx.y * blockDim.y + threadIdx.y; Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
...@@ -87,10 +94,17 @@ template <typename Idx, typename DType, ...@@ -87,10 +94,17 @@ template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp, typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false> bool UseBcast = false, bool UseIdx = false>
__global__ void ArgSpMMCooKernel( __global__ void ArgSpMMCooKernel(
const DType *ufeat, const DType *efeat, DType *out, Idx *arg_u, Idx *arg_e, const DType* __restrict__ ufeat,
const Idx *row, const Idx *col, const Idx* edge_map, const DType* __restrict__ efeat,
DType* __restrict__ out,
Idx* __restrict__ arg_u,
Idx* __restrict__ arg_e,
const Idx* __restrict__ row,
const Idx* __restrict__ col,
const Idx* __restrict__ edge_map,
int64_t N, int64_t M, int64_t E, int64_t N, int64_t M, int64_t E,
int64_t *ubcast_off, int64_t *ebcast_off, const int64_t* __restrict__ ubcast_off,
const int64_t* __restrict__ ebcast_off,
int64_t ufeat_len, int64_t efeat_len, int64_t out_len) { int64_t ufeat_len, int64_t efeat_len, int64_t out_len) {
// SPMM with COO arg max/min. // SPMM with COO arg max/min.
Idx ty = blockIdx.y * blockDim.y + threadIdx.y; Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
...@@ -128,10 +142,17 @@ template <typename Idx, typename DType, ...@@ -128,10 +142,17 @@ template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp, typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false> bool UseBcast = false, bool UseIdx = false>
__global__ void SpMMCsrKernel( __global__ void SpMMCsrKernel(
const DType *ufeat, const DType *efeat, DType *out, Idx *arg_u, Idx *arg_e, const DType* __restrict__ ufeat,
const Idx *indptr, const Idx *indices, const Idx *edge_map, const DType* __restrict__ efeat,
DType* __restrict__ out,
Idx* __restrict__ arg_u,
Idx* __restrict__ arg_e,
const Idx* __restrict__ indptr,
const Idx* __restrict__ indices,
const Idx* __restrict__ edge_map,
int64_t num_rows, int64_t num_cols, int64_t nnz, int64_t num_rows, int64_t num_cols, int64_t nnz,
int64_t *ubcast_off, int64_t *ebcast_off, const int64_t* __restrict__ ubcast_off,
const int64_t* __restrict__ ebcast_off,
int64_t ufeat_len, int64_t efeat_len, int64_t out_len) { int64_t ufeat_len, int64_t efeat_len, int64_t out_len) {
// SPMM with CSR. // SPMM with CSR.
int ty = blockIdx.y * blockDim.y + threadIdx.y; int ty = blockIdx.y * blockDim.y + threadIdx.y;
......
...@@ -98,13 +98,15 @@ void SpMM(const std::string& op, const std::string& reduce, ...@@ -98,13 +98,15 @@ void SpMM(const std::string& op, const std::string& reduce,
/*! \brief Generalized Sampled Dense-Dense Matrix Multiplication. */ /*! \brief Generalized Sampled Dense-Dense Matrix Multiplication. */
void SDDMM(const std::string& op, void SDDMM(const std::string& op,
HeteroGraphPtr graph, HeteroGraphPtr graph,
NDArray ufeat, NDArray lhs,
NDArray efeat, NDArray rhs,
NDArray out, NDArray out,
int lhs_target,
int rhs_target,
SparseFormat format) { SparseFormat format) {
// TODO(zihao): format tuning // TODO(zihao): format tuning
format = SparseFormat::kCOO; format = SparseFormat::kCOO;
const auto& bcast = CalcBcastOff(op, ufeat, efeat); const auto& bcast = CalcBcastOff(op, lhs, rhs);
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", { ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
...@@ -112,11 +114,11 @@ void SDDMM(const std::string& op, ...@@ -112,11 +114,11 @@ void SDDMM(const std::string& op,
if (format == SparseFormat::kCSR) { if (format == SparseFormat::kCSR) {
SDDMMCsr<XPU, IdType, DType>( SDDMMCsr<XPU, IdType, DType>(
op, bcast, graph->GetCSRMatrix(0), op, bcast, graph->GetCSRMatrix(0),
ufeat, efeat, out); lhs, rhs, out, lhs_target, rhs_target);
} else if (format == SparseFormat::kCOO) { } else if (format == SparseFormat::kCOO) {
SDDMMCoo<XPU, IdType, DType>( SDDMMCoo<XPU, IdType, DType>(
op, bcast, graph->GetCOOMatrix(0), op, bcast, graph->GetCOOMatrix(0),
ufeat, efeat, out); lhs, rhs, out, lhs_target, rhs_target);
} else { } else {
LOG(FATAL) << "SDDMM only supports CSR and COO foramts"; LOG(FATAL) << "SDDMM only supports CSR and COO foramts";
} }
...@@ -155,21 +157,23 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM") ...@@ -155,21 +157,23 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
const std::string op = args[1]; const std::string op = args[1];
NDArray U = args[2]; NDArray lhs = args[2];
NDArray V = args[3]; NDArray rhs = args[3];
NDArray E = args[4]; NDArray out = args[4];
CheckCtx(graph->Context(), {U, V, E}, {"U_data", "V_data", "E_data"}); int lhs_target = args[5];
CheckContiguous({U, V, E}, {"U_data", "V_data", "E_data"}); int rhs_target = args[6];
CheckCtx(graph->Context(), {lhs, rhs, out}, {"lhs", "rhs", "out"});
CheckContiguous({lhs, rhs, out}, {"lhs", "rhs", "out"});
CHECK_EQ(graph->NumEdgeTypes(), 1); CHECK_EQ(graph->NumEdgeTypes(), 1);
auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the graph. auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the graph.
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
CheckShape( CheckShape(
{graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)}, {graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)},
{0, 1, 2}, {lhs_target, rhs_target, 1},
{U, E, V}, {lhs, rhs, out},
{"U_data", "E_data", "V_data"}); {"U_data", "E_data", "V_data"});
SDDMM(op, graph.sptr(), U, V, E, SparseFormat::kAny); SDDMM(op, graph.sptr(), lhs, rhs, out, lhs_target, rhs_target, SparseFormat::kAny);
}); });
} // namespace aten } // namespace aten
......
...@@ -47,9 +47,11 @@ template <int XPU, typename IdType, typename DType> ...@@ -47,9 +47,11 @@ template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op, void SDDMMCsr(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const aten::CSRMatrix& csr, const aten::CSRMatrix& csr,
NDArray ufeat, NDArray lhs,
NDArray efeat, NDArray rhs,
NDArray out); NDArray out,
int lhs_target,
int rhs_target);
/*! /*!
* \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format. * \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format.
...@@ -58,9 +60,11 @@ template <int XPU, typename IdType, typename DType> ...@@ -58,9 +60,11 @@ template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op, void SDDMMCoo(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const aten::COOMatrix& coo, const aten::COOMatrix& coo,
NDArray ufeat, NDArray lhs,
NDArray efeat, NDArray rhs,
NDArray out); NDArray out,
int lhs_target,
int rhs_target);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
/*!
* Copyright (c) 2020 by Contributors
* \file array/selector.h
* \brief Selector functions to select among src/edge/dst attributes.
*/
#ifndef DGL_ARRAY_SELECTOR_H_
#define DGL_ARRAY_SELECTOR_H_
#include <dmlc/logging.h>
namespace dgl {
namespace {
#ifdef __CUDACC__
#define DGLDEVICE __device__
#define DGLINLINE __forceinline__
#else
#define DGLDEVICE
#define DGLINLINE inline
#endif // __CUDACC__
} // namespace
/*!
* \brief Select among src/edge/dst feature/idx.
* \note the integer argument target specifies which target
* to choose, 0: src, 1: edge, 2: dst.
*/
template <int target>
struct Selector {
template <typename T>
static DGLDEVICE DGLINLINE T Call(T src, T edge, T dst) {
LOG(INFO) << "Target " << target << " not recognized.";
return src;
}
};
template <>
template <typename T>
DGLDEVICE DGLINLINE T Selector<0>::Call(T src, T edge, T dst) {
return src;
}
template <>
template <typename T>
DGLDEVICE DGLINLINE T Selector<1>::Call(T src, T edge, T dst) {
return edge;
}
template <>
template <typename T>
DGLDEVICE DGLINLINE T Selector<2>::Call(T src, T edge, T dst) {
return dst;
}
} // namespace dgl
#endif // DGL_ARRAY_SELECTOR_H_
...@@ -15,7 +15,7 @@ namespace { ...@@ -15,7 +15,7 @@ namespace {
* type, lhs array and rhs array. * type, lhs array and rhs array.
*/ */
bool UseBcast(const std::string& op, NDArray lhs, NDArray rhs) { bool UseBcast(const std::string& op, NDArray lhs, NDArray rhs) {
if (op == "copy_u" || op == "copy_e") if (op == "copy_lhs" || op == "copy_rhs")
return false; // broadcasting is not required for copy_u/copy_e return false; // broadcasting is not required for copy_u/copy_e
if (lhs->ndim != rhs->ndim) if (lhs->ndim != rhs->ndim)
return true; return true;
...@@ -77,7 +77,7 @@ BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs) { ...@@ -77,7 +77,7 @@ BcastOff CalcBcastOff(const std::string& op, NDArray lhs, NDArray rhs) {
} }
rst.out_len = out_len; rst.out_len = out_len;
} else { } else {
rst.out_len = (op == "copy_e") ? rst.rhs_len : rst.lhs_len; rst.out_len = (op == "copy_rhs") ? rst.rhs_len : rst.lhs_len;
if (op == "dot") { if (op == "dot") {
rst.reduce_size = lhs->shape[lhs->ndim - 1]; // set reduce_size for dot. rst.reduce_size = lhs->shape[lhs->ndim - 1]; // set reduce_size for dot.
rst.out_len /= rst.reduce_size; // out_len is divied by reduce_size in dot. rst.out_len /= rst.reduce_size; // out_len is divied by reduce_size in dot.
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/immutable_graph.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include "../c_api_common.h" #include "../c_api_common.h"
...@@ -597,11 +598,9 @@ DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes") ...@@ -597,11 +598,9 @@ DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes")
*rv = ret_list; *rv = ret_list;
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroReverse") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroReverse")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0]; HeteroGraphRef hg = args[0];
HeteroGraphRef hg = args[1];
CHECK_GT(hg->NumEdgeTypes(), 0); CHECK_GT(hg->NumEdgeTypes(), 0);
auto g = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr()); auto g = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());
std::vector<HeteroGraphPtr> rev_ugs; std::vector<HeteroGraphPtr> rev_ugs;
...@@ -614,7 +613,10 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroReverse") ...@@ -614,7 +613,10 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroReverse")
} }
// node types are not changed // node types are not changed
const auto& num_nodes = g->NumVerticesPerType(); const auto& num_nodes = g->NumVerticesPerType();
auto hgptr = CreateHeteroGraph(meta_graph.sptr(), rev_ugs, num_nodes); const auto& meta_edges = hg->meta_graph()->Edges("eid");
*rv = HeteroGraphRef(hgptr); // reverse the metagraph
const auto& rev_meta = ImmutableGraph::CreateFromCOO(hg->meta_graph()->NumVertices(),
meta_edges.dst, meta_edges.src);
*rv = CreateHeteroGraph(rev_meta, rev_ugs, num_nodes);
}); });
} // namespace dgl } // namespace dgl
...@@ -1295,7 +1295,7 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const { ...@@ -1295,7 +1295,7 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
CSRPtr ret = in_csr_; CSRPtr ret = in_csr_;
if (!in_csr_->defined()) { if (!in_csr_->defined()) {
if (out_csr_->defined()) { if (out_csr_->defined()) {
const auto& newadj = aten::CSRTranspose(out_csr_->adj()); const auto& newadj = aten::CSRSort(aten::CSRTranspose(out_csr_->adj()));
if (inplace) if (inplace)
*(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj); *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
......
...@@ -1944,7 +1944,7 @@ def test_reverse(index_dtype): ...@@ -1944,7 +1944,7 @@ def test_reverse(index_dtype):
('user', 'follows', 'user'): ([0, 1, 2, 4, 3 ,1, 3], [1, 2, 3, 2, 0, 0, 1]), ('user', 'follows', 'user'): ([0, 1, 2, 4, 3 ,1, 3], [1, 2, 3, 2, 0, 0, 1]),
}, index_dtype=index_dtype) }, index_dtype=index_dtype)
gidx = g._graph gidx = g._graph
r_gidx = gidx.reverse(gidx.metagraph) r_gidx = gidx.reverse()
assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0) assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0)
assert gidx.number_of_edges(0) == r_gidx.number_of_edges(0) assert gidx.number_of_edges(0) == r_gidx.number_of_edges(0)
...@@ -1956,7 +1956,7 @@ def test_reverse(index_dtype): ...@@ -1956,7 +1956,7 @@ def test_reverse(index_dtype):
# force to start with 'csr' # force to start with 'csr'
gidx = gidx.to_format('csr') gidx = gidx.to_format('csr')
gidx = gidx.to_format('any') gidx = gidx.to_format('any')
r_gidx = gidx.reverse(gidx.metagraph) r_gidx = gidx.reverse()
assert gidx.format_in_use(0)[0] == 'csr' assert gidx.format_in_use(0)[0] == 'csr'
assert r_gidx.format_in_use(0)[0] == 'csc' assert r_gidx.format_in_use(0)[0] == 'csc'
assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0) assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0)
...@@ -1969,7 +1969,7 @@ def test_reverse(index_dtype): ...@@ -1969,7 +1969,7 @@ def test_reverse(index_dtype):
# force to start with 'csc' # force to start with 'csc'
gidx = gidx.to_format('csc') gidx = gidx.to_format('csc')
gidx = gidx.to_format('any') gidx = gidx.to_format('any')
r_gidx = gidx.reverse(gidx.metagraph) r_gidx = gidx.reverse()
assert gidx.format_in_use(0)[0] == 'csc' assert gidx.format_in_use(0)[0] == 'csc'
assert r_gidx.format_in_use(0)[0] == 'csr' assert r_gidx.format_in_use(0)[0] == 'csr'
assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0) assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0)
...@@ -1985,7 +1985,14 @@ def test_reverse(index_dtype): ...@@ -1985,7 +1985,14 @@ def test_reverse(index_dtype):
('developer', 'develops', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]), ('developer', 'develops', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),
}, index_dtype=index_dtype) }, index_dtype=index_dtype)
gidx = g._graph gidx = g._graph
r_gidx = gidx.reverse(gidx.metagraph) r_gidx = gidx.reverse()
# metagraph
mg = gidx.metagraph
r_mg = r_gidx.metagraph
for etype in range(3):
assert mg.find_edge(etype) == r_mg.find_edge(etype)[::-1]
# three node types and three edge types # three node types and three edge types
assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0) assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0)
assert gidx.number_of_nodes(1) == r_gidx.number_of_nodes(1) assert gidx.number_of_nodes(1) == r_gidx.number_of_nodes(1)
...@@ -2009,7 +2016,7 @@ def test_reverse(index_dtype): ...@@ -2009,7 +2016,7 @@ def test_reverse(index_dtype):
# force to start with 'csr' # force to start with 'csr'
gidx = gidx.to_format('csr') gidx = gidx.to_format('csr')
gidx = gidx.to_format('any') gidx = gidx.to_format('any')
r_gidx = gidx.reverse(gidx.metagraph) r_gidx = gidx.reverse()
# three node types and three edge types # three node types and three edge types
assert gidx.format_in_use(0)[0] == 'csr' assert gidx.format_in_use(0)[0] == 'csr'
assert r_gidx.format_in_use(0)[0] == 'csc' assert r_gidx.format_in_use(0)[0] == 'csc'
...@@ -2039,7 +2046,7 @@ def test_reverse(index_dtype): ...@@ -2039,7 +2046,7 @@ def test_reverse(index_dtype):
# force to start with 'csc' # force to start with 'csc'
gidx = gidx.to_format('csc') gidx = gidx.to_format('csc')
gidx = gidx.to_format('any') gidx = gidx.to_format('any')
r_gidx = gidx.reverse(gidx.metagraph) r_gidx = gidx.reverse()
# three node types and three edge types # three node types and three edge types
assert gidx.format_in_use(0)[0] == 'csc' assert gidx.format_in_use(0)[0] == 'csc'
assert r_gidx.format_in_use(0)[0] == 'csr' assert r_gidx.format_in_use(0)[0] == 'csr'
......
from dgl.backend import gspmm, gsddmm
from utils import parametrize_dtype
import dgl import dgl
import pytest import pytest
import networkx as nx import networkx as nx
import backend as F import backend as F
import numpy as np import numpy as np
np.random.seed(42) np.random.seed(42)
dgl.random.seed(42) dgl.random.seed(42)
def _unsqueeze_if_scalar(x): # used in udf, to unsqueeze the feature if it's scalar
return x if F.ndim(x) > 1 else F.unsqueeze(x, -1)
def _rand_operand_1(shp):
return F.tensor(np.random.rand(*shp))
def _rand_operand_2(shp): # for division op, the divisor should be greater than 1
return F.tensor(np.random.rand(*shp) + 1)
udf_msg = { udf_msg = {
'add': lambda edges: {'m': edges.src['x'] + edges.data['w']}, 'add': lambda edges: {'m': edges.src['x'] + edges.data['w']},
'sub': lambda edges: {'m': edges.src['x'] - edges.data['w']}, 'sub': lambda edges: {'m': edges.src['x'] - edges.data['w']},
'mul': lambda edges: {'m': edges.src['x'] * edges.data['w']}, 'mul': lambda edges: {'m': edges.src['x'] * edges.data['w']},
'div': lambda edges: {'m': edges.src['x'] / edges.data['w']}, 'div': lambda edges: {'m': edges.src['x'] / edges.data['w']},
'copy_u': lambda edges: {'m': edges.src['x']}, 'copy_lhs': lambda edges: {'m': edges.src['x']},
'copy_e': lambda edges: {'m': edges.data['w']} 'copy_rhs': lambda edges: {'m': edges.data['w']}
} }
def select(target, src, edge, dst):
if target == 'u':
return src
elif target == 'v':
return dst
elif target == 'e':
return edge
def binary_op(msg, x, y):
if msg == 'add':
return x + y
elif msg == 'sub':
return x - y
elif msg == 'mul':
return x * y
elif msg == 'div':
return x / y
elif msg == 'dot':
return F.sum(x * y, -1, keepdims=True)
elif msg == 'copy_lhs':
return x
elif msg == 'copy_rhs':
return y
def edge_func(lhs_target, rhs_target, msg):
def foo(edges):
return {
'm': binary_op(
msg,
select(lhs_target, edges.src, edges.data, edges.dst)['x'],
select(rhs_target, edges.src, edges.data, edges.dst)['y']
)
}
return foo
udf_apply_edges = { udf_apply_edges = {
'add': lambda edges: {'m': edges.src['x'] + edges.dst['y']}, lhs_target + '_' + msg + '_' + rhs_target: edge_func(lhs_target, rhs_target, msg)
'sub': lambda edges: {'m': edges.src['x'] - edges.dst['y']}, for lhs_target in ['u', 'v', 'e']
'mul': lambda edges: {'m': edges.src['x'] * edges.dst['y']}, for rhs_target in ['u', 'v', 'e']
'div': lambda edges: {'m': edges.src['x'] / edges.dst['y']}, for msg in ['add', 'sub', 'mul', 'div', 'dot', 'copy_lhs', 'copy_rhs']
'dot': lambda edges: {'m': F.sum(edges.src['x'] * edges.dst['y'], -1, keepdims=True)},
'copy_u': lambda edges: {'m': edges.src['x']},
} }
udf_reduce = { udf_reduce = {
...@@ -41,7 +67,7 @@ udf_reduce = { ...@@ -41,7 +67,7 @@ udf_reduce = {
} }
graphs = [ graphs = [
dgl.rand_graph(30, 0), # dgl.rand_graph(30, 0),
dgl.rand_graph(100, 30), dgl.rand_graph(100, 30),
dgl.rand_graph(100, 3000), dgl.rand_graph(100, 3000),
dgl.rand_bipartite(80, 160, 3000) dgl.rand_bipartite(80, 160, 3000)
...@@ -52,9 +78,9 @@ spmm_shapes = [ ...@@ -52,9 +78,9 @@ spmm_shapes = [
((5, 3, 1, 7), (1, 3, 7, 1)), ((5, 3, 1, 7), (1, 3, 7, 1)),
((1, 3, 1), (4, 1, 3)), ((1, 3, 1), (4, 1, 3)),
((3, 3), (1, 3)), ((3, 3), (1, 3)),
((), (3,)), ((1,), (3,)),
((3,), ()), ((3,), (1,)),
((), ()) ((1,), (1,))
] ]
sddmm_shapes = [ sddmm_shapes = [
...@@ -63,31 +89,59 @@ sddmm_shapes = [ ...@@ -63,31 +89,59 @@ sddmm_shapes = [
((1, 3, 3), (4, 1, 3)), ((1, 3, 3), (4, 1, 3)),
((3, 3), (1, 3)), ((3, 3), (1, 3)),
((3,), (3,)), ((3,), (3,)),
((), ()) ((1,), (1,))
] ]
@pytest.mark.parametrize('g', graphs) @pytest.mark.parametrize('g', graphs)
@pytest.mark.parametrize('shp', spmm_shapes) @pytest.mark.parametrize('shp', spmm_shapes)
@pytest.mark.parametrize('msg', ['add', 'sub', 'mul', 'div', 'copy_u', 'copy_e']) @pytest.mark.parametrize('msg', ['add', 'sub', 'mul', 'div', 'copy_lhs', 'copy_rhs'])
@pytest.mark.parametrize('reducer', ['sum', 'min', 'max']) @pytest.mark.parametrize('reducer', ['sum', 'min', 'max'])
def test_spmm(g, shp, msg, reducer): @parametrize_dtype
def test_spmm(g, shp, msg, reducer, index_dtype):
if dgl.backend.backend_name == 'tensorflow' and (reducer in ['min', 'max'] or index_dtype == 'int32'):
pytest.skip() # tensorflow dlpack has problem writing into int32 arrays on GPU.
if index_dtype == 'int32':
g = g.int()
else:
g = g.long()
print(g) print(g)
u = _rand_operand_1((g.number_of_src_nodes(),) + shp[0]) print(g.idtype)
e = _rand_operand_2((g.number_of_edges(),) + shp[1])
print('u shape: {}, e shape: {}'.format(F.shape(u), F.shape(e)))
g.srcdata['x'] = _unsqueeze_if_scalar(u)
g.edata['w'] = _unsqueeze_if_scalar(e)
hu = F.tensor(np.random.rand(*((g.number_of_src_nodes(),) + shp[0])) + 1)
he = F.tensor(np.random.rand(*((g.number_of_edges(),) + shp[1])) + 1)
print('u shape: {}, e shape: {}'.format(F.shape(hu), F.shape(he)))
g.srcdata['x'] = F.attach_grad(F.clone(hu))
g.edata['w'] = F.attach_grad(F.clone(he))
print('SpMM(message func: {}, reduce func: {})'.format(msg, reducer)) print('SpMM(message func: {}, reduce func: {})'.format(msg, reducer))
v = dgl.gspmm(g, msg, reducer, u, e)[0]
non_degree_indices = F.tensor( u = F.attach_grad(F.clone(hu))
np.nonzero(F.asnumpy(g.in_degrees()) != 0)[0]) e = F.attach_grad(F.clone(he))
v = F.gather_row(v, non_degree_indices) with F.record_grad():
g.update_all(udf_msg[msg], udf_reduce[reducer]) v = gspmm(g, msg, reducer, u, e)
if 'v' in g.dstdata: non_degree_indices = F.tensor(
v1 = F.gather_row(g.dstdata['v'], non_degree_indices) np.nonzero(F.asnumpy(g.in_degrees()) != 0)[0])
assert F.allclose(v, v1, rtol=1e-3, atol=1e-3) v = F.gather_row(v, non_degree_indices)
print('passed') if g.number_of_edges() > 0:
F.backward(F.reduce_sum(v))
if msg != 'copy_rhs':
grad_u = F.grad(u)
if msg != 'copy_lhs':
grad_e = F.grad(e)
with F.record_grad():
g.update_all(udf_msg[msg], udf_reduce[reducer])
if g.number_of_edges() > 0:
v1 = F.gather_row(g.dstdata['v'], non_degree_indices)
assert F.allclose(v, v1, rtol=1e-3, atol=1e-3)
print('forward passed')
F.backward(F.reduce_sum(v1))
if msg != 'copy_rhs':
assert F.allclose(F.grad(g.srcdata['x']), grad_u)
if msg != 'copy_lhs':
assert F.allclose(F.grad(g.edata['w']), grad_e)
print('backward passed')
g.srcdata.pop('x') g.srcdata.pop('x')
g.edata.pop('w') g.edata.pop('w')
...@@ -95,26 +149,78 @@ def test_spmm(g, shp, msg, reducer): ...@@ -95,26 +149,78 @@ def test_spmm(g, shp, msg, reducer):
@pytest.mark.parametrize('g', graphs) @pytest.mark.parametrize('g', graphs)
@pytest.mark.parametrize('shp', sddmm_shapes) @pytest.mark.parametrize('shp', sddmm_shapes)
@pytest.mark.parametrize('msg', ['add', 'sub', 'mul', 'div', 'dot', 'copy_u']) @pytest.mark.parametrize('lhs_target', ['u', 'v', 'e'])
def test_sddmm(g, shp, msg): @pytest.mark.parametrize('rhs_target', ['u', 'v', 'e'])
@pytest.mark.parametrize('msg', ['add', 'sub', 'mul', 'div', 'dot', 'copy_lhs', 'copy_rhs'])
@parametrize_dtype
def test_sddmm(g, shp, lhs_target, rhs_target, msg, index_dtype):
if dgl.backend.backend_name == 'mxnet' and g.number_of_edges() == 0: if dgl.backend.backend_name == 'mxnet' and g.number_of_edges() == 0:
pytest.skip() # mxnet do not support zero shape tensor pytest.skip() # mxnet do not support zero shape tensor
if dgl.backend.backend_name == 'tensorflow' and index_dtype == 'int32':
pytest.skip() # tensorflow dlpack has problem with int32 ndarray.
if index_dtype == 'int32':
g = g.int()
else:
g = g.long()
print(g) print(g)
u = _rand_operand_1((g.number_of_src_nodes(),) + shp[0]) print(g.idtype)
v = _rand_operand_2((g.number_of_dst_nodes(),) + shp[1])
print('u shape: {}, v shape: {}'.format(F.shape(u), F.shape(v))) len_lhs = select(
g.srcdata['x'] = _unsqueeze_if_scalar(u) lhs_target,
g.dstdata['y'] = _unsqueeze_if_scalar(v) g.number_of_src_nodes(),
g.number_of_edges(),
print('SDDMM(message func: {})'.format(msg)) g.number_of_dst_nodes())
e = dgl.gsddmm(g, msg, u, v) lhs_shp = (len_lhs,) + shp[0]
g.apply_edges(udf_apply_edges[msg]) len_rhs = select(
if 'm' in g.edata: rhs_target,
e1 = g.edata['m'] g.number_of_src_nodes(),
assert F.allclose(e, e1, rtol=1e-3, atol=1e-3) g.number_of_edges(),
print('passed') g.number_of_dst_nodes())
rhs_shp = (len_rhs,) + shp[1]
g.srcdata.pop('x') feat_lhs = F.tensor(np.random.rand(*lhs_shp) + 1)
g.dstdata.pop('y') feat_rhs = F.tensor(np.random.rand(*rhs_shp) + 1)
print('lhs shape: {}, rhs shape: {}'.format(F.shape(feat_lhs), F.shape(feat_rhs)))
lhs_frame = select(
lhs_target,
g.srcdata,
g.edata,
g.dstdata)
rhs_frame = select(
rhs_target,
g.srcdata,
g.edata,
g.dstdata)
lhs_frame['x'] = F.attach_grad(F.clone(feat_lhs))
rhs_frame['y'] = F.attach_grad(F.clone(feat_rhs))
msg_func = lhs_target + '_' + msg + '_' + rhs_target
print('SDDMM(message func: {})'.format(msg_func))
lhs = F.attach_grad(F.clone(feat_lhs))
rhs = F.attach_grad(F.clone(feat_rhs))
with F.record_grad():
e = gsddmm(g, msg, lhs, rhs, lhs_target=lhs_target, rhs_target=rhs_target)
F.backward(F.reduce_sum(e))
grad_lhs = F.grad(lhs)
grad_rhs = F.grad(rhs)
with F.record_grad():
g.apply_edges(udf_apply_edges[msg_func])
if g.number_of_edges() > 0:
e1 = g.edata['m']
assert F.allclose(e, e1, rtol=1e-3, atol=1e-3)
print('forward passed')
F.backward(F.reduce_sum(e1))
if msg != 'copy_rhs':
assert F.allclose(F.grad(lhs_frame['x']), grad_lhs)
if msg != 'copy_lhs':
assert F.allclose(F.grad(rhs_frame['y']), grad_rhs)
print('backward passed')
lhs_frame.pop('x')
rhs_frame.pop('y')
if 'm' in g.edata: g.edata.pop('m') if 'm' in g.edata: g.edata.pop('m')
if __name__ == '__main__':
test_spmm(graphs[0], spmm_shapes[5], 'copy_lhs', 'sum')
...@@ -278,7 +278,7 @@ def uniform_attention(g, shape): ...@@ -278,7 +278,7 @@ def uniform_attention(g, shape):
def test_edge_softmax(): def test_edge_softmax():
# Basic # Basic
g = dgl.DGLGraph(nx.path_graph(3)) g = dgl.graph(nx.path_graph(3))
edata = F.ones((g.number_of_edges(), 1)) edata = F.ones((g.number_of_edges(), 1))
a = nn.edge_softmax(g, edata) a = nn.edge_softmax(g, edata)
assert len(g.ndata) == 0 assert len(g.ndata) == 0
...@@ -293,12 +293,7 @@ def test_edge_softmax(): ...@@ -293,12 +293,7 @@ def test_edge_softmax():
assert F.allclose(a, uniform_attention(g, a.shape)) assert F.allclose(a, uniform_attention(g, a.shape))
# Test both forward and backward with PyTorch built-in softmax. # Test both forward and backward with PyTorch built-in softmax.
g = dgl.DGLGraph() g = dgl.rand_graph(30, 900)
g.add_nodes(30)
# build a complete graph
for i in range(30):
for j in range(30):
g.add_edge(i, j)
score = F.randn((900, 1)) score = F.randn((900, 1))
score.requires_grad_() score.requires_grad_()
...@@ -317,6 +312,7 @@ def test_edge_softmax(): ...@@ -317,6 +312,7 @@ def test_edge_softmax():
assert F.allclose(score.grad, grad_score) assert F.allclose(score.grad, grad_score)
print(score.grad[:10], grad_score[:10]) print(score.grad[:10], grad_score[:10])
"""
# Test 2 # Test 2
def generate_rand_graph(n, m=None, ctor=dgl.DGLGraph): def generate_rand_graph(n, m=None, ctor=dgl.DGLGraph):
if m is None: if m is None:
...@@ -340,14 +336,10 @@ def test_edge_softmax(): ...@@ -340,14 +336,10 @@ def test_edge_softmax():
assert len(g.dstdata) == 0 assert len(g.dstdata) == 0
assert len(g.edata) == 2 assert len(g.edata) == 2
assert F.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend assert F.allclose(a1.grad, a2.grad, rtol=1e-4, atol=1e-4) # Follow tolerance in unittest backend
"""
def test_partial_edge_softmax(): def test_partial_edge_softmax():
g = dgl.DGLGraph() g = dgl.rand_graph(30, 900)
g.add_nodes(30)
# build a complete graph
for i in range(30):
for j in range(30):
g.add_edge(i, j)
score = F.randn((300, 1)) score = F.randn((300, 1))
score.requires_grad_() score.requires_grad_()
...@@ -446,7 +438,7 @@ def test_rgcn(): ...@@ -446,7 +438,7 @@ def test_rgcn():
def test_gat_conv(): def test_gat_conv():
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.rand_graph(100, 1000)
gat = nn.GATConv(5, 2, 4) gat = nn.GATConv(5, 2, 4)
feat = F.randn((100, 5)) feat = F.randn((100, 5))
gat = gat.to(ctx) gat = gat.to(ctx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment