Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6e1cc7da
Unverified
Commit
6e1cc7da
authored
Nov 17, 2022
by
Rhett Ying
Committed by
GitHub
Nov 17, 2022
Browse files
[Dist] fix is_node check according to policy (#4911)
* [Dist] fix is_node check according to policy * add more tests
parent
b377e1b9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
59 additions
and
32 deletions
+59
-32
python/dgl/distributed/graph_partition_book.py
python/dgl/distributed/graph_partition_book.py
+34
-31
tests/distributed/test_partition.py
tests/distributed/test_partition.py
+25
-1
No files found.
python/dgl/distributed/graph_partition_book.py
View file @
6e1cc7da
...
...
@@ -1231,21 +1231,22 @@ class PartitionPolicy(object):
"""
def
__init__
(
self
,
policy_str
,
partition_book
):
if
POLICY_DELIMITER
not
in
policy_str
:
assert
policy_str
in
(
EDG
E_PART_POLICY
,
NOD
E_PART_POLICY
,
)
,
"policy_str must contain 'edge' or 'node'."
if
NODE_PART_POLICY
==
policy_str
:
policy_str
=
NODE_PART_POLICY
+
POLICY_DELIMITER
+
DEFAULT_NTYPE
else
:
policy_str
=
EDGE_PART_POLICY
+
POLICY_DELIMITER
+
DEFAULT_ETYPE
[
1
]
assert
(
policy_str
.
startswith
(
NODE_PART_POLICY
)
or
policy_str
.
startswith
(
EDGE_PART_POLICY
)),
(
f
"policy_str must start with
{
NOD
E_PART_POLICY
}
or "
f
"
{
EDG
E_PART_POLICY
}
, but got
{
policy_str
}
."
)
if
NODE_PART_POLICY
==
policy_str
:
policy_str
=
NODE_PART_POLICY
+
POLICY_DELIMITER
+
DEFAULT_NTYPE
if
EDGE_PART_POLICY
==
policy_str
:
policy_str
=
EDGE_PART_POLICY
+
POLICY_DELIMITER
+
DEFAULT_ETYPE
[
1
]
self
.
_policy_str
=
policy_str
self
.
_part_id
=
partition_book
.
partid
self
.
_partition_book
=
partition_book
part_policy
,
self
.
_type_name
=
policy_str
.
split
(
POLICY_DELIMITER
,
1
)
if
part_policy
==
EDGE_PART_POLICY
:
self
.
_type_name
=
_etype_str_to_tuple
(
self
.
_type_name
)
self
.
_is_node
=
self
.
policy_str
.
startswith
(
NODE_PART_POLICY
)
@
property
def
policy_str
(
self
):
...
...
@@ -1291,10 +1292,20 @@ class PartitionPolicy(object):
"""
return
self
.
_partition_book
@
property
def
is_node
(
self
):
"""Indicate whether the policy is for node or edge
Returns
-------
bool
node or edge
"""
return
self
.
_is_node
def
get_data_name
(
self
,
name
):
"""Get HeteroDataName"""
is_node
=
NODE_PART_POLICY
in
self
.
policy_str
return
HeteroDataName
(
is_node
,
self
.
type_name
,
name
)
return
HeteroDataName
(
self
.
is_node
,
self
.
type_name
,
name
)
def
to_local
(
self
,
id_tensor
):
"""Mapping global ID to local ID.
...
...
@@ -1309,16 +1320,14 @@ class PartitionPolicy(object):
tensor
local ID tensor
"""
if
EDGE_PART_POLICY
in
self
.
policy_str
:
return
self
.
_partition_book
.
eid2localeid
(
id_tensor
,
self
.
_part_id
,
self
.
type_name
)
elif
NODE_PART_POLICY
in
self
.
policy_str
:
if
self
.
is_node
:
return
self
.
_partition_book
.
nid2localnid
(
id_tensor
,
self
.
_part_id
,
self
.
type_name
)
else
:
raise
RuntimeError
(
"Cannot support policy: %s "
%
self
.
policy_str
)
return
self
.
_partition_book
.
eid2localeid
(
id_tensor
,
self
.
_part_id
,
self
.
type_name
)
def
to_partid
(
self
,
id_tensor
):
"""Mapping global ID to partition ID.
...
...
@@ -1333,12 +1342,10 @@ class PartitionPolicy(object):
tensor
partition ID
"""
if
EDGE_PART_POLICY
in
self
.
policy_str
:
return
self
.
_partition_book
.
eid2partid
(
id_tensor
,
self
.
type_name
)
elif
NODE_PART_POLICY
in
self
.
policy_str
:
if
self
.
is_node
:
return
self
.
_partition_book
.
nid2partid
(
id_tensor
,
self
.
type_name
)
else
:
r
aise
RuntimeError
(
"Cannot support policy: %s "
%
self
.
policy_str
)
r
eturn
self
.
_partition_book
.
eid2partid
(
id_tensor
,
self
.
type_name
)
def
get_part_size
(
self
):
"""Get data size of current partition.
...
...
@@ -1348,16 +1355,14 @@ class PartitionPolicy(object):
int
data size
"""
if
EDGE_PART_POLICY
in
self
.
policy_str
:
return
len
(
self
.
_partition_book
.
partid2eids
(
self
.
_part_id
,
self
.
type_name
)
)
elif
NODE_PART_POLICY
in
self
.
policy_str
:
if
self
.
is_node
:
return
len
(
self
.
_partition_book
.
partid2nids
(
self
.
_part_id
,
self
.
type_name
)
)
else
:
raise
RuntimeError
(
"Cannot support policy: %s "
%
self
.
policy_str
)
return
len
(
self
.
_partition_book
.
partid2eids
(
self
.
_part_id
,
self
.
type_name
)
)
def
get_size
(
self
):
"""Get the full size of the data.
...
...
@@ -1367,12 +1372,10 @@ class PartitionPolicy(object):
int
data size
"""
if
EDGE_PART_POLICY
in
self
.
policy_str
:
return
self
.
_partition_book
.
_num_edges
(
self
.
type_name
)
elif
NODE_PART_POLICY
in
self
.
policy_str
:
if
self
.
is_node
:
return
self
.
_partition_book
.
_num_nodes
(
self
.
type_name
)
else
:
r
aise
RuntimeError
(
"Cannot support policy: %s "
%
self
.
policy_str
)
r
eturn
self
.
_partition_book
.
_num_edges
(
self
.
type_name
)
class
NodePartitionPolicy
(
PartitionPolicy
):
...
...
tests/distributed/test_partition.py
View file @
6e1cc7da
import
os
import
backend
as
F
import
torch
as
th
import
dgl
import
numpy
as
np
import
pytest
...
...
@@ -588,7 +589,7 @@ def test_BasicPartitionBook():
def
test_RangePartitionBook
():
part_id
=
0
part_id
=
1
num_parts
=
2
# homogeneous
...
...
@@ -662,10 +663,33 @@ def test_RangePartitionBook():
expect_except
=
True
assert
expect_except
# NodePartitionPolicy
node_policy
=
NodePartitionPolicy
(
gpb
,
"node1"
)
assert
node_policy
.
type_name
==
"node1"
assert
node_policy
.
policy_str
==
"node~node1"
assert
node_policy
.
part_id
==
part_id
assert
node_policy
.
is_node
assert
node_policy
.
get_data_name
(
'x'
).
is_node
()
local_ids
=
th
.
arange
(
0
,
1000
)
global_ids
=
local_ids
+
1000
assert
th
.
equal
(
node_policy
.
to_local
(
global_ids
),
local_ids
)
assert
th
.
all
(
node_policy
.
to_partid
(
global_ids
)
==
part_id
)
assert
node_policy
.
get_part_size
()
==
1000
assert
node_policy
.
get_size
()
==
2000
# EdgePartitionPolicy
edge_policy
=
EdgePartitionPolicy
(
gpb
,
c_etype
)
assert
edge_policy
.
type_name
==
c_etype
assert
edge_policy
.
policy_str
==
"edge~node1:edge1:node2"
assert
edge_policy
.
part_id
==
part_id
assert
not
edge_policy
.
is_node
assert
not
edge_policy
.
get_data_name
(
'x'
).
is_node
()
local_ids
=
th
.
arange
(
0
,
5000
)
global_ids
=
local_ids
+
5000
assert
th
.
equal
(
edge_policy
.
to_local
(
global_ids
),
local_ids
)
assert
th
.
all
(
edge_policy
.
to_partid
(
global_ids
)
==
part_id
)
assert
edge_policy
.
get_part_size
()
==
5000
assert
edge_policy
.
get_size
()
==
10000
expect_except
=
False
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment