Unverified Commit f7e29388 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #90 from rusty1s/segment

[WIP] segment_* operators
parents 4ceb2d1a d1dd9466
import time
import itertools
import argparse
import torch
from scipy.io import loadmat
from torch_scatter import gather_coo, gather_csr
from scatter_segment import iters, sizes
from scatter_segment import short_rows, long_rows, download, bold
@torch.no_grad()
def correctness(dataset):
group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
dim_size = rowptr.size(0) - 1
for size in sizes[1:]:
try:
x = torch.randn((dim_size, size), device=args.device)
x = x.squeeze(-1) if size == 1 else x
out1 = x.index_select(0, row)
out2 = gather_coo(x, row)
out3 = gather_csr(x, rowptr)
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
except RuntimeError:
torch.cuda.empty_cache()
def time_func(func, x):
try:
torch.cuda.synchronize()
t = time.perf_counter()
if not args.with_backward:
with torch.no_grad():
for _ in range(iters):
func(x)
else:
x = x.requires_grad_()
for _ in range(iters):
out = func(x)
torch.autograd.grad(out, x, out, only_inputs=True)
torch.cuda.synchronize()
return time.perf_counter() - t
except RuntimeError:
torch.cuda.empty_cache()
return float('inf')
def timing(dataset):
group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size
def select(x):
return x.index_select(0, row)
def gather(x):
return x.gather(0, row.view(-1, 1).expand(-1, x.size(1)))
def gat_coo(x):
return gather_coo(x, row)
def gat_csr(x):
return gather_csr(x, rowptr)
t1, t2, t3, t4 = [], [], [], []
for size in sizes:
try:
x = torch.randn((dim_size, size), device=args.device)
t1 += [time_func(select, x)]
t2 += [time_func(gather, x)]
t3 += [time_func(gat_coo, x)]
t4 += [time_func(gat_csr, x)]
del x
except RuntimeError:
torch.cuda.empty_cache()
for t in (t1, t2, t3, t4):
t.append(float('inf'))
ts = torch.tensor([t1, t2, t3, t4])
winner = torch.zeros_like(ts, dtype=torch.bool)
winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
winner = winner.tolist()
name = f'{group}/{name}'
print(f'{bold(name)} (avg row length: {avg_row_len:.2f}):')
print('\t'.join([' '] + [f'{size:>5}' for size in sizes]))
print('\t'.join([bold('SELECT ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
print('\t'.join([bold('GAT ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
print('\t'.join([bold('GAT_COO')] +
[bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
print('\t'.join([bold('GAT_CSR')] +
[bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
for _ in range(10): # Warmup.
torch.randn(100, 100, device=args.device).sum()
for dataset in itertools.chain(short_rows, long_rows):
download(dataset)
correctness(dataset)
timing(dataset)
import time
import os.path as osp
import itertools
import argparse
import wget
import torch
from scipy.io import loadmat
import torch_scatter
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
from torch_scatter import segment_coo, segment_csr
iters = 20
sizes = [1, 16, 32, 64, 128, 256, 512]
short_rows = [
('DIMACS10', 'citationCiteseer'),
('SNAP', 'web-Stanford'),
]
long_rows = [
('Janna', 'StocF-1465'),
('GHS_psdef', 'ldoor'),
]
def download(dataset):
url = 'https://sparse.tamu.edu/mat/{}/{}.mat'
for group, name in itertools.chain(long_rows, short_rows):
if not osp.exists(f'{name}.mat'):
print(f'Downloading {group}/{name}:')
wget.download(url.format(group, name))
print('')
def bold(text, flag=True):
return f'\033[1m{text}\033[0m' if flag else text
@torch.no_grad()
def correctness(dataset):
group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
dim_size = rowptr.size(0) - 1
for size in sizes:
try:
x = torch.randn((row.size(0), size), device=args.device)
x = x.squeeze(-1) if size == 1 else x
out1 = scatter_add(x, row, dim=0, dim_size=dim_size)
out2 = segment_coo(x, row, dim_size=dim_size, reduce='add')
out3 = segment_csr(x, rowptr, reduce='add')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
out1 = scatter_mean(x, row, dim=0, dim_size=dim_size)
out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean')
out3 = segment_csr(x, rowptr, reduce='mean')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
x = x.abs_().mul_(-1)
out1, _ = scatter_min(x, row, 0, torch.zeros_like(out1))
out2, _ = segment_coo(x, row, reduce='min')
out3, _ = segment_csr(x, rowptr, reduce='min')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
x = x.abs_()
out1, _ = scatter_max(x, row, 0, torch.zeros_like(out1))
out2, _ = segment_coo(x, row, reduce='max')
out3, _ = segment_csr(x, rowptr, reduce='max')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
except RuntimeError:
torch.cuda.empty_cache()
def time_func(func, x):
try:
torch.cuda.synchronize()
t = time.perf_counter()
if not args.with_backward:
with torch.no_grad():
for _ in range(iters):
func(x)
else:
x = x.requires_grad_()
for _ in range(iters):
out = func(x)
out = out[0] if isinstance(out, tuple) else out
torch.autograd.grad(out, x, out, only_inputs=True)
torch.cuda.synchronize()
return time.perf_counter() - t
except RuntimeError:
torch.cuda.empty_cache()
return float('inf')
def timing(dataset):
group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
row_perm = row[torch.randperm(row.size(0))]
dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size
def sca_row(x):
op = getattr(torch_scatter, f'scatter_{args.reduce}')
return op(x, row, dim=0, dim_size=dim_size)
def sca_col(x):
op = getattr(torch_scatter, f'scatter_{args.reduce}')
return op(x, row_perm, dim=0, dim_size=dim_size)
def seg_coo(x):
return segment_coo(x, row, reduce=args.reduce)
def seg_csr(x):
return segment_csr(x, rowptr, reduce=args.reduce)
def dense1(x):
return getattr(torch, args.dense_reduce)(x, dim=-2)
def dense2(x):
return getattr(torch, args.dense_reduce)(x, dim=-1)
t1, t2, t3, t4, t5, t6 = [], [], [], [], [], []
for size in sizes:
try:
x = torch.randn((row.size(0), size), device=args.device)
x = x.squeeze(-1) if size == 1 else x
t1 += [time_func(sca_row, x)]
t2 += [time_func(sca_col, x)]
t3 += [time_func(seg_coo, x)]
t4 += [time_func(seg_csr, x)]
del x
except RuntimeError:
torch.cuda.empty_cache()
for t in (t1, t2, t3, t4):
t.append(float('inf'))
try:
x = torch.randn((dim_size, int(avg_row_len + 1), size),
device=args.device)
t5 += [time_func(dense1, x)]
x = x.view(dim_size, size, int(avg_row_len + 1))
t6 += [time_func(dense2, x)]
del x
except RuntimeError:
torch.cuda.empty_cache()
for t in (t5, t6):
t.append(float('inf'))
ts = torch.tensor([t1, t2, t3, t4, t5, t6])
winner = torch.zeros_like(ts, dtype=torch.bool)
winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
winner = winner.tolist()
name = f'{group}/{name}'
print(f'{bold(name)} (avg row length: {avg_row_len:.2f}):')
print('\t'.join([' '] + [f'{size:>5}' for size in sizes]))
print('\t'.join([bold('SCA_ROW')] +
[bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
print('\t'.join([bold('SCA_COL')] +
[bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
print('\t'.join([bold('SEG_COO')] +
[bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
print('\t'.join([bold('SEG_CSR')] +
[bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
print('\t'.join([bold('DENSE1 ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t5, winner[4])]))
print('\t'.join([bold('DENSE2 ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t6, winner[5])]))
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True,
choices=['add', 'mean', 'min', 'max'])
parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce
for _ in range(10): # Warmup.
torch.randn(100, 100, device=args.device).sum()
for dataset in itertools.chain(short_rows, long_rows):
download(dataset)
correctness(dataset)
timing(dataset)
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt);
at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
at::optional<at::Tensor> out_opt);
at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
if (out_opt.has_value())
CHECK_CUDA(out_opt.value());
return gather_csr_cuda(src, indptr, out_opt);
}
at::Tensor gather_coo(at::Tensor src, at::Tensor index,
at::optional<at::Tensor> out_opt) {
CHECK_CUDA(src);
CHECK_CUDA(index);
if (out_opt.has_value())
CHECK_CUDA(out_opt.value());
return gather_coo_cuda(src, index, out_opt);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gather_csr", &gather_csr, "Gather CSR (CUDA)");
m.def("gather_coo", &gather_coo, "Gather COO (CUDA)");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "compat.cuh"
#include "indptr.cuh"
#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
template <typename scalar_t, int TB>
__global__ void
gather_csr_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB;
int lane_idx = thread_idx % TB;
if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = __ldg(src_data + row_idx);
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int out_idx = row_start + lane_idx; out_idx < row_end; out_idx += TB) {
out_data[offset + out_idx] = val; // "Mostly" coalesced.
}
}
}
template <typename scalar_t>
__global__ void gather_csr_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t K, size_t E) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int lane_idx = thread_idx % K;
if (thread_idx < N * K) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = src_data[thread_idx]; // Coalesced.
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int out_idx = row_start; out_idx < row_end; out_idx++) {
out_data[offset + K * out_idx + lane_idx] = val; // "Mostly" coalesced.
}
}
}
at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt) {
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
for (int i = 0; i < indptr.dim() - 1; i++)
AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch");
src = src.contiguous();
auto gather_dim = indptr.dim() - 1;
AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1,
"Input mismatch");
at::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != gather_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
} else {
auto d_gather_size = indptr.flatten()[-1].DATA_PTR<int64_t>();
auto h_gather_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_gather_size, d_gather_size, sizeof(int64_t),
cudaMemcpyDeviceToHost);
auto sizes = src.sizes().vec();
sizes[gather_dim] = *h_gather_size;
out = at::empty(sizes, src.options());
}
auto N = src.size(gather_dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
auto E = out.size(gather_dim);
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
if (K == 1) {
gather_csr_kernel<scalar_t, 4><<<BLOCKS(1, 4 * N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, N, E);
} else {
gather_csr_broadcast_kernel<scalar_t>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, K, E);
}
});
return out;
}
template <typename scalar_t>
__global__ void
gather_coo_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t N) {
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int row = index_info.data[offset];
offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N;
scalar_t val = __ldg(src_data + offset + row);
out_data[row_idx] = val;
}
}
template <typename scalar_t>
__global__ void gather_coo_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t K, size_t N) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int col_idx = thread_idx % K;
if (thread_idx < E * K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int row = index_info.data[offset];
offset = (row_idx / index_info.sizes[index_info.dims - 1]) * N * K;
scalar_t val = __ldg(src_data + offset + K * row + col_idx);
out_data[thread_idx] = val;
}
}
at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
at::optional<at::Tensor> out_opt) {
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
for (int i = 0; i < index.dim() - 1; i++)
AT_ASSERTM(src.size(i) == index.size(i), "Input mismatch");
src = src.contiguous();
auto gather_dim = index.dim() - 1;
at::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < index.dim(); i++)
AT_ASSERTM(out.size(i) == index.size(i), "Input mismatch");
for (int i = index.dim() + 1; i < src.dim(); i++)
AT_ASSERTM(out.size(i) == src.size(i), "Input mismatch");
} else {
auto sizes = src.sizes().vec();
sizes[gather_dim] = index.size(gather_dim);
out = at::empty(sizes, src.options());
}
auto E = index.numel();
auto K = out.numel() / E;
auto N = src.size(gather_dim);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
if (K == 1) {
gather_coo_kernel<scalar_t><<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, E, N);
} else {
gather_coo_broadcast_kernel<scalar_t>
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>(src_data, index_info,
out_data, E, K, N);
}
});
return out;
}
#pragma once
#include <ATen/ATen.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
// We need our own `IndexToOffset` implementation since we do not want to
// access the last element of the `indexptr`.
template <typename scalar_t> struct IndexPtrToOffset {
static inline __host__ __device__ int
get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) {
int offset = idx % (info.sizes[info.dims - 1] - 1);
offset *= info.strides[info.dims - 1];
idx /= info.sizes[info.dims - 1] - 1;
for (int i = info.dims - 2; i >= 0; --i) {
offset += (idx % info.sizes[i]) * info.strides[i];
idx /= info.sizes[i];
}
return offset;
}
};
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt, std::string reduce);
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
std::string reduce);
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
if (out_opt.has_value())
CHECK_CUDA(out_opt.value());
return segment_csr_cuda(src, indptr, out_opt, reduce);
}
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
return segment_coo_cuda(src, index, out, reduce);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("segment_csr", &segment_csr, "Segment CSR (CUDA)");
m.def("segment_coo", &segment_coo, "Segment COO (CUDA)");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "atomics.cuh"
#include "compat.cuh"
#include "indptr.cuh"
#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
enum ReductionType { ADD, MEAN, MIN, MAX };
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
if (reduce == "add") { \
const ReductionType REDUCE = ADD; \
return __VA_ARGS__(); \
} else if (reduce == "mean") { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} else if (reduce == "min") { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} else if (reduce == "max") { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ scalar_t init() {
if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
return std::numeric_limits<scalar_t>::lowest();
} else {
return (scalar_t)0;
}
}
static inline __host__ __device__ void update(scalar_t *val,
scalar_t new_val) {
if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
}
}
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) {
if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address,
int64_t arg, int count) {
if (REDUCE == ADD) {
*address = val;
} else if (REDUCE == MEAN) {
*address = val / (scalar_t)max(count, 1);
} else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else {
*address = (scalar_t)0;
}
}
}
static inline __device__ void atomic_write(scalar_t *address, scalar_t val) {
if (REDUCE == ADD || REDUCE == MEAN) {
atomAdd(address, val);
} else if (REDUCE == MIN && val < *address) {
atomMin(address, val);
} else if (REDUCE == MAX && val > *address) {
atomMax(address, val);
}
}
};
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void
segment_csr_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, int64_t *arg_out_data, size_t N,
size_t E) {
// Each warp processes exactly `32/TB` rows and aggregates all row values
// via a parallel reduction.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB;
int lane_idx = thread_idx & (TB - 1);
if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int64_t row_start = __ldg(indptr_info.data + offset);
int64_t row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg, arg_tmp;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int64_t src_idx = row_start + lane_idx; src_idx < row_end;
src_idx += TB) {
Reducer<scalar_t, REDUCE>::update(&val, src_data[offset + src_idx], &arg,
src_idx);
}
#pragma unroll
for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp.
if (REDUCE == MIN || REDUCE == MAX)
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
}
if (lane_idx == 0) {
Reducer<scalar_t, REDUCE>::write(out_data + row_idx, val,
arg_out_data + row_idx, arg,
row_end - row_start);
}
}
}
template <typename scalar_t, ReductionType REDUCE>
__global__ void segment_csr_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) {
// Each thread processes exactly one row. It turned out that is more
// efficient than using shared memory due to avoiding synchronization
// barriers.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int lane_idx = thread_idx % K;
if (thread_idx < N * K) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int64_t row_start = __ldg(indptr_info.data + offset);
int64_t row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int64_t src_idx = row_start; src_idx < row_end; src_idx++) {
Reducer<scalar_t, REDUCE>::update(
&val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx);
}
Reducer<scalar_t, REDUCE>::write(out_data + thread_idx, val,
arg_out_data + thread_idx, arg,
row_end - row_start);
}
}
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt, std::string reduce) {
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
// Broadcasting across `index` via `expand`.
auto sizes = indptr.sizes().vec();
for (int i = 0; i < indptr.dim() - 1; i++) {
sizes[i] = src.size(i);
}
indptr = indptr.expand(sizes);
src = src.contiguous();
auto reduce_dim = indptr.dim() - 1;
at::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1,
"Input mismatch");
} else {
sizes = src.sizes().vec();
sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
out = at::empty(sizes, src.options());
}
at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce == "min" || reduce == "max") {
arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(reduce_dim);
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_csr_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (K == 1) {
segment_csr_kernel<scalar_t, REDUCE, 1>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, E);
} else {
segment_csr_broadcast_kernel<scalar_t, REDUCE>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, K, E);
}
});
});
return std::make_tuple(out, arg_out);
}
template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
__global__ void
segment_coo_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t N) {
// Each thread processes exactly one entry. Within a warp, we perform a
// parallel reduction across equal indices, and write the intermediate
// result via atomics.
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
int lane_idx = row_idx & (32 - 1);
int D = index_info.sizes[index_info.dims - 1];
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int64_t idx = index_info.data[offset], next_idx;
int out_idx = (row_idx / D) * N + idx;
scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp;
#pragma unroll
for (int i = 1; i < 32; i *= 2) {
// Parallel reduction inside a single warp.
tmp = __shfl_up_sync(FULL_MASK, val, i);
next_idx = __shfl_up_sync(FULL_MASK, idx, i);
if (lane_idx >= i && row_idx / D == (row_idx - i) / D) {
assert(idx >= next_idx);
if (idx == next_idx)
Reducer<scalar_t, REDUCE>::update(&val, tmp);
}
}
next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D ||
idx != next_idx)
Reducer<scalar_t, REDUCE>::atomic_write(out_data + out_idx, val);
}
}
template <typename scalar_t>
__global__ void segment_coo_arg_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t N) {
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
int D = index_info.sizes[index_info.dims - 1];
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int64_t idx = index_info.data[offset];
int out_idx = (row_idx / D) * N + idx;
scalar_t val = __ldg(out_data + out_idx);
if (src_data[row_idx] == val)
arg_out_data[out_idx] = row_idx % D;
}
}
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void segment_coo_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t K, size_t N) {
// Each thread processes a single column and `TB` index entries. Coalesced
// read and write is performed in column-major order. The intermediate
// results are written via atomics.
int D = index_info.sizes[index_info.dims - 1];
int E_1 = E / D;
int E_2 = D + TB - (D % TB);
int row_idx = blockIdx.x * blockDim.y + threadIdx.y;
int col_idx = blockIdx.y * blockDim.x + threadIdx.x;
int dim_start = (row_idx * TB) / E_2;
int row_start = (row_idx * TB) % E_2;
if (dim_start < E_1 && col_idx < K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
dim_start * D + row_start, index_info);
int idx1 = __ldg(index_info.data + offset), idx2;
scalar_t val = src_data[K * (dim_start * D + row_start) + col_idx];
#pragma unroll
for (int i = 1; i < TB; i++) {
if (row_start + i >= D)
break;
idx2 = __ldg(index_info.data + offset +
i * index_info.strides[index_info.dims - 1]);
assert(idx1 <= idx2);
if (idx1 == idx2) {
Reducer<scalar_t, REDUCE>::update(
&val, src_data[K * (dim_start * D + row_start + i) + col_idx]);
} else {
Reducer<scalar_t, REDUCE>::atomic_write(
out_data + (dim_start * N + idx1) * K + col_idx, val);
val = src_data[K * (dim_start * D + row_start + i) + col_idx];
}
idx1 = idx2;
}
Reducer<scalar_t, REDUCE>::atomic_write(
out_data + (dim_start * N + idx1) * K + col_idx, val);
}
}
template <typename scalar_t>
__global__ void segment_coo_arg_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K, size_t N) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int col_idx = thread_idx % K;
int D = index_info.sizes[index_info.dims - 1];
if (row_idx < E && col_idx < K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int idx = __ldg(index_info.data + offset);
int out_idx = ((row_idx / D) * N + idx) * K + col_idx;
scalar_t val = __ldg(out_data + out_idx);
if (src_data[thread_idx] == val)
arg_out_data[out_idx] = row_idx % D;
}
}
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
std::string reduce) {
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
// Broadcasting across `index` via `expand`.
auto sizes = index.sizes().vec();
for (int i = 0; i < index.dim(); i++) {
sizes[i] = src.size(i);
}
index = index.expand(sizes);
src = src.contiguous();
out = out.contiguous();
auto reduce_dim = index.dim() - 1;
for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce == "min" || reduce == "max") {
arg_out = at::full_like(out, src.size(reduce_dim), index.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
auto E = index.numel();
auto E_2 = index.size(reduce_dim);
auto E_1 = index.numel() / E_2;
auto K = src.numel() / E;
auto N = out.size(reduce_dim);
auto avg_len = (float)E_2 / (float)N;
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (K == 1) {
segment_coo_kernel<scalar_t, REDUCE, true>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info,
out_data, E, N);
} else if (avg_len <= 8) {
segment_coo_broadcast_kernel<scalar_t, REDUCE, 4>
<<<dim3((E_1 * ((E_2 + 3) / 4) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
} else if (avg_len <= 16) {
segment_coo_broadcast_kernel<scalar_t, REDUCE, 8>
<<<dim3((E_1 * ((E_2 + 7) / 8) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
} else if (avg_len <= 32) {
segment_coo_broadcast_kernel<scalar_t, REDUCE, 16>
<<<dim3((E_1 * ((E_2 + 15) / 16) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
} else {
segment_coo_broadcast_kernel<scalar_t, REDUCE, 32>
<<<dim3((E_1 * ((E_2 + 31) / 32) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
}
if (REDUCE == MIN || REDUCE == MAX) {
if (K == 1) {
segment_coo_arg_kernel<scalar_t>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, N);
} else {
segment_coo_arg_broadcast_kernel<scalar_t>
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, K, N);
}
}
});
});
if (reduce == "mean") {
auto sizes = index.sizes().vec();
sizes[reduce_dim] = out.size(reduce_dim);
auto count = at::zeros(sizes, out.options());
AT_DISPATCH_ALL_TYPES(out.scalar_type(), "count_kernel", [&] {
auto count_data = count.DATA_PTR<scalar_t>();
segment_coo_kernel<scalar_t, ADD, false>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
count_data, E, N);
});
count.clamp_(1);
arg_out = count;
for (int i = reduce_dim + 1; i < out.dim(); i++) {
count = count.unsqueeze(-1);
}
out.div_(count);
}
return std::make_tuple(out, arg_out);
}
import platform
import os.path as osp
from glob import glob
from setuptools import setup, find_packages
from sys import argv
import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
USE_GPU = True
if '--cpu' in argv:
USE_GPU = False
extra_compile_args = []
cxx_extra_compile_args = []
nvcc_extra_compile_args = ['-arch=sm_35', '--expt-relaxed-constexpr']
if platform.system() != 'Windows':
extra_compile_args += ['-Wno-unused-variable']
cxx_extra_compile_args += ['-Wno-unused-variable']
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args += ['-DVERSION_GE_1_3']
ext_modules = [
CppExtension('torch_scatter.scatter_cpu', ['cpu/scatter.cpp'],
extra_compile_args=extra_compile_args)
]
cxx_extra_compile_args += ['-DVERSION_GE_1_3']
nvcc_extra_compile_args += ['-DVERSION_GE_1_3']
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
GPU = True
for arg in argv:
if arg == '--cpu':
GPU = False
argv.remove(arg)
ext_modules = []
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))]
ext_modules += [
CppExtension(f'torch_scatter.{ext}_cpu', [f'cpu/{ext}.cpp'],
extra_compile_args=cxx_extra_compile_args) for ext in exts
]
if CUDA_HOME is not None and GPU:
if CUDA_HOME is not None and USE_GPU:
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cuda', '*.cpp'))]
ext_modules += [
CUDAExtension('torch_scatter.scatter_cuda',
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
CUDAExtension(
f'torch_scatter.{ext}_cuda',
[f'cuda/{ext}.cpp', f'cuda/{ext}_kernel.cu'], extra_compile_args={
'cxx': cxx_extra_compile_args,
'nvcc': nvcc_extra_compile_args,
}) for ext in exts
]
__version__ = '1.4.0'
__version__ = '1.5.0'
url = 'https://github.com/rusty1s/pytorch_scatter'
install_requires = []
......@@ -47,10 +55,7 @@ setup(
author_email='matthias.fey@tu-dortmund.de',
url=url,
download_url='{}/archive/{}.tar.gz'.format(url, __version__),
keywords=[
'pytorch',
'scatter',
],
keywords=['pytorch', 'scatter', 'segment'],
install_requires=install_requires,
setup_requires=setup_requires,
tests_require=tests_require,
......
......@@ -22,11 +22,21 @@ def test_softmax(dtype, device):
expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0)
], dim=0).to(device)
assert torch.allclose(out, expected)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_softmax_broadcasting(dtype, device):
src = torch.randn(10, 5, dtype=dtype, device=device)
index = tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
out = scatter_softmax(src, index, dim=0).view(5, 2, 5)
out = out.sum(dim=1)
assert torch.allclose(out, torch.ones_like(out))
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_log_softmax(dtype, device):
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
......@@ -42,6 +52,6 @@ def test_log_softmax(dtype, device):
expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0)
], dim=0).to(device)
assert torch.allclose(out, expected)
from itertools import product
import pytest
import torch
from torch.autograd import gradcheck
from torch_scatter import gather_coo, gather_csr
from .utils import tensor
dtypes = [torch.float]
devices = [torch.device('cuda')]
tests = [
{
'src': [1, 2, 3, 4],
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'expected': [1, 1, 2, 2, 2, 4],
},
{
'src': [[1, 2], [3, 4], [5, 6], [7, 8]],
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'expected': [[1, 2], [1, 2], [3, 4], [3, 4], [3, 4], [7, 8]]
},
{
'src': [[1, 3, 5, 7], [2, 4, 6, 8]],
'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
'expected': [[1, 1, 3, 3, 3, 7], [2, 2, 2, 4, 4, 6]],
},
{
'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
'index': [[0, 0, 1], [0, 2, 2]],
'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
'expected': [[[1, 2], [1, 2], [3, 4]], [[7, 9], [12, 13], [12, 13]]],
},
{
'src': [[1], [2]],
'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]],
'expected': [[1, 1], [2, 2]],
},
{
'src': [[[1, 1]], [[2, 2]]],
'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]],
'expected': [[[1, 1], [1, 1]], [[2, 2], [2, 2]]],
},
]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_forward(test, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test['expected'], dtype, device)
out = gather_coo(src, index)
assert torch.all(out == expected)
out = gather_csr(src, indptr)
assert torch.all(out == expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,device', product(tests, devices))
def test_backward(test, device):
src = tensor(test['src'], torch.double, device)
src.requires_grad_()
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
assert gradcheck(gather_coo, (src, index, None)) is True
assert gradcheck(gather_csr, (src, indptr, None)) is True
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_segment_out(test, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test['expected'], dtype, device)
size = list(src.size())
size[index.dim() - 1] = index.size(-1)
out = src.new_full(size, -2)
gather_coo(src, index, out)
assert torch.all(out == expected)
out.fill_(-2)
gather_csr(src, indptr, out)
assert torch.all(out == expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_non_contiguous_segment(test, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test['expected'], dtype, device)
if src.dim() > 1:
src = src.transpose(0, 1).contiguous().transpose(0, 1)
if index.dim() > 1:
index = index.transpose(0, 1).contiguous().transpose(0, 1)
if indptr.dim() > 1:
indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1)
out = gather_coo(src, index)
assert torch.all(out == expected)
out = gather_csr(src, indptr)
assert torch.all(out == expected)
......@@ -20,5 +20,5 @@ def test_logsumexp(dtype, device):
out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype)
out4 = torch.tensor(-1, dtype=dtype)
expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
expected = torch.stack([out0, out1, out2, out3, out4], dim=0).to(device)
assert torch.allclose(out, expected)
from itertools import product
import pytest
import torch
from torch.autograd import gradcheck
from torch_scatter import segment_coo, segment_csr
from .utils import tensor, dtypes
reductions = ['add', 'mean', 'min', 'max']
grad_reductions = ['add', 'mean']
devices = [torch.device('cuda')]
tests = [
{
'src': [1, 2, 3, 4, 5, 6],
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'add': [3, 12, 0, 6],
'mean': [1.5, 4, 0, 6],
'min': [1, 3, 0, 6],
'arg_min': [0, 2, 6, 5],
'max': [2, 5, 0, 6],
'arg_max': [1, 4, 6, 5],
},
{
'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'add': [[4, 6], [21, 24], [0, 0], [11, 12]],
'mean': [[2, 3], [7, 8], [0, 0], [11, 12]],
'min': [[1, 2], [5, 6], [0, 0], [11, 12]],
'arg_min': [[0, 0], [2, 2], [6, 6], [5, 5]],
'max': [[3, 4], [9, 10], [0, 0], [11, 12]],
'arg_max': [[1, 1], [4, 4], [6, 6], [5, 5]],
},
{
'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]],
'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
'add': [[4, 21, 0, 11], [12, 18, 12, 0]],
'mean': [[2, 7, 0, 11], [4, 9, 12, 0]],
'min': [[1, 5, 0, 11], [2, 8, 12, 0]],
'arg_min': [[0, 2, 6, 5], [0, 3, 5, 6]],
'max': [[3, 9, 0, 11], [6, 10, 12, 0]],
'arg_max': [[1, 4, 6, 5], [2, 4, 5, 6]],
},
{
'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
'index': [[0, 0, 1], [0, 2, 2]],
'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]],
'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]],
'arg_min': [[[0, 0], [2, 2], [3, 3]], [[0, 0], [3, 3], [1, 1]]],
'max': [[[3, 4], [5, 6], [0, 0]], [[7, 9], [0, 0], [12, 13]]],
'arg_max': [[[1, 1], [2, 2], [3, 3]], [[0, 0], [3, 3], [2, 2]]],
},
{
'src': [[1, 3], [2, 4]],
'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]],
'add': [[4], [6]],
'mean': [[2], [3]],
'min': [[1], [2]],
'arg_min': [[0], [0]],
'max': [[3], [4]],
'arg_max': [[1], [1]],
},
{
'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]],
'add': [[[4, 4]], [[6, 6]]],
'mean': [[[2, 2]], [[3, 3]]],
'min': [[[1, 1]], [[2, 2]]],
'arg_min': [[[0, 0]], [[0, 0]]],
'max': [[[3, 3]], [[4, 4]]],
'arg_max': [[[1, 1]], [[1, 1]]],
},
]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices))
def test_forward(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test[reduce], dtype, device)
out = segment_coo(src, index, reduce=reduce)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected)
out = segment_csr(src, indptr, reduce=reduce)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,reduce,device',
product(tests, grad_reductions, devices))
def test_backward(test, reduce, device):
src = tensor(test['src'], torch.double, device)
src.requires_grad_()
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
assert gradcheck(segment_coo, (src, index, None, None, reduce)) is True
assert gradcheck(segment_csr, (src, indptr, None, reduce)) is True
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices))
def test_segment_out(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test[reduce], dtype, device)
size = list(src.size())
size[indptr.dim() - 1] = indptr.size(-1) - 1
out = src.new_full(size, -2)
segment_csr(src, indptr, out, reduce=reduce)
assert torch.all(out == expected)
out.fill_(-2)
segment_coo(src, index, out, reduce=reduce)
if reduce == 'add':
expected = expected - 2
elif reduce == 'mean':
expected = out # We can not really test this here.
elif reduce == 'min':
expected = expected.fill_(-2)
elif reduce == 'max':
expected[expected == 0] = -2
else:
raise ValueError
assert torch.all(out == expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices))
def test_non_contiguous_segment(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test[reduce], dtype, device)
if src.dim() > 1:
src = src.transpose(0, 1).contiguous().transpose(0, 1)
if index.dim() > 1:
index = index.transpose(0, 1).contiguous().transpose(0, 1)
if indptr.dim() > 1:
indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1)
out = segment_coo(src, index, reduce=reduce)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected)
out = segment_csr(src, indptr, reduce=reduce)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected)
import torch
from torch.testing import get_all_dtypes
dtypes = get_all_dtypes()
dtypes.remove(torch.half)
dtypes.remove(torch.short) # PyTorch scatter does not work on short types.
dtypes.remove(torch.bool)
if hasattr(torch, 'bfloat16'):
dtypes.remove(torch.bfloat16)
dtypes = [torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.float, torch.double]
devices = [torch.device('cpu')]
......
......@@ -7,6 +7,10 @@ from .std import scatter_std
from .max import scatter_max
from .min import scatter_min
from .logsumexp import scatter_logsumexp
from .segment import segment_coo, segment_csr
from .gather import gather_coo, gather_csr
import torch_scatter.composite
__version__ = '1.4.0'
......@@ -21,6 +25,10 @@ __all__ = [
'scatter_max',
'scatter_min',
'scatter_logsumexp',
'segment_coo',
'segment_csr',
'gather_coo',
'gather_csr',
'torch_scatter',
'__version__',
]
import torch
from torch_scatter import scatter_add, scatter_max
from torch_scatter.utils.gen import broadcast
def scatter_softmax(src, index, dim=-1, eps=1e-12):
......@@ -31,6 +32,7 @@ def scatter_softmax(src, index, dim=-1, eps=1e-12):
raise ValueError('`scatter_softmax` can only be computed over tensors '
'with floating point data types.')
src, index = broadcast(src, index, dim)
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
max_per_src_element = max_value_per_index.gather(dim, index)
......@@ -73,6 +75,7 @@ def scatter_log_softmax(src, index, dim=-1, eps=1e-12):
raise ValueError('`scatter_log_softmax` can only be computed over '
'tensors with floating point data types.')
src, index = broadcast(src, index, dim)
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
max_per_src_element = max_value_per_index.gather(dim, index)
......
import torch
if torch.cuda.is_available():
from torch_scatter import gather_cuda, segment_cuda
class GatherCOO(torch.autograd.Function):
@staticmethod
def forward(ctx, src, index, out):
if out is not None:
ctx.mark_dirty(out)
ctx.src_size = list(src.size())
ctx.save_for_backward(index)
return gather_cuda.gather_coo(src, index, out)
@staticmethod
def backward(ctx, grad_out):
(index, ), src_size = ctx.saved_tensors, ctx.src_size
grad_src = None
if ctx.needs_input_grad[0]:
grad_src, _ = segment_cuda.segment_coo(
grad_out, index, grad_out.new_zeros(src_size), 'add')
return grad_src, None, None
class GatherCSR(torch.autograd.Function):
@staticmethod
def forward(ctx, src, indptr, out):
if out is not None:
ctx.mark_dirty(out)
ctx.src_size = list(src.size())
ctx.save_for_backward(indptr)
return gather_cuda.gather_csr(src, indptr, out)
@staticmethod
def backward(ctx, grad_out):
(indptr, ), src_size = ctx.saved_tensors, ctx.src_size
grad_src = None
if ctx.needs_input_grad[0]:
grad_src, _ = segment_cuda.segment_csr(
grad_out, indptr, grad_out.new_empty(src_size), 'add')
return grad_src, None, None
def gather_coo(src, index, out=None):
return GatherCOO.apply(src, index, out)
def gather_csr(src, indptr, out=None):
return GatherCSR.apply(src, indptr, out)
import torch
def min_value(dtype):
try:
return torch.finfo(dtype).min
except TypeError:
return torch.iinfo(dtype).min
def max_value(dtype):
try:
return torch.finfo(dtype).max
except TypeError:
return torch.iinfo(dtype).max
import torch
from torch_scatter.helpers import min_value, max_value
if torch.cuda.is_available():
from torch_scatter import segment_cuda, gather_cuda
class SegmentCOO(torch.autograd.Function):
@staticmethod
def forward(ctx, src, index, out, dim_size, reduce):
assert reduce in ['add', 'mean', 'min', 'max']
if out is not None:
ctx.mark_dirty(out)
ctx.reduce = reduce
ctx.src_size = list(src.size())
fill_value = 0
if out is None:
dim_size = index.max().item() + 1 if dim_size is None else dim_size
size = list(src.size())
size[index.dim() - 1] = dim_size
if reduce == 'min':
fill_value = max_value(src.dtype)
elif reduce == 'max':
fill_value = min_value(src.dtype)
out = src.new_full(size, fill_value)
out, arg_out = segment_cuda.segment_coo(src, index, out, reduce)
if fill_value != 0:
out.masked_fill_(out == fill_value, 0)
ctx.save_for_backward(index, arg_out)
if reduce == 'min' or reduce == 'max':
return out, arg_out
else:
return out
@staticmethod
def backward(ctx, grad_out, *args):
(index, arg_out), src_size = ctx.saved_tensors, ctx.src_size
grad_src = None
if ctx.needs_input_grad[0]:
if ctx.reduce == 'add':
grad_src = gather_cuda.gather_coo(grad_out, index,
grad_out.new_empty(src_size))
elif ctx.reduce == 'mean':
grad_src = gather_cuda.gather_coo(grad_out, index,
grad_out.new_empty(src_size))
count = arg_out
count = gather_cuda.gather_coo(
count, index, count.new_empty(src_size[:index.dim()]))
for _ in range(grad_out.dim() - index.dim()):
count = count.unsqueeze(-1)
grad_src.div_(count)
elif ctx.reduce == 'min' or ctx.reduce == 'max':
src_size[index.dim() - 1] += 1
grad_src = grad_out.new_zeros(src_size).scatter_(
index.dim() - 1, arg_out, grad_out)
grad_src = grad_src.narrow(index.dim() - 1, 0,
src_size[index.dim() - 1] - 1)
return grad_src, None, None, None, None
class SegmentCSR(torch.autograd.Function):
@staticmethod
def forward(ctx, src, indptr, out, reduce):
assert reduce in ['add', 'mean', 'min', 'max']
if out is not None:
ctx.mark_dirty(out)
ctx.reduce = reduce
ctx.src_size = list(src.size())
out, arg_out = segment_cuda.segment_csr(src, indptr, out, reduce)
ctx.save_for_backward(indptr, arg_out)
return out if arg_out is None else (out, arg_out)
@staticmethod
def backward(ctx, grad_out, *args):
(indptr, arg_out), src_size = ctx.saved_tensors, ctx.src_size
grad_src = None
if ctx.needs_input_grad[0]:
if ctx.reduce == 'add':
grad_src = gather_cuda.gather_csr(grad_out, indptr,
grad_out.new_empty(src_size))
elif ctx.reduce == 'mean':
grad_src = gather_cuda.gather_csr(grad_out, indptr,
grad_out.new_empty(src_size))
indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1)
indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1)
count = (indptr2 - indptr1).to(grad_src.dtype)
count = gather_cuda.gather_csr(
count, indptr, count.new_empty(src_size[:indptr.dim()]))
for _ in range(grad_out.dim() - indptr.dim()):
count = count.unsqueeze(-1)
grad_src.div_(count)
elif ctx.reduce == 'min' or ctx.reduce == 'max':
src_size[indptr.dim() - 1] += 1
grad_src = grad_out.new_zeros(src_size).scatter_(
indptr.dim() - 1, arg_out, grad_out)
grad_src = grad_src.narrow(indptr.dim() - 1, 0,
src_size[indptr.dim() - 1] - 1)
return grad_src, None, None, None
def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
return SegmentCOO.apply(src, index, out, dim_size, reduce)
def segment_csr(src, indptr, out=None, reduce='add'):
return SegmentCSR.apply(src, indptr, out, reduce)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment