Unverified Commit a2234d60 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Graphbolt] Modify labor to fix inconsistencies (#6117)


Co-authored-by: default avatarRamon Zhou <deluxurous@gmail.com>
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 707f2ae9
......@@ -556,6 +556,10 @@ torch::Tensor Pick<SamplerType::LABOR>(
SamplerArgs<SamplerType::LABOR> args) {
if (fanout == 0) return torch::tensor({}, options);
if (probs_or_mask.has_value()) {
if (fanout < 0) {
return NonUniformPick(
offset, num_neighbors, fanout, replace, options, probs_or_mask);
}
torch::Tensor picked_neighbors;
AT_DISPATCH_FLOATING_TYPES(
probs_or_mask.value().scalar_type(), "LaborPickFloatType", ([&] {
......@@ -568,6 +572,8 @@ torch::Tensor Pick<SamplerType::LABOR>(
}
}));
return picked_neighbors;
} else if (fanout < 0) {
return UniformPick(offset, num_neighbors, fanout, replace, options);
} else if (replace) {
return LaborPick<false, true>(
offset, num_neighbors, fanout, options,
......@@ -614,12 +620,7 @@ inline torch::Tensor LaborPick(
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args) {
// TODO: fix inconsistency with Neighbor sampler.
// 1. Replace = true, fanout = -1. Expected: sample all neighbors with
// non-zero probility once regardless of replacement.
// 2. Replace = true, fanout > num_neighbors. Expected: sample fanout many
// neighbors.
fanout = fanout < 0 ? num_neighbors : std::min(fanout, num_neighbors);
fanout = Replace ? fanout : std::min(fanout, num_neighbors);
if (!NonUniform && !Replace && fanout >= num_neighbors) {
return torch::arange(offset, offset + num_neighbors, options);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment