Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
95cf6924
Unverified
Commit
95cf6924
authored
Jun 04, 2023
by
peizhou001
Committed by
GitHub
Jun 04, 2023
Browse files
[Graphbolt] Return etypes in the neighbor sampling (#5773)
parent
ae97049e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
3 deletions
+33
-3
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+5
-1
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
+24
-0
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
+4
-2
No files found.
graphbolt/src/csc_sampling_graph.cc
View file @
95cf6924
...
...
@@ -169,11 +169,15 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
torch
::
Tensor
picked_eids
=
torch
::
cat
(
picked_neighbors_per_node
);
torch
::
Tensor
subgraph_indices
=
torch
::
index_select
(
indices_
,
0
,
picked_eids
);
torch
::
optional
<
torch
::
Tensor
>
subgraph_type_per_edge
=
torch
::
nullopt
;
if
(
type_per_edge_
.
has_value
())
subgraph_type_per_edge
=
torch
::
index_select
(
type_per_edge_
.
value
(),
0
,
picked_eids
);
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
if
(
return_eids
)
subgraph_reverse_edge_ids
=
std
::
move
(
picked_eids
);
return
c10
::
make_intrusive
<
SampledSubgraph
>
(
subgraph_indptr
,
subgraph_indices
,
nodes
,
torch
::
nullopt
,
subgraph_reverse_edge_ids
,
torch
::
nullopt
);
subgraph_reverse_edge_ids
,
subgraph_type_per_edge
);
}
c10
::
intrusive_ptr
<
CSCSamplingGraph
>
...
...
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
View file @
95cf6924
...
...
@@ -232,6 +232,30 @@ class CSCSamplingGraph:
Boolean indicating whether the edge IDs of sampled edges,
represented as a 1D tensor, should be returned. This is
typically used when edge features are required
Returns
-------
SampledSubgraph
The sampled subgraph.
Examples
--------
>>> indptr = torch.LongTensor([0, 3, 5, 7])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1])
>>> type_per_edge = torch.LongTensor([0, 0, 1, 0, 1, 0, 1])
>>> graph = gb.from_csc(indptr, indices, type_per_edge=type_per_edge)
>>> nodes = torch.LongTensor([1, 2])
>>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_neighbors(nodes, fanouts, return_eids=True)
>>> print(subgraph.indptr)
tensor([0, 2, 4])
>>> print(subgraph.indices)
tensor([2, 3, 0, 1])
>>> print(subgraph.reverse_column_node_ids)
tensor([1, 2])
>>> print(subgraph.reverse_edge_ids)
tensor([3, 4, 5, 6])
>>> print(subgraph.type_per_edge)
tensor([0, 1, 0, 1])
"""
# Ensure nodes is 1-D tensor.
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
...
...
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
View file @
95cf6924
...
...
@@ -411,7 +411,7 @@ def test_sample_neighbors():
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
fanouts
=
torch
.
tensor
([
2
,
2
,
3
])
fanouts
=
torch
.
tensor
([
2
,
2
])
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanouts
,
return_eids
=
True
)
# Verify in subgraph.
...
...
@@ -424,8 +424,10 @@ def test_sample_neighbors():
assert
torch
.
equal
(
subgraph
.
reverse_edge_ids
,
torch
.
LongTensor
([
3
,
4
,
7
,
8
,
9
,
10
,
11
])
)
assert
torch
.
equal
(
subgraph
.
type_per_edge
,
torch
.
LongTensor
([
0
,
1
,
0
,
1
,
0
,
0
,
1
])
)
assert
subgraph
.
reverse_row_node_ids
is
None
assert
subgraph
.
type_per_edge
is
None
@
unittest
.
skipIf
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment