Unverified Commit 658b2086 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Optimize hetero sampling on CPU (#7360)

parent 9090a879
......@@ -415,6 +415,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
private:
template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighborsImpl(
const torch::Tensor& seeds,
torch::optional<std::vector<int64_t>>& seed_offsets,
const std::vector<int64_t>& fanouts, bool return_eids,
NumPickFn num_pick_fn, PickFn pick_fn) const;
template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph> TemporalSampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const;
......@@ -498,13 +505,14 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @param offset The starting edge ID for the connected neighbors of the given
* node.
* @param num_neighbors The number of neighbors of this node.
*
* @return The pick number of the given node.
* @param num_picked_ptr The pointer of the tensor which stores the pick
* numbers.
*/
int64_t NumPick(
template <typename PickedNumType>
void NumPick(
int64_t fanout, bool replace,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors);
int64_t num_neighbors, PickedNumType* num_picked_ptr);
int64_t TemporalNumPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout,
......@@ -513,11 +521,13 @@ int64_t TemporalNumPick(
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors);
int64_t NumPickByEtype(
const std::vector<int64_t>& fanouts, bool replace,
template <typename PickedNumType>
void NumPickByEtype(
bool with_seed_offsets, const std::vector<int64_t>& fanouts, bool replace,
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors);
int64_t num_neighbors, PickedNumType* num_picked_ptr, int64_t seed_index,
const std::vector<int64_t>& etype_id_to_num_picked_offset);
int64_t TemporalNumPickByEtype(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
......@@ -610,16 +620,24 @@ int64_t TemporalPick(
* probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
* @param picked_data_ptr The destination address where the picked neighbors
* @param picked_data_ptr The pointer of the tensor where the picked neighbors
* should be put. Enough memory space should be allocated in advance.
* @param seed_offset The offset(index) of the seed among the group of seeds
* which share the same node type.
* @param subgraph_indptr_ptr The pointer of the tensor which stores the indptr
* of the sampled subgraph.
* @param etype_id_to_num_picked_offset A vector storing the mappings from each
* etype_id to the offset of its pick numbers in the tensor.
*/
template <SamplerType S, typename PickedType>
int64_t PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge,
bool with_seed_offsets, int64_t offset, int64_t num_neighbors,
const std::vector<int64_t>& fanouts, bool replace,
const torch::TensorOptions& options, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr);
PickedType* picked_data_ptr, int64_t seed_offset,
PickedType* subgraph_indptr_ptr,
const std::vector<int64_t>& etype_id_to_num_picked_offset);
template <typename PickedType>
int64_t TemporalPickByEtype(
......
This diff is collapsed.
......@@ -2219,10 +2219,13 @@ def test_sample_neighbors_hetero_pick_number(
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
).to(F.ctx())
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([0, 1])
nodes = {
"N0": torch.LongTensor([0]).to(F.ctx()),
"N1": torch.LongTensor([1]).to(F.ctx()),
}
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment