Unverified Commit 6a3685be authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGFIX] fix a bug in graph partitioning. (#1769)

* fix

* use utils.toindex in the right place.

* fix.

* update tensor for mxnet backend.

* fix

* fix
parent 22a6ad6d
...@@ -35,14 +35,23 @@ def cpu(): ...@@ -35,14 +35,23 @@ def cpu():
return mx.cpu() return mx.cpu()
def tensor(data, dtype=None): def tensor(data, dtype=None):
# MXNet always returns a float tensor regardless of type inside data. if isinstance(data, nd.NDArray):
# This is a workaround. if dtype is None or data.dtype == dtype:
if dtype is None: return data
if isinstance(data[0], numbers.Integral):
dtype = np.int64
else: else:
dtype = np.float32 return nd.cast(data, dtype)
return nd.array(data, dtype=dtype) else:
if dtype is None:
if isinstance(data, numbers.Number):
dtype = np.int64 if isinstance(data, numbers.Integral) else np.float32
elif isinstance(data, np.ndarray):
dtype = data.dtype
# mxnet doesn't support bool
if dtype == np.bool:
dtype = np.int32
else:
dtype = np.int64 if isinstance(data[0], numbers.Integral) else np.float32
return nd.array(data, dtype=dtype)
def as_scalar(data): def as_scalar(data):
return data.asscalar() return data.asscalar()
......
...@@ -31,7 +31,7 @@ def cpu(): ...@@ -31,7 +31,7 @@ def cpu():
return th.device('cpu') return th.device('cpu')
def tensor(data, dtype=None): def tensor(data, dtype=None):
return th.tensor(data, dtype=dtype) return th.as_tensor(data, dtype=dtype)
def as_scalar(data): def as_scalar(data):
return data.item() return data.item()
...@@ -272,6 +272,8 @@ def clone(input): ...@@ -272,6 +272,8 @@ def clone(input):
return input.clone() return input.clone()
def unique(input): def unique(input):
if input.dtype == th.bool:
input = input.type(th.int8)
return th.unique(input) return th.unique(input)
def full_1d(length, fill_value, dtype, ctx): def full_1d(length, fill_value, dtype, ctx):
......
...@@ -576,8 +576,7 @@ def _get_overlap(mask_arr, ids): ...@@ -576,8 +576,7 @@ def _get_overlap(mask_arr, ids):
masks = mask_arr[ids] masks = mask_arr[ids]
return F.boolean_mask(ids, masks) return F.boolean_mask(ids, masks)
else: else:
mask_arr = utils.toindex(mask_arr) masks = F.gather_row(F.tensor(mask_arr), ids)
masks = F.gather_row(mask_arr.tousertensor(), ids)
return F.boolean_mask(ids, masks) return F.boolean_mask(ids, masks)
def _split_local(partition_book, rank, elements, local_eles): def _split_local(partition_book, rank, elements, local_eles):
...@@ -615,8 +614,7 @@ def _split_even(partition_book, rank, elements): ...@@ -615,8 +614,7 @@ def _split_even(partition_book, rank, elements):
# I hope it's OK. # I hope it's OK.
eles = F.nonzero_1d(elements[0:len(elements)]) eles = F.nonzero_1d(elements[0:len(elements)])
else: else:
elements = utils.toindex(elements) eles = F.nonzero_1d(F.tensor(elements))
eles = F.nonzero_1d(elements.tousertensor())
# here we divide the element list as evenly as possible. If we use range partitioning, # here we divide the element list as evenly as possible. If we use range partitioning,
# the split results also respect the data locality. Range partitioning is the default # the split results also respect the data locality. Range partitioning is the default
......
...@@ -100,12 +100,10 @@ class GraphPartitionBook: ...@@ -100,12 +100,10 @@ class GraphPartitionBook:
assert num_parts > 0, 'num_parts must be greater than zero.' assert num_parts > 0, 'num_parts must be greater than zero.'
self._part_id = int(part_id) self._part_id = int(part_id)
self._num_partitions = int(num_parts) self._num_partitions = int(num_parts)
node_map = utils.toindex(node_map) self._nid2partid = F.tensor(node_map)
self._nid2partid = node_map.tousertensor()
assert F.dtype(self._nid2partid) in (F.int32, F.int64), \ assert F.dtype(self._nid2partid) in (F.int32, F.int64), \
'the node map must be stored in an integer array' 'the node map must be stored in an integer array'
edge_map = utils.toindex(edge_map) self._eid2partid = F.tensor(edge_map)
self._eid2partid = edge_map.tousertensor()
assert F.dtype(self._eid2partid) in (F.int32, F.int64), \ assert F.dtype(self._eid2partid) in (F.int32, F.int64), \
'the edge map must be stored in an integer array' 'the edge map must be stored in an integer array'
# Get meta data of the partition book. # Get meta data of the partition book.
...@@ -374,10 +372,12 @@ class RangePartitionBook: ...@@ -374,10 +372,12 @@ class RangePartitionBook:
assert num_parts > 0, 'num_parts must be greater than zero.' assert num_parts > 0, 'num_parts must be greater than zero.'
self._partid = part_id self._partid = part_id
self._num_partitions = num_parts self._num_partitions = num_parts
node_map = utils.toindex(node_map) if not isinstance(node_map, np.ndarray):
edge_map = utils.toindex(edge_map) node_map = F.asnumpy(node_map)
self._node_map = node_map.tonumpy() if not isinstance(edge_map, np.ndarray):
self._edge_map = edge_map.tonumpy() edge_map = F.asnumpy(edge_map)
self._node_map = node_map
self._edge_map = edge_map
# Get meta data of the partition book # Get meta data of the partition book
self._partition_meta_data = [] self._partition_meta_data = []
for partid in range(self._num_partitions): for partid in range(self._num_partitions):
......
...@@ -701,8 +701,7 @@ def metis_partition_assignment(g, k, balance_ntypes=None, balance_edges=False): ...@@ -701,8 +701,7 @@ def metis_partition_assignment(g, k, balance_ntypes=None, balance_edges=False):
if balance_ntypes is not None: if balance_ntypes is not None:
assert len(balance_ntypes) == g.number_of_nodes(), \ assert len(balance_ntypes) == g.number_of_nodes(), \
"The length of balance_ntypes should be equal to #nodes in the graph" "The length of balance_ntypes should be equal to #nodes in the graph"
balance_ntypes = utils.toindex(balance_ntypes) balance_ntypes = F.tensor(balance_ntypes)
balance_ntypes = balance_ntypes.tousertensor()
uniq_ntypes = F.unique(balance_ntypes) uniq_ntypes = F.unique(balance_ntypes)
for ntype in uniq_ntypes: for ntype in uniq_ntypes:
vwgt.append(F.astype(balance_ntypes == ntype, F.int64)) vwgt.append(F.astype(balance_ntypes == ntype, F.int64))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment