Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9273387e
Unverified
Commit
9273387e
authored
Feb 04, 2024
by
Rhett Ying
Committed by
GitHub
Feb 04, 2024
Browse files
[GraphBolt] move return_eids check to internal python API (#7071)
parent
1e6fa711
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
6 deletions
+14
-6
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+14
-6
No files found.
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
9273387e
...
@@ -625,8 +625,16 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -625,8 +625,16 @@ class FusedCSCSamplingGraph(SamplingGraph):
if
isinstance
(
nodes
,
dict
):
if
isinstance
(
nodes
,
dict
):
nodes
=
self
.
_convert_to_homogeneous_nodes
(
nodes
)
nodes
=
self
.
_convert_to_homogeneous_nodes
(
nodes
)
return_eids
=
(
self
.
edge_attributes
is
not
None
and
ORIGINAL_EDGE_ID
in
self
.
edge_attributes
)
C_sampled_subgraph
=
self
.
_sample_neighbors
(
C_sampled_subgraph
=
self
.
_sample_neighbors
(
nodes
,
fanouts
,
replace
,
probs_name
nodes
,
fanouts
,
replace
=
replace
,
probs_name
=
probs_name
,
return_eids
=
return_eids
,
)
)
return
self
.
_convert_to_sampled_subgraph
(
C_sampled_subgraph
)
return
self
.
_convert_to_sampled_subgraph
(
C_sampled_subgraph
)
...
@@ -679,6 +687,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -679,6 +687,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts
:
torch
.
Tensor
,
fanouts
:
torch
.
Tensor
,
replace
:
bool
=
False
,
replace
:
bool
=
False
,
probs_name
:
Optional
[
str
]
=
None
,
probs_name
:
Optional
[
str
]
=
None
,
return_eids
:
bool
=
False
,
)
->
torch
.
ScriptObject
:
)
->
torch
.
ScriptObject
:
"""Sample neighboring edges of the given nodes and return the induced
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
subgraph.
...
@@ -714,6 +723,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -714,6 +723,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
corresponding to each neighboring edge of a node. It must be a 1D
corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements
floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
equalling the total number of edges.
return_eids: bool, optional
Boolean indicating whether to return the original edge IDs of the
sampled edges.
Returns
Returns
-------
-------
...
@@ -722,16 +734,12 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -722,16 +734,12 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""
"""
# Ensure nodes is 1-D tensor.
# Ensure nodes is 1-D tensor.
self
.
_check_sampler_arguments
(
nodes
,
fanouts
,
probs_name
)
self
.
_check_sampler_arguments
(
nodes
,
fanouts
,
probs_name
)
has_original_eids
=
(
self
.
edge_attributes
is
not
None
and
ORIGINAL_EDGE_ID
in
self
.
edge_attributes
)
return
self
.
_c_csc_graph
.
sample_neighbors
(
return
self
.
_c_csc_graph
.
sample_neighbors
(
nodes
,
nodes
,
fanouts
.
tolist
(),
fanouts
.
tolist
(),
replace
,
replace
,
False
,
False
,
has_original
_eids
,
return
_eids
,
probs_name
,
probs_name
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment