Unverified Commit 6e1cc7da authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] fix is_node check according to policy (#4911)

* [Dist] fix is_node check according to policy

* add more tests
parent b377e1b9
...@@ -1231,14 +1231,14 @@ class PartitionPolicy(object): ...@@ -1231,14 +1231,14 @@ class PartitionPolicy(object):
""" """
def __init__(self, policy_str, partition_book): def __init__(self, policy_str, partition_book):
if POLICY_DELIMITER not in policy_str: assert (policy_str.startswith(NODE_PART_POLICY) or
assert policy_str in ( policy_str.startswith(EDGE_PART_POLICY)), (
EDGE_PART_POLICY, f"policy_str must start with {NODE_PART_POLICY} or "
NODE_PART_POLICY, f"{EDGE_PART_POLICY}, but got {policy_str}."
), "policy_str must contain 'edge' or 'node'." )
if NODE_PART_POLICY == policy_str: if NODE_PART_POLICY == policy_str:
policy_str = NODE_PART_POLICY + POLICY_DELIMITER + DEFAULT_NTYPE policy_str = NODE_PART_POLICY + POLICY_DELIMITER + DEFAULT_NTYPE
else: if EDGE_PART_POLICY == policy_str:
policy_str = EDGE_PART_POLICY + POLICY_DELIMITER + DEFAULT_ETYPE[1] policy_str = EDGE_PART_POLICY + POLICY_DELIMITER + DEFAULT_ETYPE[1]
self._policy_str = policy_str self._policy_str = policy_str
self._part_id = partition_book.partid self._part_id = partition_book.partid
...@@ -1246,6 +1246,7 @@ class PartitionPolicy(object): ...@@ -1246,6 +1246,7 @@ class PartitionPolicy(object):
part_policy, self._type_name = policy_str.split(POLICY_DELIMITER, 1) part_policy, self._type_name = policy_str.split(POLICY_DELIMITER, 1)
if part_policy == EDGE_PART_POLICY: if part_policy == EDGE_PART_POLICY:
self._type_name = _etype_str_to_tuple(self._type_name) self._type_name = _etype_str_to_tuple(self._type_name)
self._is_node = self.policy_str.startswith(NODE_PART_POLICY)
@property @property
def policy_str(self): def policy_str(self):
...@@ -1291,10 +1292,20 @@ class PartitionPolicy(object): ...@@ -1291,10 +1292,20 @@ class PartitionPolicy(object):
""" """
return self._partition_book return self._partition_book
@property
def is_node(self):
"""Indicate whether the policy is for node or edge
Returns
-------
bool
node or edge
"""
return self._is_node
def get_data_name(self, name): def get_data_name(self, name):
"""Get HeteroDataName""" """Get HeteroDataName"""
is_node = NODE_PART_POLICY in self.policy_str return HeteroDataName(self.is_node, self.type_name, name)
return HeteroDataName(is_node, self.type_name, name)
def to_local(self, id_tensor): def to_local(self, id_tensor):
"""Mapping global ID to local ID. """Mapping global ID to local ID.
...@@ -1309,16 +1320,14 @@ class PartitionPolicy(object): ...@@ -1309,16 +1320,14 @@ class PartitionPolicy(object):
tensor tensor
local ID tensor local ID tensor
""" """
if EDGE_PART_POLICY in self.policy_str: if self.is_node:
return self._partition_book.eid2localeid(
id_tensor, self._part_id, self.type_name
)
elif NODE_PART_POLICY in self.policy_str:
return self._partition_book.nid2localnid( return self._partition_book.nid2localnid(
id_tensor, self._part_id, self.type_name id_tensor, self._part_id, self.type_name
) )
else: else:
raise RuntimeError("Cannot support policy: %s " % self.policy_str) return self._partition_book.eid2localeid(
id_tensor, self._part_id, self.type_name
)
def to_partid(self, id_tensor): def to_partid(self, id_tensor):
"""Mapping global ID to partition ID. """Mapping global ID to partition ID.
...@@ -1333,12 +1342,10 @@ class PartitionPolicy(object): ...@@ -1333,12 +1342,10 @@ class PartitionPolicy(object):
tensor tensor
partition ID partition ID
""" """
if EDGE_PART_POLICY in self.policy_str: if self.is_node:
return self._partition_book.eid2partid(id_tensor, self.type_name)
elif NODE_PART_POLICY in self.policy_str:
return self._partition_book.nid2partid(id_tensor, self.type_name) return self._partition_book.nid2partid(id_tensor, self.type_name)
else: else:
raise RuntimeError("Cannot support policy: %s " % self.policy_str) return self._partition_book.eid2partid(id_tensor, self.type_name)
def get_part_size(self): def get_part_size(self):
"""Get data size of current partition. """Get data size of current partition.
...@@ -1348,16 +1355,14 @@ class PartitionPolicy(object): ...@@ -1348,16 +1355,14 @@ class PartitionPolicy(object):
int int
data size data size
""" """
if EDGE_PART_POLICY in self.policy_str: if self.is_node:
return len(
self._partition_book.partid2eids(self._part_id, self.type_name)
)
elif NODE_PART_POLICY in self.policy_str:
return len( return len(
self._partition_book.partid2nids(self._part_id, self.type_name) self._partition_book.partid2nids(self._part_id, self.type_name)
) )
else: else:
raise RuntimeError("Cannot support policy: %s " % self.policy_str) return len(
self._partition_book.partid2eids(self._part_id, self.type_name)
)
def get_size(self): def get_size(self):
"""Get the full size of the data. """Get the full size of the data.
...@@ -1367,12 +1372,10 @@ class PartitionPolicy(object): ...@@ -1367,12 +1372,10 @@ class PartitionPolicy(object):
int int
data size data size
""" """
if EDGE_PART_POLICY in self.policy_str: if self.is_node:
return self._partition_book._num_edges(self.type_name)
elif NODE_PART_POLICY in self.policy_str:
return self._partition_book._num_nodes(self.type_name) return self._partition_book._num_nodes(self.type_name)
else: else:
raise RuntimeError("Cannot support policy: %s " % self.policy_str) return self._partition_book._num_edges(self.type_name)
class NodePartitionPolicy(PartitionPolicy): class NodePartitionPolicy(PartitionPolicy):
......
import os import os
import backend as F import backend as F
import torch as th
import dgl import dgl
import numpy as np import numpy as np
import pytest import pytest
...@@ -588,7 +589,7 @@ def test_BasicPartitionBook(): ...@@ -588,7 +589,7 @@ def test_BasicPartitionBook():
def test_RangePartitionBook(): def test_RangePartitionBook():
part_id = 0 part_id = 1
num_parts = 2 num_parts = 2
# homogeneous # homogeneous
...@@ -662,10 +663,33 @@ def test_RangePartitionBook(): ...@@ -662,10 +663,33 @@ def test_RangePartitionBook():
expect_except = True expect_except = True
assert expect_except assert expect_except
# NodePartitionPolicy
node_policy = NodePartitionPolicy(gpb, "node1") node_policy = NodePartitionPolicy(gpb, "node1")
assert node_policy.type_name == "node1" assert node_policy.type_name == "node1"
assert node_policy.policy_str == "node~node1"
assert node_policy.part_id == part_id
assert node_policy.is_node
assert node_policy.get_data_name('x').is_node()
local_ids = th.arange(0, 1000)
global_ids = local_ids + 1000
assert th.equal(node_policy.to_local(global_ids), local_ids)
assert th.all(node_policy.to_partid(global_ids) == part_id)
assert node_policy.get_part_size() == 1000
assert node_policy.get_size() == 2000
# EdgePartitionPolicy
edge_policy = EdgePartitionPolicy(gpb, c_etype) edge_policy = EdgePartitionPolicy(gpb, c_etype)
assert edge_policy.type_name == c_etype assert edge_policy.type_name == c_etype
assert edge_policy.policy_str == "edge~node1:edge1:node2"
assert edge_policy.part_id == part_id
assert not edge_policy.is_node
assert not edge_policy.get_data_name('x').is_node()
local_ids = th.arange(0, 5000)
global_ids = local_ids + 5000
assert th.equal(edge_policy.to_local(global_ids), local_ids)
assert th.all(edge_policy.to_partid(global_ids) == part_id)
assert edge_policy.get_part_size() == 5000
assert edge_policy.get_size() == 10000
expect_except = False expect_except = False
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment