Unverified Commit 658b2086 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Optimize hetero sampling on CPU (#7360)

parent 9090a879
...@@ -415,6 +415,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -415,6 +415,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
private: private:
template <typename NumPickFn, typename PickFn> template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighborsImpl( c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighborsImpl(
const torch::Tensor& seeds,
torch::optional<std::vector<int64_t>>& seed_offsets,
const std::vector<int64_t>& fanouts, bool return_eids,
NumPickFn num_pick_fn, PickFn pick_fn) const;
template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph> TemporalSampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn, const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const; PickFn pick_fn) const;
...@@ -498,13 +505,14 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -498,13 +505,14 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @param offset The starting edge ID for the connected neighbors of the given * @param offset The starting edge ID for the connected neighbors of the given
* node. * node.
* @param num_neighbors The number of neighbors of this node. * @param num_neighbors The number of neighbors of this node.
* * @param num_picked_ptr The pointer of the tensor which stores the pick
* @return The pick number of the given node. * numbers.
*/ */
int64_t NumPick( template <typename PickedNumType>
void NumPick(
int64_t fanout, bool replace, int64_t fanout, bool replace,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset, const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors); int64_t num_neighbors, PickedNumType* num_picked_ptr);
int64_t TemporalNumPick( int64_t TemporalNumPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout, torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout,
...@@ -513,11 +521,13 @@ int64_t TemporalNumPick( ...@@ -513,11 +521,13 @@ int64_t TemporalNumPick(
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset, const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors); int64_t offset, int64_t num_neighbors);
int64_t NumPickByEtype( template <typename PickedNumType>
const std::vector<int64_t>& fanouts, bool replace, void NumPickByEtype(
bool with_seed_offsets, const std::vector<int64_t>& fanouts, bool replace,
const torch::Tensor& type_per_edge, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset, const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors); int64_t num_neighbors, PickedNumType* num_picked_ptr, int64_t seed_index,
const std::vector<int64_t>& etype_id_to_num_picked_offset);
int64_t TemporalNumPickByEtype( int64_t TemporalNumPickByEtype(
torch::Tensor seed_timestamp, torch::Tensor csc_indices, torch::Tensor seed_timestamp, torch::Tensor csc_indices,
...@@ -610,16 +620,24 @@ int64_t TemporalPick( ...@@ -610,16 +620,24 @@ int64_t TemporalPick(
* probabilities associated with each neighboring edge of a node in the original * probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements * graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph. * equal to the number of edges in the graph.
* @param picked_data_ptr The destination address where the picked neighbors * @param picked_data_ptr The pointer of the tensor where the picked neighbors
* should be put. Enough memory space should be allocated in advance. * should be put. Enough memory space should be allocated in advance.
* @param seed_offset The offset(index) of the seed among the group of seeds
* which share the same node type.
* @param subgraph_indptr_ptr The pointer of the tensor which stores the indptr
* of the sampled subgraph.
* @param etype_id_to_num_picked_offset A vector storing the mappings from each
* etype_id to the offset of its pick numbers in the tensor.
*/ */
template <SamplerType S, typename PickedType> template <SamplerType S, typename PickedType>
int64_t PickByEtype( int64_t PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts, bool with_seed_offsets, int64_t offset, int64_t num_neighbors,
bool replace, const torch::TensorOptions& options, const std::vector<int64_t>& fanouts, bool replace,
const torch::Tensor& type_per_edge, const torch::TensorOptions& options, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args, const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr); PickedType* picked_data_ptr, int64_t seed_offset,
PickedType* subgraph_indptr_ptr,
const std::vector<int64_t>& etype_id_to_num_picked_offset);
template <typename PickedType> template <typename PickedType>
int64_t TemporalPickByEtype( int64_t TemporalPickByEtype(
......
This diff is collapsed.
...@@ -2219,10 +2219,13 @@ def test_sample_neighbors_hetero_pick_number( ...@@ -2219,10 +2219,13 @@ def test_sample_neighbors_hetero_pick_number(
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
node_type_to_id=ntypes, node_type_to_id=ntypes,
edge_type_to_id=etypes, edge_type_to_id=etypes,
) ).to(F.ctx())
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([0, 1]) nodes = {
"N0": torch.LongTensor([0]).to(F.ctx()),
"N1": torch.LongTensor([1]).to(F.ctx()),
}
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment