Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
95d62394
Unverified
Commit
95d62394
authored
Jan 04, 2024
by
Mingbang Wang
Committed by
GitHub
Jan 04, 2024
Browse files
[GraphBolt] Use `diff()` to calculate the differences for simplicity (#6884)
parent
e9162491
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
20 deletions
+8
-20
python/dgl/graphbolt/sampled_subgraph.py
python/dgl/graphbolt/sampled_subgraph.py
+2
-6
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+3
-3
tests/python/pytorch/graphbolt/impl/test_minibatch.py
tests/python/pytorch/graphbolt/impl/test_minibatch.py
+3
-11
No files found.
python/dgl/graphbolt/sampled_subgraph.py
View file @
95d62394
...
...
@@ -227,13 +227,9 @@ def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
original_row_node_ids
,
dim
=
0
,
index
=
indices
)
if
original_column_node_ids
is
not
None
:
indptr
=
original_column_node_ids
.
repeat_interleave
(
indptr
[
1
:]
-
indptr
[:
-
1
]
)
indptr
=
original_column_node_ids
.
repeat_interleave
(
indptr
.
diff
())
else
:
indptr
=
torch
.
arange
(
len
(
indptr
)
-
1
).
repeat_interleave
(
indptr
[
1
:]
-
indptr
[:
-
1
]
)
indptr
=
torch
.
arange
(
len
(
indptr
)
-
1
).
repeat_interleave
(
indptr
.
diff
())
return
(
indices
,
indptr
)
...
...
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
95d62394
...
...
@@ -1402,7 +1402,7 @@ def test_from_dglgraph_homogeneous():
dgl_g
,
is_homogeneous
=
True
,
include_original_edge_id
=
True
)
# Get the COO representation of the FusedCSCSamplingGraph.
num_columns
=
gb_g
.
csc_indptr
[
1
:]
-
gb_g
.
csc_indptr
[:
-
1
]
num_columns
=
gb_g
.
csc_indptr
.
diff
()
rows
=
gb_g
.
indices
columns
=
torch
.
arange
(
gb_g
.
total_num_nodes
).
repeat_interleave
(
num_columns
)
...
...
@@ -1456,11 +1456,11 @@ def test_from_dglgraph_heterogeneous():
# `reverse_node_id` is used to map the node id in FusedCSCSamplingGraph to the
# node id in Hetero-DGLGraph.
num_ntypes
=
gb_g
.
node_type_offset
[
1
:]
-
gb_g
.
node_type_offset
[:
-
1
]
num_ntypes
=
gb_g
.
node_type_offset
.
diff
()
reverse_node_id
=
torch
.
cat
([
torch
.
arange
(
num
)
for
num
in
num_ntypes
])
# Get the COO representation of the FusedCSCSamplingGraph.
num_columns
=
gb_g
.
csc_indptr
[
1
:]
-
gb_g
.
csc_indptr
[:
-
1
]
num_columns
=
gb_g
.
csc_indptr
.
diff
()
rows
=
reverse_node_id
[
gb_g
.
indices
]
columns
=
reverse_node_id
[
torch
.
arange
(
gb_g
.
total_num_nodes
).
repeat_interleave
(
num_columns
)
...
...
tests/python/pytorch/graphbolt/impl/test_minibatch.py
View file @
95d62394
...
...
@@ -664,10 +664,7 @@ def check_dgl_blocks_hetero(minibatch, blocks):
edges
=
block
.
edges
(
etype
=
etype
)
dst_ndoes
=
torch
.
arange
(
0
,
len
(
sampled_csc
[
i
][
relation
].
indptr
)
-
1
).
repeat_interleave
(
sampled_csc
[
i
][
relation
].
indptr
[
1
:]
-
sampled_csc
[
i
][
relation
].
indptr
[:
-
1
]
)
).
repeat_interleave
(
sampled_csc
[
i
][
relation
].
indptr
.
diff
())
assert
torch
.
equal
(
edges
[
0
],
sampled_csc
[
i
][
relation
].
indices
)
assert
torch
.
equal
(
edges
[
1
],
dst_ndoes
)
assert
torch
.
equal
(
...
...
@@ -676,10 +673,7 @@ def check_dgl_blocks_hetero(minibatch, blocks):
edges
=
blocks
[
0
].
edges
(
etype
=
gb
.
etype_str_to_tuple
(
reverse_relation
))
dst_ndoes
=
torch
.
arange
(
0
,
len
(
sampled_csc
[
0
][
reverse_relation
].
indptr
)
-
1
).
repeat_interleave
(
sampled_csc
[
0
][
reverse_relation
].
indptr
[
1
:]
-
sampled_csc
[
0
][
reverse_relation
].
indptr
[:
-
1
]
)
).
repeat_interleave
(
sampled_csc
[
0
][
reverse_relation
].
indptr
.
diff
())
assert
torch
.
equal
(
edges
[
0
],
sampled_csc
[
0
][
reverse_relation
].
indices
)
assert
torch
.
equal
(
edges
[
1
],
dst_ndoes
)
assert
torch
.
equal
(
...
...
@@ -704,9 +698,7 @@ def check_dgl_blocks_homo(minibatch, blocks):
for
i
,
block
in
enumerate
(
blocks
):
dst_ndoes
=
torch
.
arange
(
0
,
len
(
sampled_csc
[
i
].
indptr
)
-
1
).
repeat_interleave
(
sampled_csc
[
i
].
indptr
[
1
:]
-
sampled_csc
[
i
].
indptr
[:
-
1
]
)
).
repeat_interleave
(
sampled_csc
[
i
].
indptr
.
diff
())
assert
torch
.
equal
(
block
.
edges
()[
0
],
sampled_csc
[
i
].
indices
),
print
(
block
.
edges
()
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment