Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
af870387
Unverified
Commit
af870387
authored
Jan 30, 2024
by
Rhett Ying
Committed by
GitHub
Jan 30, 2024
Browse files
[DistGB] add verify logic for GraphBolt partitions (#7031)
parent
cda8b381
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
3 deletions
+66
-3
python/dgl/distributed/partition.py
python/dgl/distributed/partition.py
+66
-3
No files found.
python/dgl/distributed/partition.py
View file @
af870387
...
...
@@ -12,7 +12,7 @@ import torch
from
..
import
backend
as
F
,
graphbolt
as
gb
from
..base
import
dgl_warning
,
DGLError
,
EID
,
ETYPE
,
NID
,
NTYPE
from
..convert
import
to_homogeneous
from
..convert
import
heterograph
,
to_homogeneous
from
..data.utils
import
load_graphs
,
load_tensors
,
save_graphs
,
save_tensors
from
..partition
import
(
get_peak_mem
,
...
...
@@ -190,8 +190,71 @@ def _verify_dgl_partition(graph, part_id, gpb, ntypes, etypes):
def
_verify_graphbolt_partition
(
graph
,
part_id
,
gpb
,
ntypes
,
etypes
):
"""Verify the partition of a GraphBolt graph."""
# [Rui][TODO]
_
,
_
,
_
,
_
,
_
=
graph
,
part_id
,
gpb
,
ntypes
,
etypes
required_ndata_fields
=
[
NID
]
required_edata_fields
=
[
EID
]
assert
all
(
field
in
graph
.
node_attributes
for
field
in
required_ndata_fields
),
"the partition graph should contain node mapping to global node ID."
assert
all
(
field
in
graph
.
edge_attributes
for
field
in
required_edata_fields
),
"the partition graph should contain edge mapping to global edge ID."
num_nodes
=
graph
.
total_num_nodes
num_edges
=
graph
.
total_num_edges
local_src_ids
=
graph
.
indices
local_dst_ids
=
torch
.
repeat_interleave
(
torch
.
arange
(
num_nodes
),
torch
.
diff
(
graph
.
csc_indptr
)
)
global_src_ids
=
graph
.
node_attributes
[
NID
][
local_src_ids
]
global_dst_ids
=
graph
.
node_attributes
[
NID
][
local_dst_ids
]
etype_ids
,
type_wise_eids
=
gpb
.
map_to_per_etype
(
graph
.
edge_attributes
[
EID
])
if
graph
.
type_per_edge
is
not
None
:
assert
torch
.
equal
(
etype_ids
,
graph
.
type_per_edge
)
etype_ids
,
etype_ids_indices
=
torch
.
sort
(
etype_ids
)
global_src_ids
=
global_src_ids
[
etype_ids_indices
]
global_dst_ids
=
global_dst_ids
[
etype_ids_indices
]
type_wise_eids
=
type_wise_eids
[
etype_ids_indices
]
src_ntype_ids
,
src_type_wise_nids
=
gpb
.
map_to_per_ntype
(
global_src_ids
)
dst_ntype_ids
,
dst_type_wise_nids
=
gpb
.
map_to_per_ntype
(
global_dst_ids
)
data_dict
=
dict
()
edge_ids
=
dict
()
for
c_etype
,
etype_id
in
etypes
.
items
():
idx
=
etype_ids
==
etype_id
src_ntype
,
etype
,
dst_ntype
=
c_etype
if
idx
.
sum
()
==
0
:
continue
actual_src_ntype_ids
=
src_ntype_ids
[
idx
]
actual_dst_ntype_ids
=
dst_ntype_ids
[
idx
]
expected_src_ntype_ids
=
ntypes
[
src_ntype
]
expected_dst_ntype_ids
=
ntypes
[
dst_ntype
]
assert
all
(
actual_src_ntype_ids
==
expected_src_ntype_ids
),
(
f
"Unexpected types of source nodes for
{
c_etype
}
. Expected: "
f
"
{
expected_src_ntype_ids
}
, but got:
{
actual_src_ntype_ids
}
."
)
assert
all
(
actual_dst_ntype_ids
==
expected_dst_ntype_ids
),
(
f
"Unexpected types of destination nodes for
{
c_etype
}
. Expected: "
f
"
{
expected_dst_ntype_ids
}
, but got:
{
actual_dst_ntype_ids
}
."
)
data_dict
[
c_etype
]
=
(
src_type_wise_nids
[
idx
],
dst_type_wise_nids
[
idx
])
edge_ids
[
c_etype
]
=
type_wise_eids
[
idx
]
# Make sure node/edge IDs are not out of range.
hg
=
heterograph
(
data_dict
,
{
ntype
:
gpb
.
_num_nodes
(
ntype
)
for
ntype
in
ntypes
}
)
for
etype
in
edge_ids
:
hg
.
edges
[
etype
].
data
[
EID
]
=
edge_ids
[
etype
]
assert
all
(
hg
.
num_edges
(
etype
)
==
len
(
eids
)
for
etype
,
eids
in
edge_ids
.
items
()
),
"The number of edges per etype in the partition graph is not correct."
assert
num_edges
==
hg
.
num_edges
(),
(
f
"The total number of edges in the partition graph is not correct. "
f
"Expected:
{
num_edges
}
, but got:
{
hg
.
num_edges
()
}
."
)
print
(
f
"Partition
{
part_id
}
looks good!"
)
def
load_partition
(
part_config
,
part_id
,
load_feats
=
True
,
use_graphbolt
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment