Unverified Commit b379dbd6 authored by Jingcheng Yu's avatar Jingcheng Yu Committed by GitHub
Browse files

Optimize dist_graph/_split_even_to_part memory usage (#3132)


Co-authored-by: default avataryujingcheng02 <yujingcheng02@meituan.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 88f20eec
...@@ -1125,6 +1125,21 @@ def replace_inf_with_zero(x): ...@@ -1125,6 +1125,21 @@ def replace_inf_with_zero(x):
""" """
pass pass
def count_nonzero(input):
"""Return the count of non-zero values in the tensor input.
Parameters
----------
input : Tensor
The tensor to be counted
Returns
-------
Integer
The result
"""
pass
############################################################################### ###############################################################################
# Tensor functions used *only* on index tensor # Tensor functions used *only* on index tensor
# ---------------- # ----------------
......
...@@ -342,6 +342,11 @@ def clamp(data, min_val, max_val): ...@@ -342,6 +342,11 @@ def clamp(data, min_val, max_val):
def replace_inf_with_zero(x): def replace_inf_with_zero(x):
return nd.where(nd.abs(x) == np.inf, nd.zeros_like(x), x) return nd.where(nd.abs(x) == np.inf, nd.zeros_like(x), x)
def count_nonzero(input):
# TODO: fallback to numpy is unfortunate
tmp = input.asnumpy()
return np.count_nonzero(tmp)
def unique(input): def unique(input):
# TODO: fallback to numpy is unfortunate # TODO: fallback to numpy is unfortunate
tmp = input.asnumpy() tmp = input.asnumpy()
......
...@@ -3,6 +3,7 @@ from __future__ import absolute_import ...@@ -3,6 +3,7 @@ from __future__ import absolute_import
from distutils.version import LooseVersion from distutils.version import LooseVersion
import scipy # Weird bug in new pytorch when import scipy after import torch import scipy # Weird bug in new pytorch when import scipy after import torch
import numpy as np
import torch as th import torch as th
import builtins import builtins
import numbers import numbers
...@@ -290,6 +291,10 @@ def clamp(data, min_val, max_val): ...@@ -290,6 +291,10 @@ def clamp(data, min_val, max_val):
def replace_inf_with_zero(x): def replace_inf_with_zero(x):
return th.masked_fill(x, th.isinf(x), 0) return th.masked_fill(x, th.isinf(x), 0)
def count_nonzero(input):
# TODO: fallback to numpy for backward compatibility
return np.count_nonzero(input)
def unique(input): def unique(input):
if input.dtype == th.bool: if input.dtype == th.bool:
input = input.type(th.int8) input = input.type(th.int8)
......
...@@ -409,6 +409,10 @@ def clamp(data, min_val, max_val): ...@@ -409,6 +409,10 @@ def clamp(data, min_val, max_val):
def replace_inf_with_zero(x): def replace_inf_with_zero(x):
return tf.where(tf.abs(x) == np.inf, 0, x) return tf.where(tf.abs(x) == np.inf, 0, x)
def count_nonzero(input):
return int(tf.math.count_nonzero(input))
def unique(input): def unique(input):
return tf.unique(input).y return tf.unique(input).y
......
...@@ -1075,18 +1075,14 @@ def _even_offset(n, k): ...@@ -1075,18 +1075,14 @@ def _even_offset(n, k):
def _split_even_to_part(partition_book, elements): def _split_even_to_part(partition_book, elements):
''' Split the input element list evenly. ''' Split the input element list evenly.
''' '''
if isinstance(elements, DistTensor):
# Here we need to fetch all elements from the kvstore server.
# I hope it's OK.
eles = F.nonzero_1d(elements[0:len(elements)])
else:
eles = F.nonzero_1d(F.tensor(elements))
# here we divide the element list as evenly as possible. If we use range partitioning, # here we divide the element list as evenly as possible. If we use range partitioning,
# the split results also respect the data locality. Range partitioning is the default # the split results also respect the data locality. Range partitioning is the default
# strategy. # strategy.
# TODO(zhengda) we need another way to divide the list for other partitioning strategy. # TODO(zhengda) we need another way to divide the list for other partitioning strategy.
if isinstance(elements, DistTensor):
# Here we need to fetch all elements from the kvstore server.
# I hope it's OK.
eles = F.nonzero_1d(elements[0:len(elements)])
# compute the offset of each split and ensure that the difference of each partition size # compute the offset of each split and ensure that the difference of each partition size
# is 1. # is 1.
offsets = _even_offset(len(eles), partition_book.num_partitions()) offsets = _even_offset(len(eles), partition_book.num_partitions())
...@@ -1095,6 +1091,37 @@ def _split_even_to_part(partition_book, elements): ...@@ -1095,6 +1091,37 @@ def _split_even_to_part(partition_book, elements):
# Get the elements that belong to the partition. # Get the elements that belong to the partition.
partid = partition_book.partid partid = partition_book.partid
part_eles = eles[offsets[partid] : offsets[partid + 1]] part_eles = eles[offsets[partid] : offsets[partid + 1]]
else:
elements = F.tensor(elements)
nonzero_count = F.count_nonzero(elements)
# compute the offset of each split and ensure that the difference of each partition size
# is 1.
offsets = _even_offset(nonzero_count, partition_book.num_partitions())
assert offsets[-1] == nonzero_count
# Get the elements that belong to the partition.
partid = partition_book.partid
left, right = offsets[partid], offsets[partid + 1]
x = y = 0
num_elements = len(elements)
block_size = num_elements // partition_book.num_partitions()
part_eles = None
# compute the nonzero tensor of each partition instead of whole tensor to save memory
for idx in range(0, num_elements, block_size):
nonzero_block = F.nonzero_1d(elements[idx:min(idx+block_size, num_elements)])
x = y
y += len(nonzero_block)
if y > left and x < right:
start = max(x, left) - x
end = min(y, right) - x
tmp = nonzero_block[start:end] + idx
if part_eles is None:
part_eles = tmp
else:
part_eles = F.cat((part_eles, tmp), 0)
elif x >= right:
break
return part_eles return part_eles
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment