Unverified Commit 69d9b726 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Use binary search in neighbor sample (#5891)

parent 381421b7
...@@ -343,13 +343,14 @@ torch::Tensor PickByEtype( ...@@ -343,13 +343,14 @@ torch::Tensor PickByEtype(
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
type_per_edge.scalar_type(), "PickByEtype", ([&] { type_per_edge.scalar_type(), "PickByEtype", ([&] {
const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>(); const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();
while (etype_end < offset + num_neighbors) { const auto end = offset + num_neighbors;
scalar_t etype = type_per_edge_data[etype_end]; while (etype_begin < end) {
scalar_t etype = type_per_edge_data[etype_begin];
int64_t fanout = fanouts[etype]; int64_t fanout = fanouts[etype];
while (etype_end < offset + num_neighbors && auto etype_end_it = std::upper_bound(
type_per_edge_data[etype_end] == etype) { type_per_edge_data + etype_begin, type_per_edge_data + end,
etype_end++; etype);
} etype_end = etype_end_it - type_per_edge_data;
// Do sampling for one etype. // Do sampling for one etype.
if (fanout != 0) { if (fanout != 0) {
picked_neighbors[etype] = Pick( picked_neighbors[etype] = Pick(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment