Unverified Commit 9273387e authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] move return_eids check to internal python API (#7071)

parent 1e6fa711
...@@ -625,8 +625,16 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -625,8 +625,16 @@ class FusedCSCSamplingGraph(SamplingGraph):
if isinstance(nodes, dict): if isinstance(nodes, dict):
nodes = self._convert_to_homogeneous_nodes(nodes) nodes = self._convert_to_homogeneous_nodes(nodes)
return_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
C_sampled_subgraph = self._sample_neighbors( C_sampled_subgraph = self._sample_neighbors(
nodes, fanouts, replace, probs_name nodes,
fanouts,
replace=replace,
probs_name=probs_name,
return_eids=return_eids,
) )
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(C_sampled_subgraph)
...@@ -679,6 +687,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -679,6 +687,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
return_eids: bool = False,
) -> torch.ScriptObject: ) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph. subgraph.
...@@ -714,6 +723,9 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -714,6 +723,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
corresponding to each neighboring edge of a node. It must be a 1D corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements floating-point or boolean tensor, with the number of elements
equalling the total number of edges. equalling the total number of edges.
return_eids: bool, optional
Boolean indicating whether to return the original edge IDs of the
sampled edges.
Returns Returns
------- -------
...@@ -722,16 +734,12 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -722,16 +734,12 @@ class FusedCSCSamplingGraph(SamplingGraph):
""" """
# Ensure nodes is 1-D tensor. # Ensure nodes is 1-D tensor.
self._check_sampler_arguments(nodes, fanouts, probs_name) self._check_sampler_arguments(nodes, fanouts, probs_name)
has_original_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
return self._c_csc_graph.sample_neighbors( return self._c_csc_graph.sample_neighbors(
nodes, nodes,
fanouts.tolist(), fanouts.tolist(),
replace, replace,
False, False,
has_original_eids, return_eids,
probs_name, probs_name,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment