"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "1899457b24fdac6e0c7a8280e10035d216efe06c"
Unverified Commit c4aa74ba authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt] Add probs for neighbor sampling (#5774)

parent b03d70d3
......@@ -128,8 +128,8 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* each fanout value corresponds to a specific edge type of the node.
* The value of each fanout should be >= 0 or = -1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors when the fanout is >= the number of
* neighbors (and replacement is set to false).
* equivalent to selecting all neighbors with non-zero probability when the
* fanout is >= the number of neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
......@@ -137,13 +137,18 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* Otherwise, each value can be selected only once.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* @param probs_or_mask Optional tensor containing the (unnormalized)
* probabilities or boolean mask associated with each neighboring edge of a
* node. It must be a 1D floating-point or boolean tensor with the number of
* elements equal to the number of edges.
*
* @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information.
*/
c10::intrusive_ptr<SampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool return_eids) const;
bool replace, bool return_eids,
torch::optional<torch::Tensor> probs_or_mask) const;
/**
* @brief Copy the graph to shared memory.
......@@ -221,26 +226,40 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* @brief Picks a specified number of neighbors for a node, starting from the
* given offset and having the specified number of neighbors.
*
* If 'probs_or_mask' is provided, it indicates that the sampling is
* non-uniform. In such cases:
* - When the number of neighbors with non-zero probability is less than or
* equal to fanout, all neighbors with non-zero probability will be selected.
* - When the number of neighbors with non-zero probability exceeds fanout, the
* sampling process will select 'fanout' elements based on their respective
* probabilities. Higher probabilities will increase the chances of being chosen
* during the sampling process.
*
* @param offset The starting edge ID for the connected neighbors of the sampled
* node.
* @param num_neighbors The number of neighbors to pick.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors when the fanout is >= the number of
* neighbors (and replacement is set to false).
* equivalent to selecting all neighbors with non-zero probability when the
* fanout is >= the number of neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
* @param probs_or_mask Optional tensor containing the (unnormalized)
* probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
*
* @return A tensor containing the picked neighbors.
*/
torch::Tensor Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options);
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask);
/**
* @brief Picks a specified number of neighbors for a node per edge type,
......@@ -251,9 +270,9 @@ torch::Tensor Pick(
* @param num_neighbors The number of neighbors to pick.
* @param fanouts The edge sampling numbers corresponding to each edge type for
* a single node. The value of each fanout should be >= 0 or = 1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors when the fanout is >= the number of
* neighbors (and replacement is set to false).
* - When the value is -1, all neighbors with non-zero probability will be
* chosen for sampling. It is equivalent to selecting all neighbors when the
* fanout is >= the number of neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum threshold
* for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
......@@ -262,13 +281,18 @@ torch::Tensor Pick(
* @param options Tensor options specifying the desired data type of the result.
* @param type_per_edge Tensor representing the type of each edge in the
* original graph.
* @param probs_or_mask Optional tensor containing the (unnormalized)
* probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
*
* @return A tensor containing the picked neighbors.
*/
torch::Tensor PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge);
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask);
} // namespace sampling
} // namespace graphbolt
......
......@@ -123,8 +123,17 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool return_eids) const {
bool replace, bool return_eids,
torch::optional<torch::Tensor> probs_or_mask) const {
const int64_t num_nodes = nodes.size(0);
// Note probs will be passed as input for 'torch.multinomial' in deeper stack,
// which doesn't support 'torch.half' and 'torch.bool' data types. To avoid
// crashes, convert 'probs_or_mask' to 'float32' data type.
if (probs_or_mask.has_value() &&
(probs_or_mask.value().dtype() == torch::kBool ||
probs_or_mask.value().dtype() == torch::kFloat16)) {
probs_or_mask = probs_or_mask.value().to(torch::kFloat32);
}
// If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types.
bool consider_etype = (fanouts.size() > 1);
......@@ -153,10 +162,11 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
if (consider_etype) {
picked_neighbors_per_node[i] = PickByEtype(
offset, num_neighbors, fanouts, replace, indptr_.options(),
type_per_edge_.value());
type_per_edge_.value(), probs_or_mask);
} else {
picked_neighbors_per_node[i] =
Pick(offset, num_neighbors, fanouts[0], replace, indptr_.options());
picked_neighbors_per_node[i] = Pick(
offset, num_neighbors, fanouts[0], replace, indptr_.options(),
probs_or_mask);
}
num_picked_neighbors_per_node[i + 1] =
picked_neighbors_per_node[i].size(0);
......@@ -210,7 +220,27 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory(
return BuildGraphFromSharedMemoryTensors(std::move(shared_memory_tensors));
}
torch::Tensor Pick(
/**
* @brief Perform uniform sampling of elements and return the sampled indices.
*
* @param offset The starting edge ID for the connected neighbors of the sampled
* node.
* @param num_neighbors The number of neighbors to pick.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors with non-zero probability when the
* fanout is >= the number of neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
*
* @return A tensor containing the picked neighbors.
*/
inline torch::Tensor UniformPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options) {
torch::Tensor picked_neighbors;
......@@ -221,17 +251,86 @@ torch::Tensor Pick(
picked_neighbors =
torch::randint(offset, offset + num_neighbors, {fanout}, options);
} else {
picked_neighbors = torch::randperm(num_neighbors, options) + offset;
picked_neighbors = picked_neighbors.slice(0, 0, fanout);
picked_neighbors = torch::randperm(num_neighbors, options);
picked_neighbors = picked_neighbors.slice(0, 0, fanout) + offset;
}
}
return picked_neighbors;
}
/**
* @brief Perform non-uniform sampling of elements based on probabilities and
* return the sampled indices.
*
* If 'probs_or_mask' is provided, it indicates that the sampling is
* non-uniform. In such cases:
* - When the number of neighbors with non-zero probability is less than or
* equal to fanout, all neighbors with non-zero probability will be selected.
* - When the number of neighbors with non-zero probability exceeds fanout, the
* sampling process will select 'fanout' elements based on their respective
* probabilities. Higher probabilities will increase the chances of being chosen
* during the sampling process.
*
* @param offset The starting edge ID for the connected neighbors of the sampled
* node.
* @param num_neighbors The number of neighbors to pick.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors with non-zero probability when the
* fanout is >= the number of neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result.
* @param probs_or_mask Optional tensor containing the (unnormalized)
* probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
*
* @return A tensor containing the picked neighbors.
*/
inline torch::Tensor NonUniformPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask) {
torch::Tensor picked_neighbors;
auto local_probs =
probs_or_mask.value().slice(0, offset, offset + num_neighbors);
auto positive_probs_indices = local_probs.nonzero().squeeze(1);
auto num_positive_probs = positive_probs_indices.size(0);
if (num_positive_probs == 0) return torch::tensor({}, options);
if ((fanout == -1) || (num_positive_probs <= fanout && !replace)) {
picked_neighbors = torch::arange(offset, offset + num_neighbors, options);
picked_neighbors =
torch::index_select(picked_neighbors, 0, positive_probs_indices);
} else {
if (!replace) fanout = std::min(fanout, num_positive_probs);
picked_neighbors =
torch::multinomial(local_probs, fanout, replace) + offset;
}
return picked_neighbors;
}
torch::Tensor Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask) {
if (probs_or_mask.has_value()) {
return NonUniformPick(
offset, num_neighbors, fanout, replace, options, probs_or_mask);
} else {
return UniformPick(offset, num_neighbors, fanout, replace, options);
}
}
torch::Tensor PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge) {
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask) {
std::vector<torch::Tensor> picked_neighbors(
fanouts.size(), torch::tensor({}, options));
int64_t etype_begin = offset;
......@@ -245,8 +344,9 @@ torch::Tensor PickByEtype(
}
// Do sampling for one etype.
if (fanout != 0) {
picked_neighbors[etype] =
Pick(etype_begin, etype_end - etype_begin, fanout, replace, options);
picked_neighbors[etype] = Pick(
etype_begin, etype_end - etype_begin, fanout, replace, options,
probs_or_mask);
}
etype_begin = etype_end;
}
......
......@@ -200,6 +200,7 @@ class CSCSamplingGraph:
fanouts: torch.Tensor,
replace: bool = False,
return_eids: bool = False,
probs_or_mask: Optional[torch.Tensor] = None,
) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
......@@ -231,7 +232,12 @@ class CSCSamplingGraph:
return_eids: bool
Boolean indicating whether the edge IDs of sampled edges,
represented as a 1D tensor, should be returned. This is
typically used when edge features are required
typically used when edge features are required.
probs_or_mask: torch.Tensor, optional
Optional tensor containing the (unnormalized) probabilities
associated with each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor with the number of elements equal
to the number of edges.
Returns
-------
SampledSubgraph
......@@ -273,8 +279,21 @@ class CSCSamplingGraph:
assert len(self.metadata.edge_type_to_id) == fanouts.size(
0
), "Fanouts should have the same number of elements as etypes."
if probs_or_mask is not None:
assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor."
assert (
probs_or_mask.size(0) == self.num_edges
), "Probs should have the same number of elements as the number \
of edges."
assert probs_or_mask.dtype in [
torch.bool,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
], "Probs should have a floating-point or boolean data type."
return self._c_csc_graph.sample_neighbors(
nodes, fanouts.tolist(), replace, return_eids
nodes, fanouts.tolist(), replace, return_eids, probs_or_mask
)
def copy_to_shared_memory(self, shared_memory_name: str):
......
......@@ -516,6 +516,107 @@ def test_sample_neighbors_replace(replace, expected_sampled_num):
assert sampled_num == expected_sampled_num
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize(
"probs_or_mask",
[
torch.tensor([2.5, 0, 8.4, 0, 0.4, 1.2, 2.5, 0, 8.4, 0.5, 0.4, 1.2]),
torch.tensor(
[
True,
False,
True,
False,
True,
True,
True,
False,
True,
True,
True,
True,
]
),
],
)
def test_sample_neighbors_probs(replace, probs_or_mask):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
num_nodes = 5
num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(
nodes,
fanouts=torch.tensor([2]),
replace=replace,
probs_or_mask=probs_or_mask,
)
# Verify in subgraph.
sampled_num = subgraph.indices.size(0)
if replace:
assert sampled_num == 6
else:
assert sampled_num == 4
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize(
"probs_or_mask",
[
torch.zeros(12, dtype=torch.float32),
torch.zeros(12, dtype=torch.bool),
],
)
def test_sample_neighbors_zero_probs(replace, probs_or_mask):
# Initialize data.
num_nodes = 5
num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(
nodes,
fanouts=torch.tensor([5]),
replace=replace,
probs_or_mask=probs_or_mask,
)
# Verify in subgraph.
sampled_num = subgraph.indices.size(0)
assert sampled_num == 0
def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
"""Check if two tensors are on the same shared memory.
......@@ -633,4 +734,10 @@ def test_hetero_graph_on_shared_memory(
if __name__ == "__main__":
test_sample_neighbors()
test_sample_neighbors_replace(True, 12)
test_sample_neighbors_probs(
False,
torch.tensor([2.5, 0, 8.4, 0, 0.4, 1.2, 2.5, 0, 8.4, 0, 0.4, 1.2]),
)
test_sample_neighbors_zero_probs(True, torch.zeros(12, dtype=torch.float32))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment