"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "3430fd6849ce9a80a2dd5b72fbaf38357a4a7060"
Unverified Commit b53f9365 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Graphbolt] SampleNeighbors code refactor: add pick_fn (#6086)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 59a0aa97
...@@ -222,12 +222,9 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -222,12 +222,9 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
const std::string& shared_memory_name); const std::string& shared_memory_name);
private: private:
template <SamplerType S> template <typename PickFn>
c10::intrusive_ptr<SampledSubgraph> SampleNeighborsImpl( c10::intrusive_ptr<SampledSubgraph> SampleNeighborsImpl(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts, const torch::Tensor& nodes, bool return_eids, PickFn pick_fn) const;
bool replace, bool return_eids,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<S> args) const;
/** /**
* @brief Build a CSCSamplingGraph from shared memory tensors. * @brief Build a CSCSamplingGraph from shared memory tensors.
...@@ -329,6 +326,20 @@ torch::Tensor Pick( ...@@ -329,6 +326,20 @@ torch::Tensor Pick(
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args); const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args);
template <>
torch::Tensor Pick<SamplerType::NEIGHBOR>(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args);
template <>
torch::Tensor Pick<SamplerType::LABOR>(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args);
/** /**
* @brief Picks a specified number of neighbors for a node per edge type, * @brief Picks a specified number of neighbors for a node per edge type,
* starting from the given offset and having the specified number of neighbors. * starting from the given offset and having the specified number of neighbors.
......
...@@ -131,16 +131,53 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph( ...@@ -131,16 +131,53 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
: torch::nullopt); : torch::nullopt);
} }
/**
* @brief Get a lambda function which contains the sampling process.
*
* @param fanouts The number of edges to be sampled for each node with or
* without considering edge types.
* @param replace Boolean indicating whether the sample is performed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
* @param type_per_edge A tensor representing the type of each edge, if
* present.
* @param probs_or_mask Optional tensor containing the (unnormalized)
* probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
* @param args Contains sampling algorithm specific arguments.
*
* @return A lambda function: (int64_t offset, int64_t num_neighbors) ->
* torch::Tensor, which takes offset and num_neighbors as params and returns a
* tensor of picked neighbors.
*/
template <SamplerType S> template <SamplerType S>
auto GetPickFn(
const std::vector<int64_t>& fanouts, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args) {
return [&fanouts, replace, &options, &type_per_edge, &probs_or_mask, args](
int64_t offset, int64_t num_neighbors) {
// If fanouts.size() > 1, perform sampling for each edge type of each node;
// otherwise just sample once for each node with no regard of edge types.
if (fanouts.size() > 1) {
return PickByEtype(
offset, num_neighbors, fanouts, replace, options,
type_per_edge.value(), probs_or_mask, args);
} else {
return Pick(
offset, num_neighbors, fanouts[0], replace, options, probs_or_mask,
args);
}
};
}
template <typename PickFn>
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl( c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts, const torch::Tensor& nodes, bool return_eids, PickFn pick_fn) const {
bool replace, bool return_eids,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<S> args) const {
const int64_t num_nodes = nodes.size(0); const int64_t num_nodes = nodes.size(0);
// If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types.
bool consider_etype = (fanouts.size() > 1);
const int64_t num_threads = torch::get_num_threads(); const int64_t num_threads = torch::get_num_threads();
std::vector<torch::Tensor> picked_neighbors_per_thread(num_threads); std::vector<torch::Tensor> picked_neighbors_per_thread(num_threads);
torch::Tensor num_picked_neighbors_per_node = torch::Tensor num_picked_neighbors_per_node =
...@@ -178,15 +215,10 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl( ...@@ -178,15 +215,10 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
continue; continue;
} }
if (consider_etype) { picked_neighbors_cur_thread[i - begin] =
picked_neighbors_cur_thread[i - begin] = PickByEtype( pick_fn(offset, num_neighbors);
offset, num_neighbors, fanouts, replace, indptr_options,
type_per_edge_.value(), probs_or_mask, args); // This number should be the same as the result of num_pick_fn.
} else {
picked_neighbors_cur_thread[i - begin] = Pick(
offset, num_neighbors, fanouts[0], replace,
indptr_options, probs_or_mask, args);
}
num_picked_neighbors_per_node[i + 1] = num_picked_neighbors_per_node[i + 1] =
picked_neighbors_cur_thread[i - begin].size(0); picked_neighbors_cur_thread[i - begin].size(0);
} }
...@@ -227,16 +259,23 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors( ...@@ -227,16 +259,23 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
probs_or_mask = probs_or_mask.value().to(torch::kFloat32); probs_or_mask = probs_or_mask.value().to(torch::kFloat32);
} }
} }
if (layer) { if (layer) {
const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt( const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()); static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
SamplerArgs<SamplerType::LABOR> args{indices_, random_seed, NumNodes()}; SamplerArgs<SamplerType::LABOR> args{indices_, random_seed, NumNodes()};
return SampleNeighborsImpl( return SampleNeighborsImpl(
nodes, fanouts, replace, return_eids, probs_or_mask, args); nodes, return_eids,
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask,
args));
} else { } else {
SamplerArgs<SamplerType::NEIGHBOR> args; SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl( return SampleNeighborsImpl(
nodes, fanouts, replace, return_eids, probs_or_mask, args); nodes, return_eids,
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask,
args));
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment