Commit cd6d8d68 authored by rusty1s's avatar rusty1s
Browse files

all cuda kernels done

parent 7e82bc0e
# flake8: noqa
import time
import os.path as osp
import itertools
import argparse
import wget
import torch
from scipy.io import loadmat
import torch_scatter
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
from torch_scatter import segment_coo, segment_csr
parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True)
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce
iters = 20
device = 'cuda'
sizes = [1, 16, 32, 64, 128, 256, 512]
short_rows = [
......@@ -40,13 +49,13 @@ def bold(text, flag=True):
def correctness(dataset):
group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(device, torch.long)
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
dim_size = rowptr.size(0) - 1
for size in sizes:
try:
x = torch.randn((row.size(0), size), device=device)
x = torch.randn((row.size(0), size), device=args.device)
x = x.squeeze(-1) if size == 1 else x
out1 = scatter_add(x, row, dim=0, dim_size=dim_size)
......@@ -63,92 +72,71 @@ def correctness(dataset):
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
# out1, arg_out1 = scatter_max(x, row, dim=0, dim_size=dim_size)
# out3, arg_out3 = segment_csr(x, rowptr, reduce='max')
x = x.abs_().mul_(-1)
# print(out1[:5])
# print(out3[:5])
out1, arg_out1 = scatter_min(x, row, 0, torch.zeros_like(out1))
out2, arg_out2 = segment_coo(x, row, reduce='min')
out3, arg_out3 = segment_csr(x, rowptr, reduce='min')
# nnz = (out1 != out3).nonzero().flatten()
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
# nnz1 = nnz[0].item()
# print(rowptr[nnz1], rowptr[nnz1 + 1])
x = x.abs_()
# print(x[rowptr[nnz1]:rowptr[nnz1 + 1]])
# print(x[rowptr[nnz1]:rowptr[nnz1 + 1]])
out1, arg_out1 = scatter_max(x, row, 0, torch.zeros_like(out1))
out2, arg_out2 = segment_coo(x, row, reduce='max')
out3, arg_out3 = segment_csr(x, rowptr, reduce='max')
# print(out1[nnz1])
# print(out3[nnz1])
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
# assert torch.allclose(out1, out3, atol=1e-4)
# assert torch.all(arg_out1 == arg_out3)
except RuntimeError:
torch.cuda.empty_cache()
def time_func(func, x):
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
func(x)
torch.cuda.synchronize()
return time.perf_counter() - t
except RuntimeError:
torch.cuda.empty_cache()
return float('inf')
@torch.no_grad()
def timing(dataset):
group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(device, torch.long)
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
row_perm = row[torch.randperm(row.size(0))]
dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size
sca_row = lambda x: getattr(torch_scatter, f'scatter_{args.reduce}')(
x, row, dim=0, dim_size=dim_size)
sca_col = lambda x: getattr(torch_scatter, f'scatter_{args.reduce}')(
x, row_perm, dim=0, dim_size=dim_size)
seg_coo = lambda x: segment_coo(x, row, reduce=args.reduce)
seg_csr = lambda x: segment_csr(x, rowptr, reduce=args.reduce)
dense1 = lambda x: getattr(torch, args.dense_reduce)(x, dim=-2)
dense2 = lambda x: getattr(torch, args.dense_reduce)(x, dim=-1)
t1, t2, t3, t4, t5, t6 = [], [], [], [], [], []
for size in sizes:
try:
x = torch.randn((row.size(0), size), device=device)
x = torch.randn((row.size(0), size), device=args.device)
x = x.squeeze(-1) if size == 1 else x
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = scatter_add(x, row, dim=0, dim_size=dim_size)
del out
torch.cuda.synchronize()
t1.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t1.append(float('inf'))
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = scatter_add(x, row_perm, dim=0, dim_size=dim_size)
del out
torch.cuda.synchronize()
t2.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t2.append(float('inf'))
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = segment_coo(x, row, dim_size=dim_size, reduce='any')
del out
torch.cuda.synchronize()
t3.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t3.append(float('inf'))
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = segment_csr(x, rowptr, reduce='any')
del out
torch.cuda.synchronize()
t4.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t4.append(float('inf'))
t1 += [time_func(sca_row, x)]
t2 += [time_func(sca_col, x)]
t3 += [time_func(seg_coo, x)]
t4 += [time_func(seg_csr, x)]
del x
......@@ -159,35 +147,11 @@ def timing(dataset):
try:
x = torch.randn((dim_size, int(avg_row_len + 1), size),
device=device)
x = x.squeeze(-1) if size == 1 else x
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = x.sum(dim=1)
del out
torch.cuda.synchronize()
t5.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t5.append(float('inf'))
device=args.device)
t5 += [time_func(dense1, x)]
x = x.view(dim_size, size, int(avg_row_len + 1))
x = x.squeeze(-2) if size == 1 else x
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = x.sum(dim=-1)
del out
torch.cuda.synchronize()
t6.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t6.append(float('inf'))
t6 += [time_func(dense2, x)]
del x
......@@ -221,7 +185,7 @@ def timing(dataset):
if __name__ == '__main__':
for _ in range(10): # Warmup.
torch.randn(100, 100, device=device).sum()
torch.randn(100, 100, device=args.device).sum()
for dataset in itertools.chain(short_rows, long_rows):
download(dataset)
correctness(dataset)
......
......@@ -34,12 +34,22 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
return std::numeric_limits<scalar_t>::min();
return std::numeric_limits<scalar_t>::lowest();
} else {
return (scalar_t)0;
}
}
static inline __host__ __device__ void update(scalar_t *val,
scalar_t new_val) {
if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
}
}
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) {
if (REDUCE == ADD || REDUCE == MEAN) {
......@@ -68,9 +78,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
static inline __device__ void atomic_write(scalar_t *address, scalar_t val,
int64_t *arg_address,
int64_t arg) {
static inline __device__ void atomic_write(scalar_t *address, scalar_t val) {
if (REDUCE == ADD) {
atomAdd(address, val);
} else if (REDUCE == MEAN) {
......@@ -80,14 +88,6 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
} else if (REDUCE == MAX && val > *address) {
atomMax(address, val);
}
if (REDUCE == MIN || REDUCE == MAX) {
assert(false); // TODO
__syncthreads();
if (*address == val) {
*arg_address = arg;
}
}
}
};
......@@ -111,7 +111,7 @@ segment_csr_kernel(const scalar_t *src_data,
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init(), tmp;
scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg, arg_tmp;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
......@@ -123,16 +123,10 @@ segment_csr_kernel(const scalar_t *src_data,
#pragma unroll
for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp.
if (REDUCE == MIN || REDUCE == MAX) {
tmp = __shfl_down_sync(FULL_MASK, val, i);
if (REDUCE == MIN || REDUCE == MAX)
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
// Only update valid entries.
if (lane_idx < i && row_start + lane_idx + i < row_end)
Reducer<scalar_t, REDUCE>::update(&val, tmp, &arg, arg_tmp);
} else {
Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
}
Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
}
if (lane_idx == 0) {
......@@ -256,7 +250,7 @@ template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
__global__ void
segment_coo_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int64_t *arg_out_data, size_t E) {
scalar_t *out_data, size_t E, size_t N) {
// Each thread processes exactly one entry. Within a warp, we perform a
// parallel reduction across equal indices, and write the intermediate
......@@ -269,40 +263,52 @@ segment_coo_kernel(const scalar_t *src_data,
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int idx = index_info.data[offset], next_idx;
int out_idx = (row_idx / index_info.sizes[index_info.dims - 1]) * N + idx;
scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp;
int64_t arg, arg_tmp;
if (REDUCE == MIN || REDUCE == MAX) {
arg = row_idx % index_info.sizes[index_info.dims - 1];
}
#pragma unroll
for (int i = 1; i < 32; i *= 2) {
// Parallel reduction inside a single warp.
tmp = __shfl_up_sync(FULL_MASK, val, i);
if (REDUCE == MIN || REDUCE == MAX) {
arg_tmp = __shfl_up_sync(FULL_MASK, arg, i);
}
next_idx = __shfl_up_sync(FULL_MASK, idx, i);
assert(idx >= next_idx);
if (lane_idx >= i && idx == next_idx)
Reducer<scalar_t, REDUCE>::update(&val, tmp, &arg, arg_tmp);
Reducer<scalar_t, REDUCE>::update(&val, tmp);
}
next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
if (lane_idx == 32 - 1 || idx != next_idx) {
Reducer<scalar_t, REDUCE>::atomic_write(out_data + idx, val,
arg_out_data + idx, arg);
Reducer<scalar_t, REDUCE>::atomic_write(out_data + out_idx, val);
}
}
}
template <typename scalar_t>
__global__ void segment_coo_arg_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t N) {
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int idx = index_info.data[offset];
int out_idx = (row_idx / index_info.sizes[index_info.dims - 1]) * N + idx;
scalar_t val = __ldg(out_data + out_idx);
if (src_data[row_idx] == val)
arg_out_data[out_idx] = row_idx % index_info.sizes[index_info.dims - 1];
}
}
template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void segment_coo_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K) {
scalar_t *out_data, size_t E, size_t K, size_t N) {
// Each thread processes a single column and `TB` index entries. Coalesced
// read and write is performed in column-major order. The intermediate
......@@ -314,6 +320,7 @@ __global__ void segment_coo_broadcast_kernel(
if (row_start < E && col_idx < K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_start, index_info);
int out_idx = (row_start / index_info.sizes[index_info.dims - 1]) * N;
int idx1 = __ldg(index_info.data + offset);
scalar_t val = src_data[K * row_start + col_idx];
......@@ -327,15 +334,42 @@ __global__ void segment_coo_broadcast_kernel(
i * index_info.strides[index_info.dims - 1]);
assert(idx1 <= idx2);
if (idx1 == idx2) {
val += src_data[K * (row_start + i) + col_idx];
Reducer<scalar_t, REDUCE>::update(
&val, src_data[K * (row_start + i) + col_idx]);
} else {
atomAdd(out_data + K * idx1 + col_idx, val);
Reducer<scalar_t, REDUCE>::atomic_write(
out_data + (out_idx + idx1) * K + col_idx, val);
val = src_data[K * (row_start + i) + col_idx];
}
idx1 = idx2;
}
atomAdd(out_data + K * idx1 + col_idx, val);
Reducer<scalar_t, REDUCE>::atomic_write(
out_data + (out_idx + idx1) * K + col_idx, val);
}
}
template <typename scalar_t>
__global__ void segment_coo_arg_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K, size_t N) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int col_idx = thread_idx % K;
if (row_idx < E && col_idx < K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int idx = __ldg(index_info.data + offset);
int out_idx =
((row_idx / index_info.sizes[index_info.dims - 1]) * N + idx) * K +
col_idx;
scalar_t val = __ldg(out_data + out_idx);
if (src_data[thread_idx] == val)
arg_out_data[out_idx] = row_idx % index_info.sizes[index_info.dims - 1];
}
}
......@@ -371,6 +405,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
auto E = index.numel();
auto K = src.numel() / E;
auto N = out.size(reduce_dim);
auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
......@@ -383,25 +418,37 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
if (K == 1) {
segment_coo_kernel<scalar_t, REDUCE, true>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info,
out_data, arg_out_data, E);
out_data, E, N);
} else if (avg_len <= 8) {
segment_coo_broadcast_kernel<scalar_t, REDUCE, 4>
<<<dim3(((E + (8 * 4) - 1) / (8 * 4)), (K + 31) / 32), dim3(32, 8),
0, stream>>>(src_data, index_info, out_data, arg_out_data, E, K);
0, stream>>>(src_data, index_info, out_data, E, K, N);
} else if (avg_len <= 16) {
segment_coo_broadcast_kernel<scalar_t, REDUCE, 8>
<<<dim3(((E + (8 * 8) - 1) / (8 * 8)), (K + 31) / 32), dim3(32, 8),
0, stream>>>(src_data, index_info, out_data, arg_out_data, E, K);
0, stream>>>(src_data, index_info, out_data, E, K, N);
} else if (avg_len <= 32) {
segment_coo_broadcast_kernel<scalar_t, REDUCE, 16>
<<<dim3(((E + (8 * 16) - 1) / (8 * 16)), (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data,
arg_out_data, E, K);
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
} else {
segment_coo_broadcast_kernel<scalar_t, REDUCE, 32>
<<<dim3(((E + (8 * 32) - 1) / (8 * 32)), (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data,
arg_out_data, E, K);
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
}
if (REDUCE == MIN || REDUCE == MAX) {
if (K == 1) {
segment_coo_arg_kernel<scalar_t>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, N);
} else {
segment_coo_arg_broadcast_kernel<scalar_t>
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, K, N);
}
}
});
});
......@@ -415,12 +462,17 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
auto count_data = count.DATA_PTR<scalar_t>();
segment_coo_kernel<scalar_t, ADD, false>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
count_data, nullptr, E);
count_data, E, N);
});
count.clamp_(1);
out.div_(count);
arg_out = count;
for (int i = reduce_dim + 1; i < out.dim(); i++) {
count = count.unsqueeze(-1);
}
out.div_(count);
}
return std::make_tuple(out, arg_out);
......
......@@ -3,7 +3,7 @@ from itertools import product
import pytest
import torch
from torch_scatter import segment_coo, segment_csr
from torch_scatter import scatter_add, scatter_mean, scatter_min # noqa
from torch_scatter import scatter_max
from .utils import tensor
......@@ -18,24 +18,39 @@ def test_forward(dtype, device):
device)
src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
# src = tensor([-1, -2, -3, -4, -5, -6], dtype, device)
src.requires_grad_()
indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
# out = scatter_min(src, index, dim=0)[0]
out, arg = scatter_max(src, index, dim=0)
print('SCA')
print(out)
print(arg)
# print('SCA', out)
# grad_out = torch.randn_like(out)
# print(grad_out)
# out.backward(grad_out)
# print(src.grad)
src.grad = None
out = segment_csr(src, indptr, reduce='mean')
print('CSR', out)
# src.grad = None
out, arg = segment_coo(src, index, reduce='max')
print('COO')
print(out)
print(arg)
out, arg = segment_csr(src, indptr, reduce='max')
print('CSR')
print(out)
print(arg)
# out.backward(grad_out)
# print(src.grad)
# out = out[0] if isinstance(out, tuple) else out
# out.backward(torch.randn_like(out))
out = segment_coo(src, index, reduce='mean')
print('COO', out)
# out = segment_coo(src, index, reduce='max')[0]
# print('COO', out)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment