Unverified Commit a33fafb7 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Change probs to name of attribute (#5968)

parent e6e54304
......@@ -145,10 +145,11 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* Otherwise, each value can be selected only once.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* @param probs_or_mask Optional tensor containing the (unnormalized)
* probabilities or boolean mask associated with each neighboring edge of a
* node. It must be a 1D floating-point or boolean tensor with the number of
* elements equal to the number of edges.
* @param probs_name An optional string specifying the name of an edge
* attribute. This attribute tensor should contain (unnormalized)
* probabilities corresponding to each neighboring edge of a node. It must be
* a 1D floating-point or boolean tensor, with the number of elements
* equalling the total number of edges.
*
* @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information.
......@@ -156,7 +157,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
c10::intrusive_ptr<SampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool return_eids,
torch::optional<torch::Tensor> probs_or_mask) const;
torch::optional<std::string> probs_name) const;
/**
* @brief Sample negative edges by randomly choosing negative
......
......@@ -132,15 +132,18 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool return_eids,
torch::optional<torch::Tensor> probs_or_mask) const {
torch::optional<std::string> probs_name) const {
const int64_t num_nodes = nodes.size(0);
// Note probs will be passed as input for 'torch.multinomial' in deeper stack,
// which doesn't support 'torch.half' and 'torch.bool' data types. To avoid
// crashes, convert 'probs_or_mask' to 'float32' data type.
if (probs_or_mask.has_value() &&
(probs_or_mask.value().dtype() == torch::kBool ||
probs_or_mask.value().dtype() == torch::kFloat16)) {
probs_or_mask = probs_or_mask.value().to(torch::kFloat32);
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt;
if (probs_name.has_value() && !probs_name.value().empty()) {
probs_or_mask = edge_attributes_.value().at(probs_name.value());
// Note probs will be passed as input for 'torch.multinomial' in deeper
// stack, which doesn't support 'torch.half' and 'torch.bool' data types. To
// avoid crashes, convert 'probs_or_mask' to 'float32' data type.
if (probs_or_mask.value().dtype() == torch::kBool ||
probs_or_mask.value().dtype() == torch::kFloat16) {
probs_or_mask = probs_or_mask.value().to(torch::kFloat32);
}
}
// If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types.
......
......@@ -219,7 +219,7 @@ class CSCSamplingGraph:
fanouts: torch.Tensor,
replace: bool = False,
return_eids: bool = False,
probs_or_mask: Optional[torch.Tensor] = None,
probs_name: Optional[str] = None,
) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
......@@ -252,11 +252,12 @@ class CSCSamplingGraph:
Boolean indicating whether the edge IDs of sampled edges,
represented as a 1D tensor, should be returned. This is
typically used when edge features are required.
probs_or_mask: torch.Tensor, optional
Optional tensor containing the (unnormalized) probabilities
associated with each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor with the number of elements equal
to the number of edges.
probs_name: str, optional
An optional string specifying the name of an edge attribute. This
attribute tensor should contain (unnormalized) probabilities
corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
Returns
-------
torch.classes.graphbolt.SampledSubgraph
......@@ -302,7 +303,11 @@ class CSCSamplingGraph:
(fanouts >= 0) | (fanouts == -1)
), "Fanouts should consist of values that are either -1 or \
greater than or equal to 0."
if probs_or_mask is not None:
if probs_name:
assert (
probs_name in self.edge_attributes
), f"Unknown edge attribute '{probs_name}'."
probs_or_mask = self.edge_attributes[probs_name]
assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor."
assert (
probs_or_mask.size(0) == self.num_edges
......@@ -316,7 +321,7 @@ class CSCSamplingGraph:
torch.float64,
], "Probs should have a floating-point or boolean data type."
return self._c_csc_graph.sample_neighbors(
nodes, fanouts.tolist(), replace, return_eids, probs_or_mask
nodes, fanouts.tolist(), replace, return_eids, probs_name
)
def sample_negative_edges_uniform(
......
......@@ -3,7 +3,6 @@ import tempfile
import unittest
import backend as F
import dgl
import dgl.graphbolt as gb
......@@ -508,29 +507,8 @@ def test_sample_neighbors_replace(replace, expected_sampled_num):
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize(
"probs_or_mask",
[
torch.tensor([2.5, 0, 8.4, 0, 0.4, 1.2, 2.5, 0, 8.4, 0.5, 0.4, 1.2]),
torch.tensor(
[
True,
False,
True,
False,
True,
True,
True,
False,
True,
True,
True,
True,
]
),
],
)
def test_sample_neighbors_probs(replace, probs_or_mask):
@pytest.mark.parametrize("probs_name", ["weight", "mask"])
def test_sample_neighbors_probs(replace, probs_name):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
......@@ -546,8 +524,15 @@ def test_sample_neighbors_probs(replace, probs_or_mask):
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
edge_attributes = {
"weight": torch.FloatTensor(
[2.5, 0, 8.4, 0, 0.4, 1.2, 2.5, 0, 8.4, 0.5, 0.4, 1.2]
),
"mask": torch.BoolTensor([1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1]),
}
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices)
graph = gb.from_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
......@@ -555,7 +540,7 @@ def test_sample_neighbors_probs(replace, probs_or_mask):
nodes,
fanouts=torch.tensor([2]),
replace=replace,
probs_or_mask=probs_or_mask,
probs_name=probs_name,
)
# Verify in subgraph.
......@@ -587,8 +572,10 @@ def test_sample_neighbors_zero_probs(replace, probs_or_mask):
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
edge_attributes = {"probs_or_mask": probs_or_mask}
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices)
graph = gb.from_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
......@@ -596,7 +583,7 @@ def test_sample_neighbors_zero_probs(replace, probs_or_mask):
nodes,
fanouts=torch.tensor([5]),
replace=replace,
probs_or_mask=probs_or_mask,
probs_name="probs_or_mask",
)
# Verify in subgraph.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment