Unverified Commit 658b2086 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Optimize hetero sampling on CPU (#7360)

parent 9090a879
......@@ -415,6 +415,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
private:
template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighborsImpl(
const torch::Tensor& seeds,
torch::optional<std::vector<int64_t>>& seed_offsets,
const std::vector<int64_t>& fanouts, bool return_eids,
NumPickFn num_pick_fn, PickFn pick_fn) const;
template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph> TemporalSampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const;
......@@ -498,13 +505,14 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @param offset The starting edge ID for the connected neighbors of the given
* node.
* @param num_neighbors The number of neighbors of this node.
*
* @return The pick number of the given node.
* @param num_picked_ptr The pointer of the tensor which stores the pick
* numbers.
*/
int64_t NumPick(
template <typename PickedNumType>
void NumPick(
int64_t fanout, bool replace,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors);
int64_t num_neighbors, PickedNumType* num_picked_ptr);
int64_t TemporalNumPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout,
......@@ -513,11 +521,13 @@ int64_t TemporalNumPick(
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors);
int64_t NumPickByEtype(
const std::vector<int64_t>& fanouts, bool replace,
template <typename PickedNumType>
void NumPickByEtype(
bool with_seed_offsets, const std::vector<int64_t>& fanouts, bool replace,
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors);
int64_t num_neighbors, PickedNumType* num_picked_ptr, int64_t seed_index,
const std::vector<int64_t>& etype_id_to_num_picked_offset);
int64_t TemporalNumPickByEtype(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
......@@ -610,16 +620,24 @@ int64_t TemporalPick(
* probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
* @param picked_data_ptr The destination address where the picked neighbors
* @param picked_data_ptr The pointer of the tensor where the picked neighbors
* should be put. Enough memory space should be allocated in advance.
* @param seed_offset The offset(index) of the seed among the group of seeds
* which share the same node type.
* @param subgraph_indptr_ptr The pointer of the tensor which stores the indptr
* of the sampled subgraph.
* @param etype_id_to_num_picked_offset A vector storing the mappings from each
* etype_id to the offset of its pick numbers in the tensor.
*/
template <SamplerType S, typename PickedType>
int64_t PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge,
bool with_seed_offsets, int64_t offset, int64_t num_neighbors,
const std::vector<int64_t>& fanouts, bool replace,
const torch::TensorOptions& options, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr);
PickedType* picked_data_ptr, int64_t seed_offset,
PickedType* subgraph_indptr_ptr,
const std::vector<int64_t>& etype_id_to_num_picked_offset);
template <typename PickedType>
int64_t TemporalPickByEtype(
......
......@@ -18,6 +18,7 @@
#include <type_traits>
#include <vector>
#include "./expand_indptr.h"
#include "./macro.h"
#include "./random.h"
#include "./shared_memory_helper.h"
......@@ -355,17 +356,23 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
auto GetNumPickFn(
const std::vector<int64_t>& fanouts, bool replace,
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask) {
const torch::optional<torch::Tensor>& probs_or_mask,
bool with_seed_offsets) {
// If fanouts.size() > 1, returns the total number of all edge types of the
// given node.
return [&fanouts, replace, &probs_or_mask, &type_per_edge](
int64_t seed_offset, int64_t offset, int64_t num_neighbors) {
return [&fanouts, replace, &probs_or_mask, &type_per_edge, with_seed_offsets](
int64_t offset, int64_t num_neighbors, auto num_picked_ptr,
int64_t seed_index,
const std::vector<int64_t>& etype_id_to_num_picked_offset) {
if (fanouts.size() > 1) {
return NumPickByEtype(
fanouts, replace, type_per_edge.value(), probs_or_mask, offset,
num_neighbors);
NumPickByEtype(
with_seed_offsets, fanouts, replace, type_per_edge.value(),
probs_or_mask, offset, num_neighbors, num_picked_ptr, seed_index,
etype_id_to_num_picked_offset);
} else {
return NumPick(fanouts[0], replace, probs_or_mask, offset, num_neighbors);
NumPick(
fanouts[0], replace, probs_or_mask, offset, num_neighbors,
num_picked_ptr + seed_index);
}
};
}
......@@ -423,21 +430,25 @@ auto GetPickFn(
const std::vector<int64_t>& fanouts, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args) {
return [&fanouts, replace, &options, &type_per_edge, &probs_or_mask, args](
int64_t seed_offset, int64_t offset, int64_t num_neighbors,
auto picked_data_ptr) {
const torch::optional<torch::Tensor>& probs_or_mask, bool with_seed_offsets,
SamplerArgs<S> args) {
return [&fanouts, replace, &options, &type_per_edge, &probs_or_mask, args,
with_seed_offsets](
int64_t offset, int64_t num_neighbors, auto picked_data_ptr,
int64_t seed_offset, auto subgraph_indptr_ptr,
const std::vector<int64_t>& etype_id_to_num_picked_offset) {
// If fanouts.size() > 1, perform sampling for each edge type of each
// node; otherwise just sample once for each node with no regard of edge
// types.
if (fanouts.size() > 1) {
return PickByEtype(
offset, num_neighbors, fanouts, replace, options,
type_per_edge.value(), probs_or_mask, args, picked_data_ptr);
with_seed_offsets, offset, num_neighbors, fanouts, replace, options,
type_per_edge.value(), probs_or_mask, args, picked_data_ptr,
seed_offset, subgraph_indptr_ptr, etype_id_to_num_picked_offset);
} else {
int64_t num_sampled = Pick(
offset, num_neighbors, fanouts[0], replace, options, probs_or_mask,
args, picked_data_ptr);
args, picked_data_ptr + subgraph_indptr_ptr[seed_offset]);
if (type_per_edge) {
std::sort(picked_data_ptr, picked_data_ptr + num_sampled);
}
......@@ -484,6 +495,304 @@ auto GetTemporalPickFn(
template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph>
FusedCSCSamplingGraph::SampleNeighborsImpl(
const torch::Tensor& seeds,
torch::optional<std::vector<int64_t>>& seed_offsets,
const std::vector<int64_t>& fanouts, bool return_eids,
NumPickFn num_pick_fn, PickFn pick_fn) const {
const int64_t num_seeds = seeds.size(0);
const auto indptr_options = indptr_.options();
// Calculate GrainSize for parallel_for.
// Set the default grain size to 64.
const int64_t grain_size = 64;
torch::Tensor picked_eids;
torch::Tensor subgraph_indptr;
torch::Tensor subgraph_indices;
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
torch::optional<torch::Tensor> edge_offsets = torch::nullopt;
bool with_seed_offsets = seed_offsets.has_value();
bool hetero_with_seed_offsets = with_seed_offsets && fanouts.size() > 1;
// Get the number of edge types. If it's homo or if the size of fanouts is 1
// (hetero graph but sampled as a homo graph), set num_etypes as 1.
// In temporal sampling, this will not be used for now since the logic hasn't
// been adopted for temporal sampling.
const int64_t num_etypes =
(edge_type_to_id_.has_value() && hetero_with_seed_offsets)
? edge_type_to_id_->size()
: 1;
std::vector<int64_t> etype_id_to_src_ntype_id(num_etypes);
std::vector<int64_t> etype_id_to_dst_ntype_id(num_etypes);
torch::optional<torch::Tensor> subgraph_indptr_substract = torch::nullopt;
// The pick numbers are stored in a single tensor by the order of etype. Each
// etype corresponds to a group of seeds whose ntype are the same as the
// dst_type. `etype_id_to_num_picked_offset` indicates the beginning offset
// where each etype's corresponding seeds' pick numbers are stored in the pick
// number tensor.
std::vector<int64_t> etype_id_to_num_picked_offset(num_etypes + 1);
if (hetero_with_seed_offsets) {
for (auto& etype_and_id : edge_type_to_id_.value()) {
auto etype = etype_and_id.key();
auto id = etype_and_id.value();
auto [src_type, dst_type] = utils::parse_src_dst_ntype_from_etype(etype);
auto dst_ntype_id = node_type_to_id_->at(dst_type);
etype_id_to_src_ntype_id[id] = node_type_to_id_->at(src_type);
etype_id_to_dst_ntype_id[id] = dst_ntype_id;
etype_id_to_num_picked_offset[id + 1] =
seed_offsets->at(dst_ntype_id + 1) - seed_offsets->at(dst_ntype_id) +
1;
}
std::partial_sum(
etype_id_to_num_picked_offset.begin(),
etype_id_to_num_picked_offset.end(),
etype_id_to_num_picked_offset.begin());
} else {
etype_id_to_dst_ntype_id[0] = 0;
etype_id_to_num_picked_offset[1] = num_seeds + 1;
}
// `num_rows` indicates the length of `num_picked_neighbors_per_node`, which
// is used for storing pick numbers. In non-temporal hetero sampling, it
// equals to sum_{etype} #seeds with ntype=dst_type(etype). In homo sampling,
// it equals to `num_seeds`.
const int64_t num_rows = etype_id_to_num_picked_offset[num_etypes];
torch::Tensor num_picked_neighbors_per_node =
torch::empty({num_rows}, indptr_options);
AT_DISPATCH_INDEX_TYPES(
indptr_.scalar_type(), "SampleNeighborsImplWrappedWithIndptr", ([&] {
using indptr_t = index_t;
AT_DISPATCH_INDEX_TYPES(
seeds.scalar_type(), "SampleNeighborsImplWrappedWithSeeds", ([&] {
using seeds_t = index_t;
const auto indptr_data = indptr_.data_ptr<indptr_t>();
const auto num_picked_neighbors_data_ptr =
num_picked_neighbors_per_node.data_ptr<indptr_t>();
num_picked_neighbors_data_ptr[0] = 0;
const auto seeds_data_ptr = seeds.data_ptr<seeds_t>();
// Initialize the empty spots in `num_picked_neighbors_per_node`.
if (hetero_with_seed_offsets) {
for (auto i = 0; i < num_etypes; ++i) {
num_picked_neighbors_data_ptr
[etype_id_to_num_picked_offset[i]] = 0;
}
}
// Step 1. Calculate pick number of each node.
torch::parallel_for(
0, num_seeds, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
const auto nid = seeds_data_ptr[i];
TORCH_CHECK(
nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of "
"the graph's node IDs.");
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
const auto seed_type_id =
(hetero_with_seed_offsets)
? std::upper_bound(
seed_offsets->begin(), seed_offsets->end(),
i) -
seed_offsets->begin() - 1
: 0;
// `seed_index` indicates the index of the current
// seed within the group of seeds which have the same
// node type.
const auto seed_index =
(hetero_with_seed_offsets)
? i - seed_offsets->at(seed_type_id)
: i;
num_pick_fn(
offset, num_neighbors,
num_picked_neighbors_data_ptr + 1, seed_index,
etype_id_to_num_picked_offset);
}
});
if (hetero_with_seed_offsets) {
torch::Tensor num_picked_offset_tensor =
torch::zeros({num_etypes + 1}, indptr_options);
torch::Tensor substract_offset =
torch::zeros({num_etypes}, indptr_options);
const auto substract_offset_data_ptr =
substract_offset.data_ptr<indptr_t>();
const auto num_picked_offset_data_ptr =
num_picked_offset_tensor.data_ptr<indptr_t>();
for (auto i = 0; i < num_etypes; ++i) {
num_picked_offset_data_ptr[i + 1] =
etype_id_to_num_picked_offset[i + 1];
// Collect the total pick number for each edge type.
if (i + 1 < num_etypes)
substract_offset_data_ptr[i + 1] =
num_picked_neighbors_data_ptr
[etype_id_to_num_picked_offset[i]];
num_picked_neighbors_data_ptr
[etype_id_to_num_picked_offset[i]] = 0;
}
substract_offset =
substract_offset.cumsum(0, indptr_.scalar_type());
subgraph_indptr_substract = ops::ExpandIndptr(
num_picked_offset_tensor, indptr_.scalar_type(),
substract_offset);
}
// Step 2. Calculate prefix sum to get total length and offsets of
// each node. It's also the indptr of the generated subgraph.
subgraph_indptr = num_picked_neighbors_per_node.cumsum(
0, indptr_.scalar_type());
auto subgraph_indptr_data_ptr =
subgraph_indptr.data_ptr<indptr_t>();
// When doing non-temporal hetero sampling, we generate an
// edge_offsets tensor.
if (hetero_with_seed_offsets) {
edge_offsets = torch::empty({num_etypes + 1}, indptr_options);
AT_DISPATCH_INTEGRAL_TYPES(
edge_offsets.value().scalar_type(), "CalculateEdgeOffsets",
([&] {
auto edge_offsets_data_ptr =
edge_offsets.value().data_ptr<scalar_t>();
edge_offsets_data_ptr[0] = 0;
for (auto i = 0; i < num_etypes; ++i) {
edge_offsets_data_ptr[i + 1] = subgraph_indptr_data_ptr
[etype_id_to_num_picked_offset[i + 1] - 1];
}
}));
}
// Step 3. Allocate the tensor for picked neighbors.
const auto total_length =
subgraph_indptr.data_ptr<indptr_t>()[num_rows - 1];
picked_eids = torch::empty({total_length}, indptr_options);
subgraph_indices =
torch::empty({total_length}, indices_.options());
if (!hetero_with_seed_offsets && type_per_edge_.has_value()) {
subgraph_type_per_edge = torch::empty(
{total_length}, type_per_edge_.value().options());
}
auto picked_eids_data_ptr = picked_eids.data_ptr<indptr_t>();
torch::parallel_for(
0, num_seeds, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
const auto nid = seeds_data_ptr[i];
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
auto picked_number = 0;
const auto seed_type_id =
(hetero_with_seed_offsets)
? std::upper_bound(
seed_offsets->begin(), seed_offsets->end(),
i) -
seed_offsets->begin() - 1
: 0;
const auto seed_index =
(hetero_with_seed_offsets)
? i - seed_offsets->at(seed_type_id)
: i;
// Step 4. Pick neighbors for each node.
picked_number = pick_fn(
offset, num_neighbors, picked_eids_data_ptr,
seed_index, subgraph_indptr_data_ptr,
etype_id_to_num_picked_offset);
if (!hetero_with_seed_offsets) {
TORCH_CHECK(
num_picked_neighbors_data_ptr[i + 1] ==
picked_number,
"Actual picked count doesn't match the calculated "
"pick number.");
}
// Step 5. Calculate other attributes and return the
// subgraph.
if (picked_number > 0) {
AT_DISPATCH_INDEX_TYPES(
subgraph_indices.scalar_type(),
"IndexSelectSubgraphIndices", ([&] {
auto subgraph_indices_data_ptr =
subgraph_indices.data_ptr<index_t>();
auto indices_data_ptr =
indices_.data_ptr<index_t>();
for (auto i = 0; i < num_etypes; ++i) {
if (etype_id_to_dst_ntype_id[i] != seed_type_id)
continue;
const auto indptr_offset =
with_seed_offsets
? etype_id_to_num_picked_offset[i] +
seed_index
: seed_index;
const auto picked_begin =
subgraph_indptr_data_ptr[indptr_offset];
const auto picked_end =
subgraph_indptr_data_ptr[indptr_offset + 1];
for (auto j = picked_begin; j < picked_end;
++j) {
subgraph_indices_data_ptr[j] =
indices_data_ptr[picked_eids_data_ptr[j]];
if (hetero_with_seed_offsets &&
node_type_offset_.has_value()) {
// Substract the node type offset from
// subgraph indices. Assuming
// node_type_offset has the same dtype as
// indices.
auto node_type_offset_data =
node_type_offset_.value()
.data_ptr<index_t>();
subgraph_indices_data_ptr[j] -=
node_type_offset_data
[etype_id_to_src_ntype_id[i]];
}
}
}
}));
if (!hetero_with_seed_offsets &&
type_per_edge_.has_value()) {
// When hetero graph is sampled as a homo graph, we
// still generate type_per_edge tensor for this
// situation.
AT_DISPATCH_INTEGRAL_TYPES(
subgraph_type_per_edge.value().scalar_type(),
"IndexSelectTypePerEdge", ([&] {
auto subgraph_type_per_edge_data_ptr =
subgraph_type_per_edge.value()
.data_ptr<scalar_t>();
auto type_per_edge_data_ptr =
type_per_edge_.value().data_ptr<scalar_t>();
const auto picked_offset =
subgraph_indptr_data_ptr[seed_index];
for (auto j = picked_offset;
j < picked_offset + picked_number; ++j)
subgraph_type_per_edge_data_ptr[j] =
type_per_edge_data_ptr
[picked_eids_data_ptr[j]];
}));
}
}
}
});
}));
}));
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
if (subgraph_indptr_substract.has_value()) {
subgraph_indptr -= subgraph_indptr_substract.value();
}
return c10::make_intrusive<FusedSampledSubgraph>(
subgraph_indptr, subgraph_indices, seeds, torch::nullopt,
subgraph_reverse_edge_ids, subgraph_type_per_edge, edge_offsets);
}
template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph>
FusedCSCSamplingGraph::TemporalSampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const {
const int64_t num_nodes = nodes.size(0);
......@@ -663,6 +972,8 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}
}
bool with_seed_offsets = seed_offsets.has_value();
if (layer) {
if (random_seed.has_value() && random_seed->numel() >= 2) {
SamplerArgs<SamplerType::LABOR_DEPENDENT> args{
......@@ -670,11 +981,13 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
{random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()};
return SampleNeighborsImpl(
seeds.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
seeds.value(), seed_offsets, fanouts, return_eids,
GetNumPickFn(
fanouts, replace, type_per_edge_, probs_or_mask,
with_seed_offsets),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_,
probs_or_mask, args));
probs_or_mask, with_seed_offsets, args));
} else {
auto args = [&] {
if (random_seed.has_value() && random_seed->numel() == 1) {
......@@ -689,20 +1002,23 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}
}();
return SampleNeighborsImpl(
seeds.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
seeds.value(), seed_offsets, fanouts, return_eids,
GetNumPickFn(
fanouts, replace, type_per_edge_, probs_or_mask,
with_seed_offsets),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_,
probs_or_mask, args));
probs_or_mask, with_seed_offsets, args));
}
} else {
SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl(
seeds.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
seeds.value(), seed_offsets, fanouts, return_eids,
GetNumPickFn(
fanouts, replace, type_per_edge_, probs_or_mask, with_seed_offsets),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask,
args));
with_seed_offsets, args));
}
}
......@@ -734,7 +1050,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
SamplerArgs<SamplerType::LABOR> args{indices_, random_seed, NumNodes()};
return SampleNeighborsImpl(
return TemporalSampleNeighborsImpl(
input_nodes, return_eids,
GetTemporalNumPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
......@@ -745,7 +1061,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
edge_timestamp, args));
} else {
SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl(
return TemporalSampleNeighborsImpl(
input_nodes, return_eids,
GetTemporalNumPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
......@@ -806,12 +1122,13 @@ void FusedCSCSamplingGraph::HoldSharedMemoryObject(
tensor_data_shm_ = std::move(tensor_data_shm);
}
int64_t NumPick(
template <typename PickedNumType>
void NumPick(
int64_t fanout, bool replace,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors) {
int64_t num_neighbors, PickedNumType* picked_num_ptr) {
int64_t num_valid_neighbors = num_neighbors;
if (probs_or_mask.has_value()) {
if (probs_or_mask.has_value() && num_neighbors > 0) {
// Subtract the count of zeros in probs_or_mask.
AT_DISPATCH_ALL_TYPES(
probs_or_mask.value().scalar_type(), "CountZero", ([&] {
......@@ -821,8 +1138,11 @@ int64_t NumPick(
0);
}));
}
if (num_valid_neighbors == 0 || fanout == -1) return num_valid_neighbors;
return replace ? fanout : std::min(fanout, num_valid_neighbors);
if (num_valid_neighbors == 0 || fanout == -1) {
*picked_num_ptr = num_valid_neighbors;
} else {
*picked_num_ptr = replace ? fanout : std::min(fanout, num_valid_neighbors);
}
}
torch::Tensor TemporalMask(
......@@ -926,14 +1246,16 @@ int64_t TemporalNumPick(
return replace ? fanout : std::min(fanout, num_valid_neighbors);
}
int64_t NumPickByEtype(
const std::vector<int64_t>& fanouts, bool replace,
template <typename PickedNumType>
void NumPickByEtype(
bool with_seed_offsets, const std::vector<int64_t>& fanouts, bool replace,
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors) {
int64_t num_neighbors, PickedNumType* num_picked_ptr, int64_t seed_index,
const std::vector<int64_t>& etype_id_to_num_picked_offset) {
int64_t etype_begin = offset;
const int64_t end = offset + num_neighbors;
int64_t total_count = 0;
PickedNumType total_count = 0;
AT_DISPATCH_INTEGRAL_TYPES(
type_per_edge.scalar_type(), "NumPickFnByEtype", ([&] {
const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();
......@@ -947,13 +1269,32 @@ int64_t NumPickByEtype(
etype);
int64_t etype_end = etype_end_it - type_per_edge_data;
// Do sampling for one etype.
total_count += NumPick(
fanouts[etype], replace, probs_or_mask, etype_begin,
etype_end - etype_begin);
if (with_seed_offsets) {
// The pick numbers aren't stored continuously, but separately for
// each different etype.
const auto offset =
etype_id_to_num_picked_offset[etype] + seed_index;
NumPick(
fanouts[etype], replace, probs_or_mask, etype_begin,
etype_end - etype_begin, num_picked_ptr + offset);
// Use the skipped position of each edge type in the
// num_picked_tensor to sum up the total pick number for each edge
// type.
num_picked_ptr[etype_id_to_num_picked_offset[etype] - 1] +=
num_picked_ptr[offset];
} else {
PickedNumType picked_count = 0;
NumPick(
fanouts[etype], replace, probs_or_mask, etype_begin,
etype_end - etype_begin, &picked_count);
total_count += picked_count;
}
etype_begin = etype_end;
}
}));
return total_count;
if (!with_seed_offsets) {
num_picked_ptr[seed_index] = total_count;
}
}
int64_t TemporalNumPickByEtype(
......@@ -1265,6 +1606,7 @@ int64_t Pick(
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr) {
if (fanout == 0 || num_neighbors == 0) return 0;
if (probs_or_mask.has_value()) {
return NonUniformPick(
offset, num_neighbors, fanout, replace, options, probs_or_mask.value(),
......@@ -1326,14 +1668,16 @@ int64_t TemporalPick(
template <SamplerType S, typename PickedType>
int64_t PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge,
bool with_seed_offsets, int64_t offset, int64_t num_neighbors,
const std::vector<int64_t>& fanouts, bool replace,
const torch::TensorOptions& options, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr) {
PickedType* picked_data_ptr, int64_t seed_index,
PickedType* subgraph_indptr_ptr,
const std::vector<int64_t>& etype_id_to_num_picked_offset) {
int64_t etype_begin = offset;
int64_t etype_end = offset;
int64_t pick_offset = 0;
int64_t picked_total_count = 0;
AT_DISPATCH_INTEGRAL_TYPES(
type_per_edge.scalar_type(), "PickByEtype", ([&] {
const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();
......@@ -1348,17 +1692,36 @@ int64_t PickByEtype(
type_per_edge_data + etype_begin, type_per_edge_data + end,
etype);
etype_end = etype_end_it - type_per_edge_data;
// Do sampling for one etype.
// Do sampling for one etype. The picked nodes aren't stored
// continuously, but separately for each different etype.
if (fanout != 0) {
int64_t picked_count = Pick(
etype_begin, etype_end - etype_begin, fanout, replace, options,
probs_or_mask, args, picked_data_ptr + pick_offset);
pick_offset += picked_count;
auto picked_count = 0;
if (with_seed_offsets) {
const auto indptr_offset =
etype_id_to_num_picked_offset[etype] + seed_index;
picked_count = Pick(
etype_begin, etype_end - etype_begin, fanout, replace,
options, probs_or_mask, args,
picked_data_ptr + subgraph_indptr_ptr[indptr_offset]);
TORCH_CHECK(
subgraph_indptr_ptr[indptr_offset + 1] -
subgraph_indptr_ptr[indptr_offset] ==
picked_count,
"Actual picked count doesn't match the calculated "
"pick number.");
} else {
picked_count = Pick(
etype_begin, etype_end - etype_begin, fanout, replace,
options, probs_or_mask, args,
picked_data_ptr + subgraph_indptr_ptr[seed_index] +
picked_total_count);
}
picked_total_count += picked_count;
}
etype_begin = etype_end;
}
}));
return pick_offset;
return picked_total_count;
}
template <SamplerType S, typename PickedType>
......@@ -1409,7 +1772,7 @@ std::enable_if_t<is_labor(S), int64_t> Pick(
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr) {
if (fanout == 0) return 0;
if (fanout == 0 || num_neighbors == 0) return 0;
if (probs_or_mask.has_value()) {
if (fanout < 0) {
return NonUniformPick(
......
......@@ -2219,10 +2219,13 @@ def test_sample_neighbors_hetero_pick_number(
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
).to(F.ctx())
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([0, 1])
nodes = {
"N0": torch.LongTensor([0]).to(F.ctx()),
"N1": torch.LongTensor([1]).to(F.ctx()),
}
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment