Unverified Commit f26316ed authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Quick code polish of csc_sampling_graph.cc. (#5957)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 86befc63
......@@ -185,9 +185,10 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
torch::Tensor subgraph_indices =
torch::index_select(indices_, 0, picked_eids);
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
if (type_per_edge_.has_value())
if (type_per_edge_.has_value()) {
subgraph_type_per_edge =
torch::index_select(type_per_edge_.value(), 0, picked_eids);
}
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
return c10::make_intrusive<SampledSubgraph>(
......@@ -250,7 +251,7 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory(
* fanout is >= the number of neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* @param replace Boolean indicating whether the sample is performed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
......@@ -263,15 +264,13 @@ inline torch::Tensor UniformPick(
torch::Tensor picked_neighbors;
if ((fanout == -1) || (num_neighbors <= fanout && !replace)) {
picked_neighbors = torch::arange(offset, offset + num_neighbors, options);
} else {
if (replace) {
} else if (replace) {
picked_neighbors =
torch::randint(offset, offset + num_neighbors, {fanout}, options);
} else {
picked_neighbors = torch::randperm(num_neighbors, options);
picked_neighbors = picked_neighbors.slice(0, 0, fanout) + offset;
}
}
return picked_neighbors;
}
......@@ -298,7 +297,7 @@ inline torch::Tensor UniformPick(
* fanout is >= the number of neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* @param replace Boolean indicating whether the sample is performed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment