Unverified Commit b379dbd6 authored by Jingcheng Yu's avatar Jingcheng Yu Committed by GitHub
Browse files

Optimize dist_graph/_split_even_to_part memory usage (#3132)


Co-authored-by: default avataryujingcheng02 <yujingcheng02@meituan.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 88f20eec
......@@ -1125,6 +1125,21 @@ def replace_inf_with_zero(x):
"""
pass
def count_nonzero(input):
"""Return the count of non-zero values in the tensor input.
Parameters
----------
input : Tensor
The tensor to be counted
Returns
-------
Integer
The result
"""
pass
###############################################################################
# Tensor functions used *only* on index tensor
# ----------------
......@@ -1643,7 +1658,7 @@ def scatter_add(x, idx, m):
The indices array.
m : int
The length of output.
Returns
-------
Tensor
......
......@@ -342,6 +342,11 @@ def clamp(data, min_val, max_val):
def replace_inf_with_zero(x):
return nd.where(nd.abs(x) == np.inf, nd.zeros_like(x), x)
def count_nonzero(input):
# TODO: fallback to numpy is unfortunate
tmp = input.asnumpy()
return np.count_nonzero(tmp)
def unique(input):
# TODO: fallback to numpy is unfortunate
tmp = input.asnumpy()
......@@ -520,10 +525,10 @@ class CopyReduce(mx.autograd.Function):
in_ones_nd = zerocopy_to_dgl_ndarray(in_ones)
degs_nd = zerocopy_to_dgl_ndarray(degs)
K.copy_reduce(
'sum', self.graph, self.target, in_ones_nd, degs_nd,
'sum', self.graph, self.target, in_ones_nd, degs_nd,
self.in_map[0], self.out_map[0])
# reshape
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.ndim - 1)).clip(1, float('inf'))
degs = degs.reshape((out_data.shape[0],) + (1,) * (out_data.ndim - 1)).clip(1, float('inf'))
out_data = out_data / degs
else:
degs = None
......
......@@ -3,6 +3,7 @@ from __future__ import absolute_import
from distutils.version import LooseVersion
import scipy # Weird bug in new pytorch when import scipy after import torch
import numpy as np
import torch as th
import builtins
import numbers
......@@ -290,6 +291,10 @@ def clamp(data, min_val, max_val):
def replace_inf_with_zero(x):
return th.masked_fill(x, th.isinf(x), 0)
def count_nonzero(input):
# TODO: fallback to numpy for backward compatibility
return np.count_nonzero(input)
def unique(input):
if input.dtype == th.bool:
input = input.type(th.int8)
......
......@@ -409,6 +409,10 @@ def clamp(data, min_val, max_val):
def replace_inf_with_zero(x):
return tf.where(tf.abs(x) == np.inf, 0, x)
def count_nonzero(input):
return int(tf.math.count_nonzero(input))
def unique(input):
return tf.unique(input).y
......
......@@ -1075,26 +1075,53 @@ def _even_offset(n, k):
def _split_even_to_part(partition_book, elements):
''' Split the input element list evenly.
'''
# here we divide the element list as evenly as possible. If we use range partitioning,
# the split results also respect the data locality. Range partitioning is the default
# strategy.
# TODO(zhengda) we need another way to divide the list for other partitioning strategy.
if isinstance(elements, DistTensor):
# Here we need to fetch all elements from the kvstore server.
# I hope it's OK.
eles = F.nonzero_1d(elements[0:len(elements)])
# compute the offset of each split and ensure that the difference of each partition size
# is 1.
offsets = _even_offset(len(eles), partition_book.num_partitions())
assert offsets[-1] == len(eles)
# Get the elements that belong to the partition.
partid = partition_book.partid
part_eles = eles[offsets[partid] : offsets[partid + 1]]
else:
eles = F.nonzero_1d(F.tensor(elements))
# here we divide the element list as evenly as possible. If we use range partitioning,
# the split results also respect the data locality. Range partitioning is the default
# strategy.
# TODO(zhengda) we need another way to divide the list for other partitioning strategy.
# compute the offset of each split and ensure that the difference of each partition size
# is 1.
offsets = _even_offset(len(eles), partition_book.num_partitions())
assert offsets[-1] == len(eles)
# Get the elements that belong to the partition.
partid = partition_book.partid
part_eles = eles[offsets[partid] : offsets[partid + 1]]
elements = F.tensor(elements)
nonzero_count = F.count_nonzero(elements)
# compute the offset of each split and ensure that the difference of each partition size
# is 1.
offsets = _even_offset(nonzero_count, partition_book.num_partitions())
assert offsets[-1] == nonzero_count
# Get the elements that belong to the partition.
partid = partition_book.partid
left, right = offsets[partid], offsets[partid + 1]
x = y = 0
num_elements = len(elements)
block_size = num_elements // partition_book.num_partitions()
part_eles = None
# compute the nonzero tensor of each partition instead of whole tensor to save memory
for idx in range(0, num_elements, block_size):
nonzero_block = F.nonzero_1d(elements[idx:min(idx+block_size, num_elements)])
x = y
y += len(nonzero_block)
if y > left and x < right:
start = max(x, left) - x
end = min(y, right) - x
tmp = nonzero_block[start:end] + idx
if part_eles is None:
part_eles = tmp
else:
part_eles = F.cat((part_eles, tmp), 0)
elif x >= right:
break
return part_eles
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment