"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "8257aff511c7a3161e58e4aa0fa0d385037c7161"
Unverified Commit cf8a3fb3 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[hotfix] Address the performance issue of topk. (#2628)



* upd

* fix

* upd

* upd

* udp

* upd

* upd

* upd

* upd

* upd

* upd
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent e46983f6
......@@ -227,32 +227,40 @@ def randint(shape, dtype, ctx, low, high):
def pad_packed_tensor(input, lengths, value, l_min=None):
old_shape = input.shape
if isinstance(lengths, th.Tensor):
max_len = as_scalar(lengths.max())
device = input.device
if not is_tensor(lengths):
lengths = th.tensor(lengths, dtype=th.int64, device=device)
else:
max_len = builtins.max(lengths)
lengths = lengths.to(device)
max_len = as_scalar(lengths.max())
if l_min is not None:
max_len = builtins.max(max_len, l_min)
batch_size = len(lengths)
device = input.device
x = input.new(batch_size * max_len, *old_shape[1:])
x.fill_(value)
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = th.tensor(index).to(device)
return scatter_row(x, index, input).view(batch_size, max_len, *old_shape[1:])
index = th.ones(len(input), dtype=th.int64, device=device)
cum_lengths = th.cumsum(lengths, 0)
index[cum_lengths[:-1]] += (max_len - lengths[:-1])
index = th.cumsum(index, 0) - 1
x[index] = input
return x.view(batch_size, max_len, *old_shape[1:])
def pack_padded_tensor(input, lengths):
batch_size, max_len = input.shape[:2]
max_len = input.shape[1]
device = input.device
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = th.tensor(index).to(device)
return gather_row(input.view(batch_size * max_len, -1), index)
if not is_tensor(lengths):
lengths = th.tensor(lengths, dtype=th.int64, device=device)
else:
lengths = lengths.to(device)
input = input.view(-1, *input.shape[2:])
out_len = lengths.sum().item()
index = th.ones(out_len, dtype=th.int64, device=device)
cum_lengths = th.cumsum(lengths, 0)
index[cum_lengths[:-1]] += (max_len - lengths[:-1])
index = th.cumsum(index, 0) - 1
return input[index]
def boolean_mask(input, mask):
if 'bool' not in str(mask.dtype):
......
......@@ -485,6 +485,38 @@ READOUT_ON_ATTRS = {
'edges': ('edata', 'batch_num_edges', 'number_of_edges'),
}
def _topk_torch(keys, k, descending, x):
"""Internal function to take graph-wise top-k node/edge features according to
the rank given by keys, this function is PyTorch only.
Parameters
----------
keys : Tensor
The key for ranking.
k : int
The :math:`k` in "top-:math:`k`".
descending : bool
Indicates whether to return the feature corresponding to largest or
smallest elements.
x : Tensor
The padded feature with shape (batch, max_len, *)
Returns
-------
sorted_feat : Tensor
A tensor with shape :math:`(batch, k, *)`.
sorted_idx : Tensor
A tensor with shape :math:`(batch, k)`.
"""
import torch as th
batch_size, max_len = x.shape[0], x.shape[1]
topk_indices = keys.topk(k, -1, largest=descending)[1] # (batch_size, k)
x = x.view((batch_size * max_len), -1)
shift = th.arange(0, batch_size, device=x.device).view(batch_size, 1) * max_len
topk_indices_ = topk_indices + shift
x = x[topk_indices_].view(batch_size, k, -1)
return th.masked_fill(x, th.isinf(x), 0), topk_indices
def _topk_on(graph, typestr, feat, k, descending, sortby, ntype_or_etype):
"""Internal function to take graph-wise top-k node/edge features of
field :attr:`feat` in :attr:`graph` ranked by keys at given
......@@ -534,41 +566,41 @@ def _topk_on(graph, typestr, feat, k, descending, sortby, ntype_or_etype):
if F.ndim(data[feat]) > 2:
raise DGLError('Only support {} feature `{}` with dimension less than or'
' equal to 2'.format(typestr, feat))
feat = data[feat]
hidden_size = F.shape(feat)[-1]
batch_num_objs = getattr(graph, batch_num_objs_attr)(ntype_or_etype)
batch_size = len(batch_num_objs)
length = max(max(F.asnumpy(batch_num_objs)), k)
fill_val = -float('inf') if descending else float('inf')
feat_ = F.pad_packed_tensor(feat, batch_num_objs, fill_val, l_min=k)
if sortby is not None:
keys = F.squeeze(F.slice_axis(feat_, -1, sortby, sortby+1), -1)
order = F.argsort(keys, -1, descending=descending)
else:
order = F.argsort(feat_, 1, descending=descending)
topk_indices = F.slice_axis(order, 1, 0, k)
# zero padding
feat_ = F.pad_packed_tensor(feat, batch_num_objs, 0, l_min=k)
feat_ = F.pad_packed_tensor(feat, batch_num_objs, fill_val, l_min=k) # (batch_size, l, d)
if sortby is not None:
feat_ = F.reshape(feat_, (batch_size * length, -1))
shift = F.repeat(F.arange(0, batch_size) * length, k, -1)
shift = F.copy_to(shift, F.context(feat))
topk_indices_ = F.reshape(topk_indices, (-1,)) + shift
if F.backend_name == 'pytorch' and sortby is not None:
# PyTorch's implementation of top-K
keys = feat_[..., sortby] # (batch_size, l)
return _topk_torch(keys, k, descending, feat_)
else:
feat_ = F.reshape(feat_, (-1,))
shift = F.repeat(F.arange(0, batch_size), k * hidden_size, -1) * length * hidden_size +\
F.cat([F.arange(0, hidden_size)] * batch_size * k, -1)
shift = F.copy_to(shift, F.context(feat))
topk_indices_ = F.reshape(topk_indices, (-1,)) * hidden_size + shift
return F.reshape(F.gather_row(feat_, topk_indices_), (batch_size, k, -1)),\
topk_indices
# Fallback to framework-agnostic implementation of top-K
if sortby is not None:
keys = F.squeeze(F.slice_axis(feat_, -1, sortby, sortby+1), -1)
order = F.argsort(keys, -1, descending=descending)
else:
order = F.argsort(feat_, 1, descending=descending)
topk_indices = F.slice_axis(order, 1, 0, k)
if sortby is not None:
feat_ = F.reshape(feat_, (batch_size * length, -1))
shift = F.repeat(F.arange(0, batch_size) * length, k, -1)
shift = F.copy_to(shift, F.context(feat))
topk_indices_ = F.reshape(topk_indices, (-1,)) + shift
else:
feat_ = F.reshape(feat_, (-1,))
shift = F.repeat(F.arange(0, batch_size), k * hidden_size, -1) * length * hidden_size +\
F.cat([F.arange(0, hidden_size)] * batch_size * k, -1)
shift = F.copy_to(shift, F.context(feat))
topk_indices_ = F.reshape(topk_indices, (-1,)) * hidden_size + shift
out = F.reshape(F.gather_row(feat_, topk_indices_), (batch_size, k, -1))
out = F.replace_inf_with_zero(out)
return out, topk_indices
def topk_nodes(graph, feat, k, *, descending=True, sortby=None, ntype=None):
"""Return a graph-level representation by a graph-wise top-k on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment