"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "828a5e5bc6ffaa5716b02283874ec830e1b786fc"
Unverified Commit 2ef90be0 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt] Dispatch type edges in neighbor sampling (#5890)

parent d22049e8
...@@ -340,21 +340,25 @@ torch::Tensor PickByEtype( ...@@ -340,21 +340,25 @@ torch::Tensor PickByEtype(
fanouts.size(), torch::tensor({}, options)); fanouts.size(), torch::tensor({}, options));
int64_t etype_begin = offset; int64_t etype_begin = offset;
int64_t etype_end = offset; int64_t etype_end = offset;
while (etype_end < offset + num_neighbors) { AT_DISPATCH_INTEGRAL_TYPES(
int64_t etype = type_per_edge[etype_end].item<int64_t>(); type_per_edge.scalar_type(), "PickByEtype", ([&] {
int64_t fanout = fanouts[etype]; const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();
while (etype_end < offset + num_neighbors && while (etype_end < offset + num_neighbors) {
type_per_edge[etype_end].item<int64_t>() == etype) { scalar_t etype = type_per_edge_data[etype_end];
etype_end++; int64_t fanout = fanouts[etype];
} while (etype_end < offset + num_neighbors &&
// Do sampling for one etype. type_per_edge_data[etype_end] == etype) {
if (fanout != 0) { etype_end++;
picked_neighbors[etype] = Pick( }
etype_begin, etype_end - etype_begin, fanout, replace, options, // Do sampling for one etype.
probs_or_mask); if (fanout != 0) {
} picked_neighbors[etype] = Pick(
etype_begin = etype_end; etype_begin, etype_end - etype_begin, fanout, replace, options,
} probs_or_mask);
}
etype_begin = etype_end;
}
}));
return torch::cat(picked_neighbors, 0); return torch::cat(picked_neighbors, 0);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment