Unverified Commit ddc92f8d authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

speed up random walks (#3158)


Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent d1cc0969
...@@ -57,6 +57,7 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep( ...@@ -57,6 +57,7 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
dgl_id_t curr, dgl_id_t curr,
int64_t len, int64_t len,
const std::vector<CSRMatrix> &edges_by_type, const std::vector<CSRMatrix> &edges_by_type,
const std::vector<bool> &csr_has_data,
const IdxType *metapath_data, const IdxType *metapath_data,
const std::vector<FloatArray> &prob, const std::vector<FloatArray> &prob,
TerminatePredicate<IdxType> terminate) { TerminatePredicate<IdxType> terminate) {
...@@ -70,7 +71,7 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep( ...@@ -70,7 +71,7 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
const CSRMatrix &csr = edges_by_type[etype]; const CSRMatrix &csr = edges_by_type[etype];
const IdxType *offsets = csr.indptr.Ptr<IdxType>(); const IdxType *offsets = csr.indptr.Ptr<IdxType>();
const IdxType *all_succ = csr.indices.Ptr<IdxType>(); const IdxType *all_succ = csr.indices.Ptr<IdxType>();
const IdxType *all_eids = CSRHasData(csr) ? csr.data.Ptr<IdxType>() : nullptr; const IdxType *all_eids = csr_has_data[etype] ? csr.data.Ptr<IdxType>() : nullptr;
const IdxType *succ = all_succ + offsets[curr]; const IdxType *succ = all_succ + offsets[curr];
const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr; const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr;
...@@ -124,6 +125,7 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform( ...@@ -124,6 +125,7 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform(
dgl_id_t curr, dgl_id_t curr,
int64_t len, int64_t len,
const std::vector<CSRMatrix> &edges_by_type, const std::vector<CSRMatrix> &edges_by_type,
const std::vector<bool> &csr_has_data,
const IdxType *metapath_data, const IdxType *metapath_data,
const std::vector<FloatArray> &prob, const std::vector<FloatArray> &prob,
TerminatePredicate<IdxType> terminate) { TerminatePredicate<IdxType> terminate) {
...@@ -137,7 +139,7 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform( ...@@ -137,7 +139,7 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform(
const CSRMatrix &csr = edges_by_type[etype]; const CSRMatrix &csr = edges_by_type[etype];
const IdxType *offsets = csr.indptr.Ptr<IdxType>(); const IdxType *offsets = csr.indptr.Ptr<IdxType>();
const IdxType *all_succ = csr.indices.Ptr<IdxType>(); const IdxType *all_succ = csr.indices.Ptr<IdxType>();
const IdxType *all_eids = CSRHasData(csr) ? csr.data.Ptr<IdxType>() : nullptr; const IdxType *all_eids = csr_has_data[etype] ? csr.data.Ptr<IdxType>() : nullptr;
const IdxType *succ = all_succ + offsets[curr]; const IdxType *succ = all_succ + offsets[curr];
const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr; const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr;
...@@ -179,9 +181,14 @@ std::pair<IdArray, IdArray> MetapathBasedRandomWalk( ...@@ -179,9 +181,14 @@ std::pair<IdArray, IdArray> MetapathBasedRandomWalk(
// This forces the heterograph to materialize all OutCSR's before the OpenMP loop; // This forces the heterograph to materialize all OutCSR's before the OpenMP loop;
// otherwise data races will happen. // otherwise data races will happen.
// TODO(BarclayII): should we later on materialize COO/CSR/CSC anyway unless told otherwise? // TODO(BarclayII): should we later on materialize COO/CSR/CSC anyway unless told otherwise?
std::vector<CSRMatrix> edges_by_type; int64_t num_etypes = hg->NumEdgeTypes();
for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) std::vector<CSRMatrix> edges_by_type(num_etypes);
edges_by_type.push_back(hg->GetCSRMatrix(etype)); std::vector<bool> csr_has_data(num_etypes);
for (int64_t etype = 0; etype < num_etypes; ++etype) {
const CSRMatrix &csr = hg->GetCSRMatrix(etype);
edges_by_type[etype] = csr;
csr_has_data[etype] = CSRHasData(csr);
}
// Hoist the check for Uniform vs Non uniform edge distribution // Hoist the check for Uniform vs Non uniform edge distribution
// to avoid putting it on the hot path // to avoid putting it on the hot path
...@@ -194,18 +201,18 @@ std::pair<IdArray, IdArray> MetapathBasedRandomWalk( ...@@ -194,18 +201,18 @@ std::pair<IdArray, IdArray> MetapathBasedRandomWalk(
} }
if (!isUniform) { if (!isUniform) {
StepFunc<IdxType> step = StepFunc<IdxType> step =
[&edges_by_type, metapath_data, &prob, terminate] [&edges_by_type, &csr_has_data, metapath_data, &prob, terminate]
(IdxType *data, dgl_id_t curr, int64_t len) { (IdxType *data, dgl_id_t curr, int64_t len) {
return MetapathRandomWalkStep<XPU, IdxType>( return MetapathRandomWalkStep<XPU, IdxType>(
data, curr, len, edges_by_type, metapath_data, prob, terminate); data, curr, len, edges_by_type, csr_has_data, metapath_data, prob, terminate);
}; };
return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step); return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step);
} else { } else {
StepFunc<IdxType> step = StepFunc<IdxType> step =
[&edges_by_type, metapath_data, &prob, terminate] [&edges_by_type, &csr_has_data, metapath_data, &prob, terminate]
(IdxType *data, dgl_id_t curr, int64_t len) { (IdxType *data, dgl_id_t curr, int64_t len) {
return MetapathRandomWalkStepUniform<XPU, IdxType>( return MetapathRandomWalkStepUniform<XPU, IdxType>(
data, curr, len, edges_by_type, metapath_data, prob, terminate); data, curr, len, edges_by_type, csr_has_data, metapath_data, prob, terminate);
}; };
return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step); return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step);
} }
......
...@@ -68,11 +68,11 @@ bool has_edge_between(const CSRMatrix &csr, dgl_id_t u, ...@@ -68,11 +68,11 @@ bool has_edge_between(const CSRMatrix &csr, dgl_id_t u,
template <DLDeviceType XPU, typename IdxType> template <DLDeviceType XPU, typename IdxType>
std::tuple<dgl_id_t, dgl_id_t, bool> Node2vecRandomWalkStep( std::tuple<dgl_id_t, dgl_id_t, bool> Node2vecRandomWalkStep(
IdxType *data, dgl_id_t curr, dgl_id_t pre, const double p, const double q, IdxType *data, dgl_id_t curr, dgl_id_t pre, const double p, const double q,
int64_t len, const CSRMatrix &csr, const FloatArray &probs, int64_t len, const CSRMatrix &csr, bool csr_has_data, const FloatArray &probs,
TerminatePredicate<IdxType> terminate) { TerminatePredicate<IdxType> terminate) {
const IdxType *offsets = csr.indptr.Ptr<IdxType>(); const IdxType *offsets = csr.indptr.Ptr<IdxType>();
const IdxType *all_succ = csr.indices.Ptr<IdxType>(); const IdxType *all_succ = csr.indices.Ptr<IdxType>();
const IdxType *all_eids = CSRHasData(csr) ? csr.data.Ptr<IdxType>() : nullptr; const IdxType *all_eids = csr_has_data ? csr.data.Ptr<IdxType>() : nullptr;
const IdxType *succ = all_succ + offsets[curr]; const IdxType *succ = all_succ + offsets[curr];
const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr; const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr;
...@@ -153,13 +153,14 @@ std::pair<IdArray, IdArray> Node2vecRandomWalk( ...@@ -153,13 +153,14 @@ std::pair<IdArray, IdArray> Node2vecRandomWalk(
const int64_t max_num_steps, const FloatArray &prob, const int64_t max_num_steps, const FloatArray &prob,
TerminatePredicate<IdxType> terminate) { TerminatePredicate<IdxType> terminate) {
const CSRMatrix &edges = g->GetCSRMatrix(0); // homogeneous graph. const CSRMatrix &edges = g->GetCSRMatrix(0); // homogeneous graph.
bool csr_has_data = CSRHasData(edges);
StepFunc<IdxType> step = StepFunc<IdxType> step =
[&edges, &prob, p, q, terminate] [&edges, csr_has_data, &prob, p, q, terminate]
(IdxType *data, dgl_id_t curr, int64_t len) { (IdxType *data, dgl_id_t curr, int64_t len) {
dgl_id_t pre = (len != 0) ? data[len - 1] : curr; dgl_id_t pre = (len != 0) ? data[len - 1] : curr;
return Node2vecRandomWalkStep<XPU, IdxType>(data, curr, pre, p, q, len, return Node2vecRandomWalkStep<XPU, IdxType>(data, curr, pre, p, q, len,
edges, prob, terminate); edges, csr_has_data, prob, terminate);
}; };
return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step); return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment