Unverified Commit 2150fcaf authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Feature] Added heterograph support to SDDMM_COO and clean up SpMM and SDDMM hetero kernels (#3449)



* Added SDDMMCOO_hetero support

* removed redundant CUDA kernels

* added benchmark for regression test

* fix

* fixed bug for single src node type
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent e053df79
import time
import dgl
import torch
import numpy as np
import dgl.function as fn
from .. import utils
@utils.benchmark('time', timeout=600)
@utils.parametrize('num_relations', [5, 50, 500])
@utils.parametrize('format', ['coo', 'csr'])
@utils.parametrize('feat_size', [8, 128, 512])
@utils.parametrize('reduce_type', ['u->e']) #, 'e->u'])
def track_time( num_relations, format, feat_size, reduce_type):
device = utils.get_bench_device()
dd = {}
candidate_edges = [dgl.data.CoraGraphDataset(verbose=False)[0].edges(), dgl.data.PubmedGraphDataset(verbose=False)[
0].edges(), dgl.data.CiteseerGraphDataset(verbose=False)[0].edges()]
for i in range(num_relations):
dd[('n1', 'e_{}'.format(i), 'n2')] = candidate_edges[i %
len(candidate_edges)]
graph = dgl.heterograph(dd)
graph = graph.to(device)
graph.nodes['n1'].data['h'] = torch.randn(
(graph.num_nodes('n1'), feat_size), device=device)
graph.nodes['n2'].data['h'] = torch.randn(
(graph.num_nodes('n2'), feat_size), device=device)
reduce_builtin_dict = {
'u->e': fn.copy_u('h', 'x'),
# 'e->u': fn.copy_e('h', 'x'),
}
# dry run
for i in range(3):
graph.apply_edges(reduce_builtin_dict[reduce_type])
# timing
with utils.Timer() as t:
for i in range(10):
graph.apply_edges(reduce_builtin_dict[reduce_type])
return t.elapsed_secs / 10
...@@ -224,9 +224,13 @@ def data_dict_to_list(graph, data_dict, func, target): ...@@ -224,9 +224,13 @@ def data_dict_to_list(graph, data_dict, func, target):
else: else:
if target == 'u': if target == 'u':
lhs_list = [None] * graph._graph.number_of_ntypes() lhs_list = [None] * graph._graph.number_of_ntypes()
for srctype, _, _ in graph.canonical_etypes: if not isinstance(data_dict, dict):
src_id = graph.get_ntype_id(srctype) src_id, _ = graph._graph.metagraph.find_edge(0)
lhs_list[src_id] = data_dict[srctype] lhs_list[src_id] = data_dict
else:
for srctype, _, _ in graph.canonical_etypes:
src_id = graph.get_ntype_id(srctype)
lhs_list[src_id] = data_dict[srctype]
return lhs_list return lhs_list
else: # target == 'e': else: # target == 'e':
rhs_list = [None] * graph._graph.number_of_etypes() rhs_list = [None] * graph._graph.number_of_etypes()
......
...@@ -189,6 +189,34 @@ void SDDMMCoo(const std::string& op, ...@@ -189,6 +189,34 @@ void SDDMMCoo(const std::string& op,
}); });
} }
/*! \brief Generalized SDDMM on Coo format with Heterograph support. */
template <int XPU, typename IdType, int bits>
void SDDMMCooHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_nid,
const std::vector<dgl_type_t>& rhs_nid) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM for each relation type */
for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {
COOMatrix coo = vec_coo[etype];
NDArray lhs = vec_lhs[lhs_nid[etype]];
NDArray rhs = vec_rhs[rhs_nid[etype]];
NDArray out = vec_out[etype];
cpu::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
}
});
});
});
}
template void SDDMMCoo<kDLCPU, int32_t, 16>( template void SDDMMCoo<kDLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
...@@ -214,6 +242,49 @@ template void SDDMMCoo<kDLCPU, int64_t, 64>( ...@@ -214,6 +242,49 @@ template void SDDMMCoo<kDLCPU, int64_t, 64>(
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCooHetero<kDLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -106,20 +106,17 @@ void SDDMMCsrHetero(const std::string& op, ...@@ -106,20 +106,17 @@ void SDDMMCsrHetero(const std::string& op,
int rhs_target, int rhs_target,
const std::vector<dgl_type_t>& lhs_eid, const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid) { const std::vector<dgl_type_t>& rhs_eid) {
// TODO(Israt): Resolve PR - https://github.com/dmlc/dgl/issues/2995
// to use maxstream > 1
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
SWITCH_BITS(bits, DType, { SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM for each relation type */ /* Call SDDMM CUDA kernel for each relation type sequentially */
for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) { for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
CSRMatrix csr = vec_csr[etype]; CSRMatrix csr = vec_csr[etype];
NDArray lhs = vec_lhs[lhs_eid[etype]]; NDArray lhs = vec_lhs[lhs_eid[etype]];
NDArray rhs = vec_rhs[rhs_eid[etype]]; NDArray rhs = vec_rhs[rhs_eid[etype]];
NDArray out = vec_out[etype]; NDArray out = vec_out[etype];
cuda::SDDMMCsrHetero<IdType, DType, Op, LhsTarget, RhsTarget>( cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out, thr_entry->stream); bcast, csr, lhs, rhs, out);
} }
}); });
}); });
...@@ -148,6 +145,41 @@ void SDDMMCoo(const std::string& op, ...@@ -148,6 +145,41 @@ void SDDMMCoo(const std::string& op,
}); });
} }
/*!
* \brief CUDA implementation of g-SDDMM on heterograph using
Csr format.
*/
template <int XPU, typename IdType, int bits>
void SDDMMCooHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM CUDA kernel for each relation type sequentially */
for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
COOMatrix coo = vec_coo[etype];
NDArray lhs = vec_lhs[lhs_eid[etype]];
NDArray rhs = vec_rhs[rhs_eid[etype]];
NDArray out = vec_out[etype];
cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, coo, lhs, rhs, out);
}
});
});
});
}
template void SDDMMCsr<kDLGPU, int32_t, 16>( template void SDDMMCsr<kDLGPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
...@@ -216,7 +248,6 @@ template void SDDMMCsrHetero<kDLGPU, int64_t, 64>( ...@@ -216,7 +248,6 @@ template void SDDMMCsrHetero<kDLGPU, int64_t, 64>(
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCoo<kDLGPU, int32_t, 16>( template void SDDMMCoo<kDLGPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
...@@ -242,5 +273,48 @@ template void SDDMMCoo<kDLGPU, int64_t, 64>( ...@@ -242,5 +273,48 @@ template void SDDMMCoo<kDLGPU, int64_t, 64>(
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCooHetero<kDLGPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -310,57 +310,6 @@ void SDDMMCsr( ...@@ -310,57 +310,6 @@ void SDDMMCsr(
}); });
} }
/*!
* \brief CUDA implementation of g-SDDMM on heterograph using Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param lhs The left hand side operand feature.
* \param rhs The right hand size operand feature.
* \param out The result feature on edges.
* \param stream cudaStream id.
*/
template <typename Idx, typename DType, typename Op,
int LhsTarget = 0, int RhsTarget = 2>
void SDDMMCsrHetero(
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray lhs,
NDArray rhs,
NDArray out,
cudaStream_t strm_id) {
const Idx *indptr = csr.indptr.Ptr<Idx>();
const Idx *indices = csr.indices.Ptr<Idx>();
const Idx *edge_map = csr.data.Ptr<Idx>();
const DType *lhs_data = lhs.Ptr<DType>();
const DType *rhs_data = rhs.Ptr<DType>();
DType *out_data = out.Ptr<DType>();
int64_t N = csr.num_rows, M = csr.num_cols, E = csr.indices->shape[0];
int64_t *lhs_off = nullptr, *rhs_off = nullptr;
int64_t len = bcast.out_len,
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
int64_t reduce_dim = bcast.reduce_size;
const int ntx = FindNumThreads(len);
const int nty = CUDA_MAX_NUM_THREADS / ntx;
const int nbx = (len + ntx - 1) / ntx;
const int nby = FindNumBlocks<'y'>((E + nty - 1) / nty);
const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty);
const bool use_idx = !IsNullArray(csr.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
CUDA_KERNEL_CALL((SDDMMCsrKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
nblks, nthrs, 0, strm_id,
lhs_data, rhs_data, out_data,
indptr, indices, edge_map,
N, M, E, reduce_dim,
lhs_off, rhs_off,
lhs_len, rhs_len, len);
});
}
} // namespace cuda } // namespace cuda
} // namespace aten } // namespace aten
......
...@@ -521,7 +521,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -521,7 +521,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const std::vector<dgl_type_t>& out_ntids) { // output node type id const std::vector<dgl_type_t>& out_ntids) { // output node type id
bool is_scalar_efeat = vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0]; bool is_scalar_efeat = vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
bool use_efeat = op != "copy_lhs"; bool use_efeat = op != "copy_lhs";
// TODO(Israt): Resolve PR-https://github.com/dmlc/dgl/issues/2995 and use multistream
auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx); auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
SWITCH_BITS(bits, DType, { SWITCH_BITS(bits, DType, {
std::vector<DType*> trans_out(vec_out.size(), NULL); std::vector<DType*> trans_out(vec_out.size(), NULL);
...@@ -603,34 +602,27 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -603,34 +602,27 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
NDArray efeat = (vec_efeat.size() == 0) ? NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype]; NullArray() : vec_efeat[etype];
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
cuda::SpMMCsrHetero<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >( cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
bcast, csr, ufeat, efeat, vec_out[dst_id], bcast, csr, ufeat, efeat, vec_out[dst_id], NullArray(), NullArray());
NullArray(), NullArray(), thr_entry->stream);
}); });
} }
} else if (reduce == "max") { } else if (reduce == "max") {
// SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ? NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id]; NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype]; NullArray() : vec_efeat[etype];
cuda::SpMMCsrHetero<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >( cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
bcast, csr, ufeat, efeat, vec_out[dst_id], bcast, csr, ufeat, efeat, vec_out[dst_id], out_aux[0], out_aux[1]);
out_aux[0], out_aux[1], thr_entry->stream);
}); });
// });
} else if (reduce == "min") { } else if (reduce == "min") {
// SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ? NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id]; NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype]; NullArray() : vec_efeat[etype];
cuda::SpMMCsrHetero<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >( cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
bcast, csr, ufeat, efeat, vec_out[dst_id], bcast, csr, ufeat, efeat, vec_out[dst_id], out_aux[0], out_aux[1]);
out_aux[0], out_aux[1], thr_entry->stream);
// });
}); });
} else { } else {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
......
...@@ -317,64 +317,6 @@ void SpMMCsr( ...@@ -317,64 +317,6 @@ void SpMMCsr(
} }
/*!
* \brief CUDA implementation of g-SpMM on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param ufeat The feature on source nodes.
* \param efeat The feature on edges.
* \param out The result feature on destination nodes.
* \param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* \param arge Arg-Min/Max on edges. which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* \param stream cudaStream id.
*/
template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp>
void SpMMCsrHetero(
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat, NDArray efeat,
NDArray out, NDArray argu, NDArray arge,
cudaStream_t strm_id) {
const Idx *indptr = csr.indptr.Ptr<Idx>();
const Idx *indices = csr.indices.Ptr<Idx>();
const Idx *edge_map = csr.data.Ptr<Idx>();
const DType *ufeat_data = ufeat.Ptr<DType>();
const DType *efeat_data = efeat.Ptr<DType>();
DType *out_data = out.Ptr<DType>();
Idx* argu_data = argu.Ptr<Idx>();
Idx* arge_data = arge.Ptr<Idx>();
int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
int64_t len = bcast.out_len,
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
const int ntx = FindNumThreads(len);
const int nty = CUDA_MAX_NUM_THREADS / ntx;
const int nbx = (len + ntx - 1) / ntx;
const int nby = FindNumBlocks<'y'>((csr.num_rows + nty - 1) / nty);
//LOG(INFO) << "nblks=(" << nbx << ", " << nby << ") nthrs=(" << ntx << ", " << nty << ")";
const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty);
const bool use_idx = !IsNullArray(csr.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {
CUDA_KERNEL_CALL((SpMMCsrKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, strm_id,
ufeat_data, efeat_data, out_data, argu_data, arge_data,
indptr, indices, edge_map,
csr.num_rows, csr.num_cols,
ubcast_off, ebcast_off,
lhs_len, rhs_len, len)
});
}
} // namespace cuda } // namespace cuda
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -152,14 +152,11 @@ void SDDMMHetero(const std::string& op, ...@@ -152,14 +152,11 @@ void SDDMMHetero(const std::string& op,
std::vector<NDArray> out, std::vector<NDArray> out,
int lhs_target, int lhs_target,
int rhs_target) { int rhs_target) {
// TODO(Israt): change it to COO_CODE SparseFormat format = graph->SelectFormat(0, COO_CODE);
SparseFormat format = graph->SelectFormat(0, CSR_CODE);
std::vector<CSRMatrix> vec_csr;
std::vector<dgl_type_t> lhs_eid; std::vector<dgl_type_t> lhs_eid;
std::vector<dgl_type_t> rhs_eid; std::vector<dgl_type_t> rhs_eid;
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_csr.push_back(graph->GetCSRMatrix(etype));
lhs_eid.push_back(get_typeid_by_target(graph, lhs_target, etype)); lhs_eid.push_back(get_typeid_by_target(graph, lhs_target, etype));
rhs_eid.push_back(get_typeid_by_target(graph, rhs_target, etype)); rhs_eid.push_back(get_typeid_by_target(graph, rhs_target, etype));
} }
...@@ -169,14 +166,25 @@ void SDDMMHetero(const std::string& op, ...@@ -169,14 +166,25 @@ void SDDMMHetero(const std::string& op,
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out[rhs_eid[0]]->dtype, bits, "Feature data", { ATEN_FLOAT_BITS_SWITCH(out[rhs_eid[0]]->dtype, bits, "Feature data", {
if (format == SparseFormat::kCSR) { if (format == SparseFormat::kCSR) {
std::vector<CSRMatrix> vec_csr;
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_csr.push_back(graph->GetCSRMatrix(etype));
}
SDDMMCsrHetero<XPU, IdType, bits>( SDDMMCsrHetero<XPU, IdType, bits>(
op, bcast, vec_csr, op, bcast, vec_csr,
lhs, rhs, out, lhs_target, rhs_target, lhs, rhs, out, lhs_target, rhs_target,
lhs_eid, rhs_eid); lhs_eid, rhs_eid);
} else if (format == SparseFormat::kCOO) {
std::vector<COOMatrix> vec_coo;
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_coo.push_back(graph->GetCOOMatrix(etype));
}
SDDMMCooHetero<XPU, IdType, bits>(
op, bcast, vec_coo,
lhs, rhs, out, lhs_target, rhs_target,
lhs_eid, rhs_eid);
} else { } else {
// TODO(Israt): Add support for COO format LOG(FATAL) << "SDDMM only supports CSR and COO formats";
LOG(FATAL) << "SDDMM only supports CSC format for graphs with number "
<< "of relation types > 1";
} }
}); });
}); });
......
...@@ -96,6 +96,22 @@ void SDDMMCoo(const std::string& op, ...@@ -96,6 +96,22 @@ void SDDMMCoo(const std::string& op,
int lhs_target, int lhs_target,
int rhs_target); int rhs_target);
/*!
* \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo
format with heterograph support.
*/
template <int XPU, typename IdType, int bits>
void SDDMMCooHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid);
/*! /*!
* \brief Segment reduce. * \brief Segment reduce.
*/ */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment