Commit 92082f98 authored by rusty1s's avatar rusty1s
Browse files

faster coalesce if no value provided

parent 3c7253aa
#include <torch/torch.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src);
std::tuple<at::Tensor, at::Tensor> unique(at::Tensor src) {
CHECK_CUDA(src);
return unique_cuda(src);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("unique", &unique, "Unique (CUDA)");
}
#include <ATen/ATen.h>
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t>
__global__ void unique_cuda_kernel(scalar_t *__restrict__ src, uint8_t *mask,
size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
if (i == 0 || src[i] != src[i - 1]) {
mask[i] = 1;
}
}
}
std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
at::Tensor perm;
std::tie(src, perm) = src.sort();
auto mask = at::zeros(src.numel(), src.type().toScalarType(at::kByte));
AT_DISPATCH_ALL_TYPES(src.type(), "grid_cuda_kernel", [&] {
unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>(
src.data<scalar_t>(), mask.data<uint8_t>(), src.numel());
});
src = src.masked_select(mask);
perm = perm.masked_select(mask);
return std::make_tuple(src, perm);
}
......@@ -17,7 +17,9 @@ if torch.cuda.is_available():
'spspmm_cuda',
['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'],
extra_link_args=['-lcusparse'],
)
),
CUDAExtension('unique_cuda',
['cuda/unique.cpp', 'cuda/unique_kernel.cu'])
]
cmdclass['build_ext'] = BuildExtension
......
import torch
import torch_scatter
from .utils.unique import unique
def coalesce(index, value, m, n, op='add', fill_value=0):
"""Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate
......@@ -23,16 +25,20 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
row, col = index
unique, inv = torch.unique(row * n + col, sorted=True, return_inverse=True)
if value is None:
_, perm = unique(row * n + col)
index = torch.stack([row[perm], col[perm]], dim=0)
return index, value
uniq, inv = torch.unique(row * n + col, sorted=True, return_inverse=True)
perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)
perm = inv.new_empty(unique.size(0)).scatter_(0, inv, perm)
perm = inv.new_empty(uniq.size(0)).scatter_(0, inv, perm)
index = torch.stack([row[perm], col[perm]], dim=0)
if value is not None:
op = getattr(torch_scatter, 'scatter_{}'.format(op))
value = op(value, inv, 0, None, perm.size(0), fill_value)
if isinstance(value, tuple):
value = value[0]
op = getattr(torch_scatter, 'scatter_{}'.format(op))
value = op(value, inv, 0, None, perm.size(0), fill_value)
if isinstance(value, tuple):
value = value[0]
return index, value
import torch
import numpy as np
if torch.cuda.is_available():
import unique_cuda
def unique(src):
src = src.contiguous().view(-1)
if src.is_cuda:
out, perm = unique_cuda.unique(src)
else:
out, perm = np.unique(src.numpy(), return_index=True)
out, perm = torch.from_numpy(out), torch.from_numpy(perm)
return out, perm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment