/*! * Copyright (c) 2018 by Contributors * @file graph/sampling/randomwalk_with_restart_cpu.cc * @brief DGL sampler - CPU implementation of metapath-based random walk with * restart with OpenMP */ #include #include #include #include #include #include "metapath_randomwalk.h" #include "randomwalks_cpu.h" #include "randomwalks_impl.h" namespace dgl { using namespace dgl::runtime; using namespace dgl::aten; namespace sampling { namespace impl { template std::pair RandomWalkWithRestart( const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath, const std::vector &prob, double restart_prob) { TerminatePredicate terminate = [restart_prob](IdxType *data, dgl_id_t curr, int64_t len) { return RandomEngine::ThreadLocal()->Uniform() < restart_prob; }; return MetapathBasedRandomWalk( hg, seeds, metapath, prob, terminate); } template std::pair RandomWalkWithRestart( const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath, const std::vector &prob, double restart_prob); template std::pair RandomWalkWithRestart( const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath, const std::vector &prob, double restart_prob); template std::pair RandomWalkWithStepwiseRestart( const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath, const std::vector &prob, FloatArray restart_prob) { std::pair result; ATEN_FLOAT_TYPE_SWITCH(restart_prob->dtype, DType, "restart probability", { DType *restart_prob_data = static_cast(restart_prob->data); TerminatePredicate terminate = [restart_prob_data](IdxType *data, dgl_id_t curr, int64_t len) { return RandomEngine::ThreadLocal()->Uniform() < restart_prob_data[len]; }; result = MetapathBasedRandomWalk( hg, seeds, metapath, prob, terminate); }); return result; } template std::pair RandomWalkWithStepwiseRestart( const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath, const std::vector &prob, FloatArray restart_prob); template std::pair RandomWalkWithStepwiseRestart( const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath, const std::vector &prob, FloatArray restart_prob); }; // namespace impl }; // namespace sampling }; // namespace dgl