Unverified Commit f0213d21 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Refactor sampling (#7367)

parent 6b140f28
...@@ -18,6 +18,7 @@ namespace graphbolt { ...@@ -18,6 +18,7 @@ namespace graphbolt {
namespace sampling { namespace sampling {
enum SamplerType { NEIGHBOR, LABOR, LABOR_DEPENDENT }; enum SamplerType { NEIGHBOR, LABOR, LABOR_DEPENDENT };
enum TemporalOption { NOT_TEMPORAL, TEMPORAL };
constexpr bool is_labor(SamplerType S) { constexpr bool is_labor(SamplerType S) {
return S == SamplerType::LABOR || S == SamplerType::LABOR_DEPENDENT; return S == SamplerType::LABOR || S == SamplerType::LABOR_DEPENDENT;
...@@ -413,18 +414,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -413,18 +414,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
SharedMemoryPtr tensor_metadata_shm, SharedMemoryPtr tensor_data_shm); SharedMemoryPtr tensor_metadata_shm, SharedMemoryPtr tensor_data_shm);
private: private:
template <typename NumPickFn, typename PickFn> template <TemporalOption Temporal, typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighborsImpl( c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighborsImpl(
const torch::Tensor& seeds, const torch::Tensor& seeds,
torch::optional<std::vector<int64_t>>& seed_offsets, torch::optional<std::vector<int64_t>>& seed_offsets,
const std::vector<int64_t>& fanouts, bool return_eids, const std::vector<int64_t>& fanouts, bool return_eids,
NumPickFn num_pick_fn, PickFn pick_fn) const; NumPickFn num_pick_fn, PickFn pick_fn) const;
template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph> TemporalSampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const;
/** @brief CSC format index pointer array. */ /** @brief CSC format index pointer array. */
torch::Tensor indptr_; torch::Tensor indptr_;
......
...@@ -492,7 +492,7 @@ auto GetTemporalPickFn( ...@@ -492,7 +492,7 @@ auto GetTemporalPickFn(
}; };
} }
template <typename NumPickFn, typename PickFn> template <TemporalOption Temporal, typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph> c10::intrusive_ptr<FusedSampledSubgraph>
FusedCSCSamplingGraph::SampleNeighborsImpl( FusedCSCSamplingGraph::SampleNeighborsImpl(
const torch::Tensor& seeds, const torch::Tensor& seeds,
...@@ -512,7 +512,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -512,7 +512,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
torch::optional<torch::Tensor> edge_offsets = torch::nullopt; torch::optional<torch::Tensor> edge_offsets = torch::nullopt;
bool with_seed_offsets = seed_offsets.has_value(); bool with_seed_offsets = seed_offsets.has_value();
bool hetero_with_seed_offsets = with_seed_offsets && fanouts.size() > 1; bool hetero_with_seed_offsets = with_seed_offsets && fanouts.size() > 1 &&
Temporal == TemporalOption::NOT_TEMPORAL;
// Get the number of edge types. If it's homo or if the size of fanouts is 1 // Get the number of edge types. If it's homo or if the size of fanouts is 1
// (hetero graph but sampled as a homo graph), set num_etypes as 1. // (hetero graph but sampled as a homo graph), set num_etypes as 1.
...@@ -584,24 +585,31 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -584,24 +585,31 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
const auto offset = indptr_data[nid]; const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset; const auto num_neighbors = indptr_data[nid + 1] - offset;
const auto seed_type_id = if constexpr (Temporal == TemporalOption::TEMPORAL) {
(hetero_with_seed_offsets) num_picked_neighbors_data_ptr[i + 1] =
? std::upper_bound( num_neighbors == 0
seed_offsets->begin(), seed_offsets->end(), ? 0
i) - : num_pick_fn(i, offset, num_neighbors);
seed_offsets->begin() - 1 } else {
: 0; const auto seed_type_id =
// `seed_index` indicates the index of the current (hetero_with_seed_offsets)
// seed within the group of seeds which have the same ? std::upper_bound(
// node type. seed_offsets->begin(),
const auto seed_index = seed_offsets->end(), i) -
(hetero_with_seed_offsets) seed_offsets->begin() - 1
? i - seed_offsets->at(seed_type_id) : 0;
: i; // `seed_index` indicates the index of the current
num_pick_fn( // seed within the group of seeds which have the same
offset, num_neighbors, // node type.
num_picked_neighbors_data_ptr + 1, seed_index, const auto seed_index =
etype_id_to_num_picked_offset); (hetero_with_seed_offsets)
? i - seed_offsets->at(seed_type_id)
: i;
num_pick_fn(
offset, num_neighbors,
num_picked_neighbors_data_ptr + 1, seed_index,
etype_id_to_num_picked_offset);
}
} }
}); });
...@@ -684,16 +692,30 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -684,16 +692,30 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
: i; : i;
// Step 4. Pick neighbors for each node. // Step 4. Pick neighbors for each node.
picked_number = pick_fn( if constexpr (Temporal == TemporalOption::TEMPORAL) {
offset, num_neighbors, picked_eids_data_ptr, picked_number = num_picked_neighbors_data_ptr[i + 1];
seed_index, subgraph_indptr_data_ptr, auto picked_offset = subgraph_indptr_data_ptr[i];
etype_id_to_num_picked_offset); if (picked_number > 0) {
if (!hetero_with_seed_offsets) { auto actual_picked_count = pick_fn(
TORCH_CHECK( i, offset, num_neighbors,
num_picked_neighbors_data_ptr[i + 1] == picked_eids_data_ptr + picked_offset);
picked_number, TORCH_CHECK(
"Actual picked count doesn't match the calculated " actual_picked_count == picked_number,
"pick number."); "Actual picked count doesn't match the calculated"
" pick number.");
}
} else {
picked_number = pick_fn(
offset, num_neighbors, picked_eids_data_ptr,
seed_index, subgraph_indptr_data_ptr,
etype_id_to_num_picked_offset);
if (!hetero_with_seed_offsets) {
TORCH_CHECK(
num_picked_neighbors_data_ptr[i + 1] ==
picked_number,
"Actual picked count doesn't match the calculated"
" pick number.");
}
} }
// Step 5. Calculate other attributes and return the // Step 5. Calculate other attributes and return the
...@@ -779,141 +801,6 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -779,141 +801,6 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
subgraph_reverse_edge_ids, subgraph_type_per_edge, edge_offsets); subgraph_reverse_edge_ids, subgraph_type_per_edge, edge_offsets);
} }
template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph>
FusedCSCSamplingGraph::TemporalSampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const {
const int64_t num_nodes = nodes.size(0);
const auto indptr_options = indptr_.options();
torch::Tensor num_picked_neighbors_per_node =
torch::empty({num_nodes + 1}, indptr_options);
// Calculate GrainSize for parallel_for.
// Set the default grain size to 64.
const int64_t grain_size = 64;
torch::Tensor picked_eids;
torch::Tensor subgraph_indptr;
torch::Tensor subgraph_indices;
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
AT_DISPATCH_INDEX_TYPES(
indptr_.scalar_type(), "SampleNeighborsImplWrappedWithIndptr", ([&] {
using indptr_t = index_t;
AT_DISPATCH_INDEX_TYPES(
nodes.scalar_type(), "SampleNeighborsImplWrappedWithNodes", ([&] {
using nodes_t = index_t;
const auto indptr_data = indptr_.data_ptr<indptr_t>();
auto num_picked_neighbors_data_ptr =
num_picked_neighbors_per_node.data_ptr<indptr_t>();
num_picked_neighbors_data_ptr[0] = 0;
const auto nodes_data_ptr = nodes.data_ptr<nodes_t>();
// Step 1. Calculate pick number of each node.
torch::parallel_for(
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes_data_ptr[i];
TORCH_CHECK(
nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of "
"the "
"graph's node IDs.");
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
num_picked_neighbors_data_ptr[i + 1] =
num_neighbors == 0
? 0
: num_pick_fn(i, offset, num_neighbors);
}
});
// Step 2. Calculate prefix sum to get total length and offsets of
// each node. It's also the indptr of the generated subgraph.
subgraph_indptr = num_picked_neighbors_per_node.cumsum(
0, indptr_.scalar_type());
// Step 3. Allocate the tensor for picked neighbors.
const auto total_length =
subgraph_indptr.data_ptr<indptr_t>()[num_nodes];
picked_eids = torch::empty({total_length}, indptr_options);
subgraph_indices =
torch::empty({total_length}, indices_.options());
if (type_per_edge_.has_value()) {
subgraph_type_per_edge = torch::empty(
{total_length}, type_per_edge_.value().options());
}
// Step 4. Pick neighbors for each node.
auto picked_eids_data_ptr = picked_eids.data_ptr<indptr_t>();
auto subgraph_indptr_data_ptr =
subgraph_indptr.data_ptr<indptr_t>();
torch::parallel_for(
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes_data_ptr[i];
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
const auto picked_number =
num_picked_neighbors_data_ptr[i + 1];
const auto picked_offset = subgraph_indptr_data_ptr[i];
if (picked_number > 0) {
auto actual_picked_count = pick_fn(
i, offset, num_neighbors,
picked_eids_data_ptr + picked_offset);
TORCH_CHECK(
actual_picked_count == picked_number,
"Actual picked count doesn't match the calculated "
"pick "
"number.");
// Step 5. Calculate other attributes and return the
// subgraph.
AT_DISPATCH_INDEX_TYPES(
subgraph_indices.scalar_type(),
"IndexSelectSubgraphIndices", ([&] {
auto subgraph_indices_data_ptr =
subgraph_indices.data_ptr<index_t>();
auto indices_data_ptr =
indices_.data_ptr<index_t>();
for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) {
subgraph_indices_data_ptr[i] =
indices_data_ptr[picked_eids_data_ptr[i]];
}
}));
if (type_per_edge_.has_value()) {
AT_DISPATCH_INTEGRAL_TYPES(
subgraph_type_per_edge.value().scalar_type(),
"IndexSelectTypePerEdge", ([&] {
auto subgraph_type_per_edge_data_ptr =
subgraph_type_per_edge.value()
.data_ptr<scalar_t>();
auto type_per_edge_data_ptr =
type_per_edge_.value().data_ptr<scalar_t>();
for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) {
subgraph_type_per_edge_data_ptr[i] =
type_per_edge_data_ptr
[picked_eids_data_ptr[i]];
}
}));
}
}
}
});
}));
}));
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
return c10::make_intrusive<FusedSampledSubgraph>(
subgraph_indptr, subgraph_indices, nodes, torch::nullopt,
subgraph_reverse_edge_ids, subgraph_type_per_edge);
}
c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
torch::optional<torch::Tensor> seeds, torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets, torch::optional<std::vector<int64_t>> seed_offsets,
...@@ -969,7 +856,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( ...@@ -969,7 +856,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
indices_, indices_,
{random_seed.value(), static_cast<float>(seed2_contribution)}, {random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()}; NumNodes()};
return SampleNeighborsImpl( return SampleNeighborsImpl<TemporalOption::NOT_TEMPORAL>(
seeds.value(), seed_offsets, fanouts, return_eids, seeds.value(), seed_offsets, fanouts, return_eids,
GetNumPickFn( GetNumPickFn(
fanouts, replace, type_per_edge_, probs_or_mask, fanouts, replace, type_per_edge_, probs_or_mask,
...@@ -990,7 +877,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( ...@@ -990,7 +877,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
NumNodes()}; NumNodes()};
} }
}(); }();
return SampleNeighborsImpl( return SampleNeighborsImpl<TemporalOption::NOT_TEMPORAL>(
seeds.value(), seed_offsets, fanouts, return_eids, seeds.value(), seed_offsets, fanouts, return_eids,
GetNumPickFn( GetNumPickFn(
fanouts, replace, type_per_edge_, probs_or_mask, fanouts, replace, type_per_edge_, probs_or_mask,
...@@ -1001,7 +888,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( ...@@ -1001,7 +888,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
} }
} else { } else {
SamplerArgs<SamplerType::NEIGHBOR> args; SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl( return SampleNeighborsImpl<TemporalOption::NOT_TEMPORAL>(
seeds.value(), seed_offsets, fanouts, return_eids, seeds.value(), seed_offsets, fanouts, return_eids,
GetNumPickFn( GetNumPickFn(
fanouts, replace, type_per_edge_, probs_or_mask, with_seed_offsets), fanouts, replace, type_per_edge_, probs_or_mask, with_seed_offsets),
...@@ -1019,6 +906,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( ...@@ -1019,6 +906,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
bool return_eids, torch::optional<std::string> probs_name, bool return_eids, torch::optional<std::string> probs_name,
torch::optional<std::string> node_timestamp_attr_name, torch::optional<std::string> node_timestamp_attr_name,
torch::optional<std::string> edge_timestamp_attr_name) const { torch::optional<std::string> edge_timestamp_attr_name) const {
torch::optional<std::vector<int64_t>> seed_offsets = torch::nullopt;
// 1. Get probs_or_mask. // 1. Get probs_or_mask.
auto probs_or_mask = this->EdgeAttribute(probs_name); auto probs_or_mask = this->EdgeAttribute(probs_name);
if (probs_name.has_value()) { if (probs_name.has_value()) {
...@@ -1039,8 +927,8 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( ...@@ -1039,8 +927,8 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt( const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()); static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
SamplerArgs<SamplerType::LABOR> args{indices_, random_seed, NumNodes()}; SamplerArgs<SamplerType::LABOR> args{indices_, random_seed, NumNodes()};
return TemporalSampleNeighborsImpl( return SampleNeighborsImpl<TemporalOption::TEMPORAL>(
input_nodes, return_eids, input_nodes, seed_offsets, fanouts, return_eids,
GetTemporalNumPickFn( GetTemporalNumPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace, input_nodes_timestamp, this->indices_, fanouts, replace,
type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp), type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp),
...@@ -1050,8 +938,8 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( ...@@ -1050,8 +938,8 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
edge_timestamp, args)); edge_timestamp, args));
} else { } else {
SamplerArgs<SamplerType::NEIGHBOR> args; SamplerArgs<SamplerType::NEIGHBOR> args;
return TemporalSampleNeighborsImpl( return SampleNeighborsImpl<TemporalOption::TEMPORAL>(
input_nodes, return_eids, input_nodes, seed_offsets, fanouts, return_eids,
GetTemporalNumPickFn( GetTemporalNumPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace, input_nodes_timestamp, this->indices_, fanouts, replace,
type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp), type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment