Unverified Commit 1e6fa711 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Add fast path for tamporal sampling. (#7078)

parent 15695ed0
...@@ -810,12 +810,71 @@ torch::Tensor TemporalMask( ...@@ -810,12 +810,71 @@ torch::Tensor TemporalMask(
return mask; return mask;
} }
/**
* @brief Fast path for temporal sampling without probability. It is used when
* the number of neighbors is large. It randomly samples neighbors and checks
* the timestamp of the neighbors. It is successful if the number of sampled
* neighbors in kTriedThreshold trials is equal to the fanout.
*/
std::pair<bool, std::vector<int64_t>> FastTemporalPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indices, int64_t fanout,
bool replace, const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors) {
constexpr int64_t kTriedThreshold = 1000;
auto timestamp = utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset);
std::vector<int64_t> sampled_edges;
sampled_edges.reserve(fanout);
std::set<int64_t> sampled_edge_set;
int64_t sample_count = 0;
int64_t tried = 0;
while (sample_count < fanout && tried < kTriedThreshold) {
int64_t edge_id =
RandomEngine::ThreadLocal()->RandInt(offset, offset + num_neighbors);
++tried;
if (!replace && sampled_edge_set.count(edge_id) > 0) {
continue;
}
if (node_timestamp.has_value()) {
int64_t neighbor_id =
utils::GetValueByIndex<int64_t>(csc_indices, edge_id);
if (utils::GetValueByIndex<int64_t>(
node_timestamp.value(), neighbor_id) >= timestamp)
continue;
}
if (edge_timestamp.has_value() &&
utils::GetValueByIndex<int64_t>(edge_timestamp.value(), edge_id) >=
timestamp) {
continue;
}
if (!replace) {
sampled_edge_set.insert(edge_id);
}
sampled_edges.push_back(edge_id);
sample_count++;
}
if (sample_count < fanout) {
return {false, {}};
}
return {true, sampled_edges};
}
int64_t TemporalNumPick( int64_t TemporalNumPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout, torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout,
bool replace, const torch::optional<torch::Tensor>& probs_or_mask, bool replace, const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp, const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset, const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors) { int64_t offset, int64_t num_neighbors) {
constexpr int64_t kFastPathThreshold = 1000;
if (num_neighbors > kFastPathThreshold && !probs_or_mask.has_value()) {
// TODO: Currently we use the fast path both in TemporalNumPick and
// TemporalPick. We may only sample once in TemporalNumPick and use the
// sampled edges in TemporalPick to avoid sampling twice.
auto [success, sampled_edges] = FastTemporalPick(
seed_timestamp, csc_indics, fanout, replace, node_timestamp,
edge_timestamp, seed_offset, offset, num_neighbors);
if (success) return sampled_edges.size();
}
auto mask = TemporalMask( auto mask = TemporalMask(
utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset), csc_indics, utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset), csc_indics,
probs_or_mask, node_timestamp, edge_timestamp, probs_or_mask, node_timestamp, edge_timestamp,
...@@ -1183,6 +1242,19 @@ int64_t TemporalPick( ...@@ -1183,6 +1242,19 @@ int64_t TemporalPick(
const torch::optional<torch::Tensor>& node_timestamp, const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, SamplerArgs<S> args, const torch::optional<torch::Tensor>& edge_timestamp, SamplerArgs<S> args,
PickedType* picked_data_ptr) { PickedType* picked_data_ptr) {
constexpr int64_t kFastPathThreshold = 1000;
if (S == SamplerType::NEIGHBOR && num_neighbors > kFastPathThreshold &&
!probs_or_mask.has_value()) {
auto [success, sampled_edges] = FastTemporalPick(
seed_timestamp, csc_indices, fanout, replace, node_timestamp,
edge_timestamp, seed_offset, offset, num_neighbors);
if (success) {
for (size_t i = 0; i < sampled_edges.size(); ++i) {
picked_data_ptr[i] = static_cast<PickedType>(sampled_edges[i]);
}
return sampled_edges.size();
}
}
auto mask = TemporalMask( auto mask = TemporalMask(
utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset), csc_indices, utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset), csc_indices,
probs_or_mask, node_timestamp, edge_timestamp, probs_or_mask, node_timestamp, edge_timestamp,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment