Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
308f8ca3
Unverified
Commit
308f8ca3
authored
Nov 21, 2023
by
Rhett Ying
Committed by
GitHub
Nov 21, 2023
Browse files
[GraphBolt] enable more dtypes for sample_neighbors (#6523)
parent
ba2ca4be
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
13 deletions
+25
-13
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+4
-4
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+21
-9
No files found.
graphbolt/src/fused_csc_sampling_graph.cc
View file @
308f8ca3
...
...
@@ -335,13 +335,12 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
auto
num_picked_neighbors_data_ptr
=
num_picked_neighbors_per_node
.
data_ptr
<
scalar_t
>
();
num_picked_neighbors_data_ptr
[
0
]
=
0
;
const
auto
nodes_data_ptr
=
nodes
.
data_ptr
<
int64_t
>
();
// Step 1. Calculate pick number of each node.
torch
::
parallel_for
(
0
,
num_nodes
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
const
auto
nid
=
nodes
_data_ptr
[
i
]
;
const
auto
nid
=
nodes
[
i
].
item
<
int64_t
>
()
;
TORCH_CHECK
(
nid
>=
0
&&
nid
<
NumNodes
(),
"The seed nodes' IDs should fall within the range of the "
...
...
@@ -356,7 +355,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
// Step 2. Calculate prefix sum to get total length and offsets of each
// node. It's also the indptr of the generated subgraph.
subgraph_indptr
=
torch
::
cumsum
(
num_picked_neighbors_per_node
,
0
);
subgraph_indptr
=
num_picked_neighbors_per_node
.
cumsum
(
0
,
indptr_
.
scalar_type
());
// Step 3. Allocate the tensor for picked neighbors.
const
auto
total_length
=
...
...
@@ -374,7 +374,7 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
torch
::
parallel_for
(
0
,
num_nodes
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
const
auto
nid
=
nodes
_data_ptr
[
i
]
;
const
auto
nid
=
nodes
[
i
].
item
<
int64_t
>
()
;
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
const
auto
picked_number
=
num_picked_neighbors_data_ptr
[
i
+
1
];
...
...
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
308f8ca3
...
...
@@ -604,7 +604,10 @@ def test_in_subgraph_heterogeneous():
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph is CPU only at present."
,
)
def
test_sample_neighbors_homo
():
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"indptr_dtype"
,
[
torch
.
int32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"indices_dtype"
,
[
torch
.
int32
,
torch
.
int64
])
def
test_sample_neighbors_homo
(
labor
,
indptr_dtype
,
indices_dtype
):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
...
...
@@ -615,17 +618,21 @@ def test_sample_neighbors_homo():
# Initialize data.
total_num_nodes
=
5
total_num_edges
=
12
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
indptr
=
torch
.
tensor
([
0
,
3
,
5
,
7
,
9
,
12
],
dtype
=
indptr_dtype
)
indices
=
torch
.
tensor
(
[
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
],
dtype
=
indices_dtype
)
assert
indptr
[
-
1
]
==
total_num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
assert
len
(
indptr
)
==
total_num_nodes
+
1
# Construct FusedCSCSamplingGraph.
graph
=
gb
.
from_fused_csc
(
indptr
,
indices
)
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
1
,
3
,
4
])
subgraph
=
graph
.
sample_neighbors
(
nodes
,
fanouts
=
torch
.
LongTensor
([
2
]))
nodes
=
torch
.
tensor
([
1
,
3
,
4
],
dtype
=
indices_dtype
)
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
subgraph
=
sampler
(
nodes
,
fanouts
=
torch
.
LongTensor
([
2
]))
# Verify in subgraph.
sampled_num
=
subgraph
.
node_pairs
[
0
].
size
(
0
)
...
...
@@ -640,7 +647,9 @@ def test_sample_neighbors_homo():
reason
=
"Graph is CPU only at present."
,
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_sample_neighbors_hetero
(
labor
):
@
pytest
.
mark
.
parametrize
(
"indptr_dtype"
,
[
torch
.
int32
,
torch
.
int64
])
@
pytest
.
mark
.
parametrize
(
"indices_dtype"
,
[
torch
.
int32
,
torch
.
int64
])
def
test_sample_neighbors_hetero
(
labor
,
indptr_dtype
,
indices_dtype
):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
...
...
@@ -656,8 +665,8 @@ def test_sample_neighbors_hetero(labor):
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
total_num_nodes
=
5
total_num_edges
=
9
indptr
=
torch
.
LongT
ensor
([
0
,
2
,
4
,
6
,
7
,
9
])
indices
=
torch
.
LongT
ensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
1
])
indptr
=
torch
.
t
ensor
([
0
,
2
,
4
,
6
,
7
,
9
]
,
dtype
=
indptr_dtype
)
indices
=
torch
.
t
ensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
1
]
,
dtype
=
indices_dtype
)
type_per_edge
=
torch
.
LongTensor
([
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
])
node_type_offset
=
torch
.
LongTensor
([
0
,
2
,
5
])
assert
indptr
[
-
1
]
==
total_num_edges
...
...
@@ -673,7 +682,10 @@ def test_sample_neighbors_hetero(labor):
)
# Sample on both node types.
nodes
=
{
"n1"
:
torch
.
LongTensor
([
0
]),
"n2"
:
torch
.
LongTensor
([
0
])}
nodes
=
{
"n1"
:
torch
.
tensor
([
0
],
dtype
=
indices_dtype
),
"n2"
:
torch
.
tensor
([
0
],
dtype
=
indices_dtype
),
}
fanouts
=
torch
.
tensor
([
-
1
,
-
1
])
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
subgraph
=
sampler
(
nodes
,
fanouts
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment