Unverified Commit 6e1cc7da authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] fix is_node check according to policy (#4911)

* [Dist] fix is_node check according to policy

* add more tests
parent b377e1b9
......@@ -1231,21 +1231,22 @@ class PartitionPolicy(object):
"""
def __init__(self, policy_str, partition_book):
if POLICY_DELIMITER not in policy_str:
assert policy_str in (
EDGE_PART_POLICY,
NODE_PART_POLICY,
), "policy_str must contain 'edge' or 'node'."
if NODE_PART_POLICY == policy_str:
policy_str = NODE_PART_POLICY + POLICY_DELIMITER + DEFAULT_NTYPE
else:
policy_str = EDGE_PART_POLICY + POLICY_DELIMITER + DEFAULT_ETYPE[1]
assert (policy_str.startswith(NODE_PART_POLICY) or
policy_str.startswith(EDGE_PART_POLICY)), (
f"policy_str must start with {NODE_PART_POLICY} or "
f"{EDGE_PART_POLICY}, but got {policy_str}."
)
if NODE_PART_POLICY == policy_str:
policy_str = NODE_PART_POLICY + POLICY_DELIMITER + DEFAULT_NTYPE
if EDGE_PART_POLICY == policy_str:
policy_str = EDGE_PART_POLICY + POLICY_DELIMITER + DEFAULT_ETYPE[1]
self._policy_str = policy_str
self._part_id = partition_book.partid
self._partition_book = partition_book
part_policy, self._type_name = policy_str.split(POLICY_DELIMITER, 1)
if part_policy == EDGE_PART_POLICY:
self._type_name = _etype_str_to_tuple(self._type_name)
self._is_node = self.policy_str.startswith(NODE_PART_POLICY)
@property
def policy_str(self):
......@@ -1291,10 +1292,20 @@ class PartitionPolicy(object):
"""
return self._partition_book
@property
def is_node(self):
"""Indicate whether the policy is for node or edge
Returns
-------
bool
node or edge
"""
return self._is_node
def get_data_name(self, name):
"""Get HeteroDataName"""
is_node = NODE_PART_POLICY in self.policy_str
return HeteroDataName(is_node, self.type_name, name)
return HeteroDataName(self.is_node, self.type_name, name)
def to_local(self, id_tensor):
"""Mapping global ID to local ID.
......@@ -1309,16 +1320,14 @@ class PartitionPolicy(object):
tensor
local ID tensor
"""
if EDGE_PART_POLICY in self.policy_str:
return self._partition_book.eid2localeid(
id_tensor, self._part_id, self.type_name
)
elif NODE_PART_POLICY in self.policy_str:
if self.is_node:
return self._partition_book.nid2localnid(
id_tensor, self._part_id, self.type_name
)
else:
raise RuntimeError("Cannot support policy: %s " % self.policy_str)
return self._partition_book.eid2localeid(
id_tensor, self._part_id, self.type_name
)
def to_partid(self, id_tensor):
"""Mapping global ID to partition ID.
......@@ -1333,12 +1342,10 @@ class PartitionPolicy(object):
tensor
partition ID
"""
if EDGE_PART_POLICY in self.policy_str:
return self._partition_book.eid2partid(id_tensor, self.type_name)
elif NODE_PART_POLICY in self.policy_str:
if self.is_node:
return self._partition_book.nid2partid(id_tensor, self.type_name)
else:
raise RuntimeError("Cannot support policy: %s " % self.policy_str)
return self._partition_book.eid2partid(id_tensor, self.type_name)
def get_part_size(self):
"""Get data size of current partition.
......@@ -1348,16 +1355,14 @@ class PartitionPolicy(object):
int
data size
"""
if EDGE_PART_POLICY in self.policy_str:
return len(
self._partition_book.partid2eids(self._part_id, self.type_name)
)
elif NODE_PART_POLICY in self.policy_str:
if self.is_node:
return len(
self._partition_book.partid2nids(self._part_id, self.type_name)
)
else:
raise RuntimeError("Cannot support policy: %s " % self.policy_str)
return len(
self._partition_book.partid2eids(self._part_id, self.type_name)
)
def get_size(self):
"""Get the full size of the data.
......@@ -1367,12 +1372,10 @@ class PartitionPolicy(object):
int
data size
"""
if EDGE_PART_POLICY in self.policy_str:
return self._partition_book._num_edges(self.type_name)
elif NODE_PART_POLICY in self.policy_str:
if self.is_node:
return self._partition_book._num_nodes(self.type_name)
else:
raise RuntimeError("Cannot support policy: %s " % self.policy_str)
return self._partition_book._num_edges(self.type_name)
class NodePartitionPolicy(PartitionPolicy):
......
import os
import backend as F
import torch as th
import dgl
import numpy as np
import pytest
......@@ -588,7 +589,7 @@ def test_BasicPartitionBook():
def test_RangePartitionBook():
part_id = 0
part_id = 1
num_parts = 2
# homogeneous
......@@ -662,10 +663,33 @@ def test_RangePartitionBook():
expect_except = True
assert expect_except
# NodePartitionPolicy
node_policy = NodePartitionPolicy(gpb, "node1")
assert node_policy.type_name == "node1"
assert node_policy.policy_str == "node~node1"
assert node_policy.part_id == part_id
assert node_policy.is_node
assert node_policy.get_data_name('x').is_node()
local_ids = th.arange(0, 1000)
global_ids = local_ids + 1000
assert th.equal(node_policy.to_local(global_ids), local_ids)
assert th.all(node_policy.to_partid(global_ids) == part_id)
assert node_policy.get_part_size() == 1000
assert node_policy.get_size() == 2000
# EdgePartitionPolicy
edge_policy = EdgePartitionPolicy(gpb, c_etype)
assert edge_policy.type_name == c_etype
assert edge_policy.policy_str == "edge~node1:edge1:node2"
assert edge_policy.part_id == part_id
assert not edge_policy.is_node
assert not edge_policy.get_data_name('x').is_node()
local_ids = th.arange(0, 5000)
global_ids = local_ids + 5000
assert th.equal(edge_policy.to_local(global_ids), local_ids)
assert th.all(edge_policy.to_partid(global_ids) == part_id)
assert edge_policy.get_part_size() == 5000
assert edge_policy.get_size() == 10000
expect_except = False
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment