"...api/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8d14edf27ff28a5a37cdb19927579a2d590a7af2"
Unverified Commit a33fafb7 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Change probs to name of attribute (#5968)

parent e6e54304
...@@ -145,10 +145,11 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -145,10 +145,11 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* Otherwise, each value can be selected only once. * Otherwise, each value can be selected only once.
* @param return_eids Boolean indicating whether edge IDs need to be returned, * @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required. * typically used when edge features are required.
* @param probs_or_mask Optional tensor containing the (unnormalized) * @param probs_name An optional string specifying the name of an edge
* probabilities or boolean mask associated with each neighboring edge of a * attribute. This attribute tensor should contain (unnormalized)
* node. It must be a 1D floating-point or boolean tensor with the number of * probabilities corresponding to each neighboring edge of a node. It must be
* elements equal to the number of edges. * a 1D floating-point or boolean tensor, with the number of elements
* equalling the total number of edges.
* *
* @return An intrusive pointer to a SampledSubgraph object containing the * @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information. * sampled graph's information.
...@@ -156,7 +157,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -156,7 +157,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
c10::intrusive_ptr<SampledSubgraph> SampleNeighbors( c10::intrusive_ptr<SampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts, const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool return_eids, bool replace, bool return_eids,
torch::optional<torch::Tensor> probs_or_mask) const; torch::optional<std::string> probs_name) const;
/** /**
* @brief Sample negative edges by randomly choosing negative * @brief Sample negative edges by randomly choosing negative
......
...@@ -132,15 +132,18 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph( ...@@ -132,15 +132,18 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors( c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts, const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool return_eids, bool replace, bool return_eids,
torch::optional<torch::Tensor> probs_or_mask) const { torch::optional<std::string> probs_name) const {
const int64_t num_nodes = nodes.size(0); const int64_t num_nodes = nodes.size(0);
// Note probs will be passed as input for 'torch.multinomial' in deeper stack, torch::optional<torch::Tensor> probs_or_mask = torch::nullopt;
// which doesn't support 'torch.half' and 'torch.bool' data types. To avoid if (probs_name.has_value() && !probs_name.value().empty()) {
// crashes, convert 'probs_or_mask' to 'float32' data type. probs_or_mask = edge_attributes_.value().at(probs_name.value());
if (probs_or_mask.has_value() && // Note probs will be passed as input for 'torch.multinomial' in deeper
(probs_or_mask.value().dtype() == torch::kBool || // stack, which doesn't support 'torch.half' and 'torch.bool' data types. To
probs_or_mask.value().dtype() == torch::kFloat16)) { // avoid crashes, convert 'probs_or_mask' to 'float32' data type.
probs_or_mask = probs_or_mask.value().to(torch::kFloat32); if (probs_or_mask.value().dtype() == torch::kBool ||
probs_or_mask.value().dtype() == torch::kFloat16) {
probs_or_mask = probs_or_mask.value().to(torch::kFloat32);
}
} }
// If true, perform sampling for each edge type of each node, otherwise just // If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types. // sample once for each node with no regard of edge types.
......
...@@ -219,7 +219,7 @@ class CSCSamplingGraph: ...@@ -219,7 +219,7 @@ class CSCSamplingGraph:
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
return_eids: bool = False, return_eids: bool = False,
probs_or_mask: Optional[torch.Tensor] = None, probs_name: Optional[str] = None,
) -> torch.ScriptObject: ) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph. subgraph.
...@@ -252,11 +252,12 @@ class CSCSamplingGraph: ...@@ -252,11 +252,12 @@ class CSCSamplingGraph:
Boolean indicating whether the edge IDs of sampled edges, Boolean indicating whether the edge IDs of sampled edges,
represented as a 1D tensor, should be returned. This is represented as a 1D tensor, should be returned. This is
typically used when edge features are required. typically used when edge features are required.
probs_or_mask: torch.Tensor, optional probs_name: str, optional
Optional tensor containing the (unnormalized) probabilities An optional string specifying the name of an edge attribute. This
associated with each neighboring edge of a node. It must be a 1D attribute tensor should contain (unnormalized) probabilities
floating-point or boolean tensor with the number of elements equal corresponding to each neighboring edge of a node. It must be a 1D
to the number of edges. floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
Returns Returns
------- -------
torch.classes.graphbolt.SampledSubgraph torch.classes.graphbolt.SampledSubgraph
...@@ -302,7 +303,11 @@ class CSCSamplingGraph: ...@@ -302,7 +303,11 @@ class CSCSamplingGraph:
(fanouts >= 0) | (fanouts == -1) (fanouts >= 0) | (fanouts == -1)
), "Fanouts should consist of values that are either -1 or \ ), "Fanouts should consist of values that are either -1 or \
greater than or equal to 0." greater than or equal to 0."
if probs_or_mask is not None: if probs_name:
assert (
probs_name in self.edge_attributes
), f"Unknown edge attribute '{probs_name}'."
probs_or_mask = self.edge_attributes[probs_name]
assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor." assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor."
assert ( assert (
probs_or_mask.size(0) == self.num_edges probs_or_mask.size(0) == self.num_edges
...@@ -316,7 +321,7 @@ class CSCSamplingGraph: ...@@ -316,7 +321,7 @@ class CSCSamplingGraph:
torch.float64, torch.float64,
], "Probs should have a floating-point or boolean data type." ], "Probs should have a floating-point or boolean data type."
return self._c_csc_graph.sample_neighbors( return self._c_csc_graph.sample_neighbors(
nodes, fanouts.tolist(), replace, return_eids, probs_or_mask nodes, fanouts.tolist(), replace, return_eids, probs_name
) )
def sample_negative_edges_uniform( def sample_negative_edges_uniform(
......
...@@ -3,7 +3,6 @@ import tempfile ...@@ -3,7 +3,6 @@ import tempfile
import unittest import unittest
import backend as F import backend as F
import dgl import dgl
import dgl.graphbolt as gb import dgl.graphbolt as gb
...@@ -508,29 +507,8 @@ def test_sample_neighbors_replace(replace, expected_sampled_num): ...@@ -508,29 +507,8 @@ def test_sample_neighbors_replace(replace, expected_sampled_num):
reason="Graph is CPU only at present.", reason="Graph is CPU only at present.",
) )
@pytest.mark.parametrize("replace", [True, False]) @pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize( @pytest.mark.parametrize("probs_name", ["weight", "mask"])
"probs_or_mask", def test_sample_neighbors_probs(replace, probs_name):
[
torch.tensor([2.5, 0, 8.4, 0, 0.4, 1.2, 2.5, 0, 8.4, 0.5, 0.4, 1.2]),
torch.tensor(
[
True,
False,
True,
False,
True,
True,
True,
False,
True,
True,
True,
True,
]
),
],
)
def test_sample_neighbors_probs(replace, probs_or_mask):
"""Original graph in COO: """Original graph in COO:
1 0 1 0 1 1 0 1 0 1
1 0 1 1 0 1 0 1 1 0
...@@ -546,8 +524,15 @@ def test_sample_neighbors_probs(replace, probs_or_mask): ...@@ -546,8 +524,15 @@ def test_sample_neighbors_probs(replace, probs_or_mask):
assert indptr[-1] == num_edges assert indptr[-1] == num_edges
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
edge_attributes = {
"weight": torch.FloatTensor(
[2.5, 0, 8.4, 0, 0.4, 1.2, 2.5, 0, 8.4, 0.5, 0.4, 1.2]
),
"mask": torch.BoolTensor([1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1]),
}
# Construct CSCSamplingGraph. # Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices) graph = gb.from_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
...@@ -555,7 +540,7 @@ def test_sample_neighbors_probs(replace, probs_or_mask): ...@@ -555,7 +540,7 @@ def test_sample_neighbors_probs(replace, probs_or_mask):
nodes, nodes,
fanouts=torch.tensor([2]), fanouts=torch.tensor([2]),
replace=replace, replace=replace,
probs_or_mask=probs_or_mask, probs_name=probs_name,
) )
# Verify in subgraph. # Verify in subgraph.
...@@ -587,8 +572,10 @@ def test_sample_neighbors_zero_probs(replace, probs_or_mask): ...@@ -587,8 +572,10 @@ def test_sample_neighbors_zero_probs(replace, probs_or_mask):
assert indptr[-1] == num_edges assert indptr[-1] == num_edges
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
edge_attributes = {"probs_or_mask": probs_or_mask}
# Construct CSCSamplingGraph. # Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices) graph = gb.from_csc(indptr, indices, edge_attributes=edge_attributes)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
...@@ -596,7 +583,7 @@ def test_sample_neighbors_zero_probs(replace, probs_or_mask): ...@@ -596,7 +583,7 @@ def test_sample_neighbors_zero_probs(replace, probs_or_mask):
nodes, nodes,
fanouts=torch.tensor([5]), fanouts=torch.tensor([5]),
replace=replace, replace=replace,
probs_or_mask=probs_or_mask, probs_name="probs_or_mask",
) )
# Verify in subgraph. # Verify in subgraph.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment