"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "c3a043eb23027bd3fa417b474d9f753bf9c75e72"
Unverified Commit cf8a3fb3 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[hotfix] Address the performance issue of topk. (#2628)



* upd

* fix

* upd

* upd

* udp

* upd

* upd

* upd

* upd

* upd

* upd
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent e46983f6
...@@ -227,32 +227,40 @@ def randint(shape, dtype, ctx, low, high): ...@@ -227,32 +227,40 @@ def randint(shape, dtype, ctx, low, high):
def pad_packed_tensor(input, lengths, value, l_min=None): def pad_packed_tensor(input, lengths, value, l_min=None):
old_shape = input.shape old_shape = input.shape
if isinstance(lengths, th.Tensor): device = input.device
max_len = as_scalar(lengths.max()) if not is_tensor(lengths):
lengths = th.tensor(lengths, dtype=th.int64, device=device)
else: else:
max_len = builtins.max(lengths) lengths = lengths.to(device)
max_len = as_scalar(lengths.max())
if l_min is not None: if l_min is not None:
max_len = builtins.max(max_len, l_min) max_len = builtins.max(max_len, l_min)
batch_size = len(lengths) batch_size = len(lengths)
device = input.device
x = input.new(batch_size * max_len, *old_shape[1:]) x = input.new(batch_size * max_len, *old_shape[1:])
x.fill_(value) x.fill_(value)
index = [] index = th.ones(len(input), dtype=th.int64, device=device)
for i, l in enumerate(lengths): cum_lengths = th.cumsum(lengths, 0)
index.extend(range(i * max_len, i * max_len + l)) index[cum_lengths[:-1]] += (max_len - lengths[:-1])
index = th.tensor(index).to(device) index = th.cumsum(index, 0) - 1
return scatter_row(x, index, input).view(batch_size, max_len, *old_shape[1:]) x[index] = input
return x.view(batch_size, max_len, *old_shape[1:])
def pack_padded_tensor(input, lengths): def pack_padded_tensor(input, lengths):
batch_size, max_len = input.shape[:2] max_len = input.shape[1]
device = input.device device = input.device
index = [] if not is_tensor(lengths):
for i, l in enumerate(lengths): lengths = th.tensor(lengths, dtype=th.int64, device=device)
index.extend(range(i * max_len, i * max_len + l)) else:
index = th.tensor(index).to(device) lengths = lengths.to(device)
return gather_row(input.view(batch_size * max_len, -1), index) input = input.view(-1, *input.shape[2:])
out_len = lengths.sum().item()
index = th.ones(out_len, dtype=th.int64, device=device)
cum_lengths = th.cumsum(lengths, 0)
index[cum_lengths[:-1]] += (max_len - lengths[:-1])
index = th.cumsum(index, 0) - 1
return input[index]
def boolean_mask(input, mask): def boolean_mask(input, mask):
if 'bool' not in str(mask.dtype): if 'bool' not in str(mask.dtype):
......
...@@ -485,6 +485,38 @@ READOUT_ON_ATTRS = { ...@@ -485,6 +485,38 @@ READOUT_ON_ATTRS = {
'edges': ('edata', 'batch_num_edges', 'number_of_edges'), 'edges': ('edata', 'batch_num_edges', 'number_of_edges'),
} }
def _topk_torch(keys, k, descending, x):
"""Internal function to take graph-wise top-k node/edge features according to
the rank given by keys, this function is PyTorch only.
Parameters
----------
keys : Tensor
The key for ranking.
k : int
The :math:`k` in "top-:math:`k`".
descending : bool
Indicates whether to return the feature corresponding to largest or
smallest elements.
x : Tensor
The padded feature with shape (batch, max_len, *)
Returns
-------
sorted_feat : Tensor
A tensor with shape :math:`(batch, k, *)`.
sorted_idx : Tensor
A tensor with shape :math:`(batch, k)`.
"""
import torch as th
batch_size, max_len = x.shape[0], x.shape[1]
topk_indices = keys.topk(k, -1, largest=descending)[1] # (batch_size, k)
x = x.view((batch_size * max_len), -1)
shift = th.arange(0, batch_size, device=x.device).view(batch_size, 1) * max_len
topk_indices_ = topk_indices + shift
x = x[topk_indices_].view(batch_size, k, -1)
return th.masked_fill(x, th.isinf(x), 0), topk_indices
def _topk_on(graph, typestr, feat, k, descending, sortby, ntype_or_etype): def _topk_on(graph, typestr, feat, k, descending, sortby, ntype_or_etype):
"""Internal function to take graph-wise top-k node/edge features of """Internal function to take graph-wise top-k node/edge features of
field :attr:`feat` in :attr:`graph` ranked by keys at given field :attr:`feat` in :attr:`graph` ranked by keys at given
...@@ -534,41 +566,41 @@ def _topk_on(graph, typestr, feat, k, descending, sortby, ntype_or_etype): ...@@ -534,41 +566,41 @@ def _topk_on(graph, typestr, feat, k, descending, sortby, ntype_or_etype):
if F.ndim(data[feat]) > 2: if F.ndim(data[feat]) > 2:
raise DGLError('Only support {} feature `{}` with dimension less than or' raise DGLError('Only support {} feature `{}` with dimension less than or'
' equal to 2'.format(typestr, feat)) ' equal to 2'.format(typestr, feat))
feat = data[feat] feat = data[feat]
hidden_size = F.shape(feat)[-1] hidden_size = F.shape(feat)[-1]
batch_num_objs = getattr(graph, batch_num_objs_attr)(ntype_or_etype) batch_num_objs = getattr(graph, batch_num_objs_attr)(ntype_or_etype)
batch_size = len(batch_num_objs) batch_size = len(batch_num_objs)
length = max(max(F.asnumpy(batch_num_objs)), k) length = max(max(F.asnumpy(batch_num_objs)), k)
fill_val = -float('inf') if descending else float('inf') fill_val = -float('inf') if descending else float('inf')
feat_ = F.pad_packed_tensor(feat, batch_num_objs, fill_val, l_min=k) feat_ = F.pad_packed_tensor(feat, batch_num_objs, fill_val, l_min=k) # (batch_size, l, d)
if sortby is not None:
keys = F.squeeze(F.slice_axis(feat_, -1, sortby, sortby+1), -1)
order = F.argsort(keys, -1, descending=descending)
else:
order = F.argsort(feat_, 1, descending=descending)
topk_indices = F.slice_axis(order, 1, 0, k)
# zero padding
feat_ = F.pad_packed_tensor(feat, batch_num_objs, 0, l_min=k)
if sortby is not None: if F.backend_name == 'pytorch' and sortby is not None:
feat_ = F.reshape(feat_, (batch_size * length, -1)) # PyTorch's implementation of top-K
shift = F.repeat(F.arange(0, batch_size) * length, k, -1) keys = feat_[..., sortby] # (batch_size, l)
shift = F.copy_to(shift, F.context(feat)) return _topk_torch(keys, k, descending, feat_)
topk_indices_ = F.reshape(topk_indices, (-1,)) + shift
else: else:
feat_ = F.reshape(feat_, (-1,)) # Fallback to framework-agnostic implementation of top-K
shift = F.repeat(F.arange(0, batch_size), k * hidden_size, -1) * length * hidden_size +\ if sortby is not None:
F.cat([F.arange(0, hidden_size)] * batch_size * k, -1) keys = F.squeeze(F.slice_axis(feat_, -1, sortby, sortby+1), -1)
shift = F.copy_to(shift, F.context(feat)) order = F.argsort(keys, -1, descending=descending)
topk_indices_ = F.reshape(topk_indices, (-1,)) * hidden_size + shift else:
order = F.argsort(feat_, 1, descending=descending)
return F.reshape(F.gather_row(feat_, topk_indices_), (batch_size, k, -1)),\ topk_indices = F.slice_axis(order, 1, 0, k)
topk_indices
if sortby is not None:
feat_ = F.reshape(feat_, (batch_size * length, -1))
shift = F.repeat(F.arange(0, batch_size) * length, k, -1)
shift = F.copy_to(shift, F.context(feat))
topk_indices_ = F.reshape(topk_indices, (-1,)) + shift
else:
feat_ = F.reshape(feat_, (-1,))
shift = F.repeat(F.arange(0, batch_size), k * hidden_size, -1) * length * hidden_size +\
F.cat([F.arange(0, hidden_size)] * batch_size * k, -1)
shift = F.copy_to(shift, F.context(feat))
topk_indices_ = F.reshape(topk_indices, (-1,)) * hidden_size + shift
out = F.reshape(F.gather_row(feat_, topk_indices_), (batch_size, k, -1))
out = F.replace_inf_with_zero(out)
return out, topk_indices
def topk_nodes(graph, feat, k, *, descending=True, sortby=None, ntype=None): def topk_nodes(graph, feat, k, *, descending=True, sortby=None, ntype=None):
"""Return a graph-level representation by a graph-wise top-k on """Return a graph-level representation by a graph-wise top-k on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment