Commit cd6d8d68 authored by rusty1s's avatar rusty1s
Browse files

all cuda kernels done

parent 7e82bc0e
# flake8: noqa
import time import time
import os.path as osp import os.path as osp
import itertools import itertools
import argparse
import wget import wget
import torch import torch
from scipy.io import loadmat from scipy.io import loadmat
import torch_scatter
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
from torch_scatter import segment_coo, segment_csr from torch_scatter import segment_coo, segment_csr
parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True)
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce
iters = 20 iters = 20
device = 'cuda'
sizes = [1, 16, 32, 64, 128, 256, 512] sizes = [1, 16, 32, 64, 128, 256, 512]
short_rows = [ short_rows = [
...@@ -40,13 +49,13 @@ def bold(text, flag=True): ...@@ -40,13 +49,13 @@ def bold(text, flag=True):
def correctness(dataset): def correctness(dataset):
group, name = dataset group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(device, torch.long) rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(device, torch.long) row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
dim_size = rowptr.size(0) - 1 dim_size = rowptr.size(0) - 1
for size in sizes: for size in sizes:
try: try:
x = torch.randn((row.size(0), size), device=device) x = torch.randn((row.size(0), size), device=args.device)
x = x.squeeze(-1) if size == 1 else x x = x.squeeze(-1) if size == 1 else x
out1 = scatter_add(x, row, dim=0, dim_size=dim_size) out1 = scatter_add(x, row, dim=0, dim_size=dim_size)
...@@ -63,92 +72,71 @@ def correctness(dataset): ...@@ -63,92 +72,71 @@ def correctness(dataset):
assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
# out1, arg_out1 = scatter_max(x, row, dim=0, dim_size=dim_size) x = x.abs_().mul_(-1)
# out3, arg_out3 = segment_csr(x, rowptr, reduce='max')
# print(out1[:5]) out1, arg_out1 = scatter_min(x, row, 0, torch.zeros_like(out1))
# print(out3[:5]) out2, arg_out2 = segment_coo(x, row, reduce='min')
out3, arg_out3 = segment_csr(x, rowptr, reduce='min')
# nnz = (out1 != out3).nonzero().flatten() assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
# nnz1 = nnz[0].item() x = x.abs_()
# print(rowptr[nnz1], rowptr[nnz1 + 1])
# print(x[rowptr[nnz1]:rowptr[nnz1 + 1]]) out1, arg_out1 = scatter_max(x, row, 0, torch.zeros_like(out1))
# print(x[rowptr[nnz1]:rowptr[nnz1 + 1]]) out2, arg_out2 = segment_coo(x, row, reduce='max')
out3, arg_out3 = segment_csr(x, rowptr, reduce='max')
# print(out1[nnz1]) assert torch.allclose(out1, out2, atol=1e-4)
# print(out3[nnz1]) assert torch.allclose(out1, out3, atol=1e-4)
# assert torch.allclose(out1, out3, atol=1e-4)
# assert torch.all(arg_out1 == arg_out3)
except RuntimeError: except RuntimeError:
torch.cuda.empty_cache() torch.cuda.empty_cache()
def time_func(func, x):
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
func(x)
torch.cuda.synchronize()
return time.perf_counter() - t
except RuntimeError:
torch.cuda.empty_cache()
return float('inf')
@torch.no_grad() @torch.no_grad()
def timing(dataset): def timing(dataset):
group, name = dataset group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(device, torch.long) rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(device, torch.long) row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
row_perm = row[torch.randperm(row.size(0))] row_perm = row[torch.randperm(row.size(0))]
dim_size = rowptr.size(0) - 1 dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size avg_row_len = row.size(0) / dim_size
sca_row = lambda x: getattr(torch_scatter, f'scatter_{args.reduce}')(
x, row, dim=0, dim_size=dim_size)
sca_col = lambda x: getattr(torch_scatter, f'scatter_{args.reduce}')(
x, row_perm, dim=0, dim_size=dim_size)
seg_coo = lambda x: segment_coo(x, row, reduce=args.reduce)
seg_csr = lambda x: segment_csr(x, rowptr, reduce=args.reduce)
dense1 = lambda x: getattr(torch, args.dense_reduce)(x, dim=-2)
dense2 = lambda x: getattr(torch, args.dense_reduce)(x, dim=-1)
t1, t2, t3, t4, t5, t6 = [], [], [], [], [], [] t1, t2, t3, t4, t5, t6 = [], [], [], [], [], []
for size in sizes: for size in sizes:
try: try:
x = torch.randn((row.size(0), size), device=device) x = torch.randn((row.size(0), size), device=args.device)
x = x.squeeze(-1) if size == 1 else x x = x.squeeze(-1) if size == 1 else x
try: t1 += [time_func(sca_row, x)]
torch.cuda.synchronize() t2 += [time_func(sca_col, x)]
t = time.perf_counter() t3 += [time_func(seg_coo, x)]
for _ in range(iters): t4 += [time_func(seg_csr, x)]
out = scatter_add(x, row, dim=0, dim_size=dim_size)
del out
torch.cuda.synchronize()
t1.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t1.append(float('inf'))
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = scatter_add(x, row_perm, dim=0, dim_size=dim_size)
del out
torch.cuda.synchronize()
t2.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t2.append(float('inf'))
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = segment_coo(x, row, dim_size=dim_size, reduce='any')
del out
torch.cuda.synchronize()
t3.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t3.append(float('inf'))
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = segment_csr(x, rowptr, reduce='any')
del out
torch.cuda.synchronize()
t4.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t4.append(float('inf'))
del x del x
...@@ -159,35 +147,11 @@ def timing(dataset): ...@@ -159,35 +147,11 @@ def timing(dataset):
try: try:
x = torch.randn((dim_size, int(avg_row_len + 1), size), x = torch.randn((dim_size, int(avg_row_len + 1), size),
device=device) device=args.device)
x = x.squeeze(-1) if size == 1 else x
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = x.sum(dim=1)
del out
torch.cuda.synchronize()
t5.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t5.append(float('inf'))
t5 += [time_func(dense1, x)]
x = x.view(dim_size, size, int(avg_row_len + 1)) x = x.view(dim_size, size, int(avg_row_len + 1))
x = x.squeeze(-2) if size == 1 else x t6 += [time_func(dense2, x)]
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = x.sum(dim=-1)
del out
torch.cuda.synchronize()
t6.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t6.append(float('inf'))
del x del x
...@@ -221,7 +185,7 @@ def timing(dataset): ...@@ -221,7 +185,7 @@ def timing(dataset):
if __name__ == '__main__': if __name__ == '__main__':
for _ in range(10): # Warmup. for _ in range(10): # Warmup.
torch.randn(100, 100, device=device).sum() torch.randn(100, 100, device=args.device).sum()
for dataset in itertools.chain(short_rows, long_rows): for dataset in itertools.chain(short_rows, long_rows):
download(dataset) download(dataset)
correctness(dataset) correctness(dataset)
......
...@@ -34,12 +34,22 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -34,12 +34,22 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
if (REDUCE == MIN) { if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max(); return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) { } else if (REDUCE == MAX) {
return std::numeric_limits<scalar_t>::min(); return std::numeric_limits<scalar_t>::lowest();
} else { } else {
return (scalar_t)0; return (scalar_t)0;
} }
} }
static inline __host__ __device__ void update(scalar_t *val,
scalar_t new_val) {
if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
}
}
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val, static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) { int64_t *arg, int64_t new_arg) {
if (REDUCE == ADD || REDUCE == MEAN) { if (REDUCE == ADD || REDUCE == MEAN) {
...@@ -68,9 +78,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -68,9 +78,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
} }
} }
static inline __device__ void atomic_write(scalar_t *address, scalar_t val, static inline __device__ void atomic_write(scalar_t *address, scalar_t val) {
int64_t *arg_address,
int64_t arg) {
if (REDUCE == ADD) { if (REDUCE == ADD) {
atomAdd(address, val); atomAdd(address, val);
} else if (REDUCE == MEAN) { } else if (REDUCE == MEAN) {
...@@ -80,14 +88,6 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -80,14 +88,6 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
} else if (REDUCE == MAX && val > *address) { } else if (REDUCE == MAX && val > *address) {
atomMax(address, val); atomMax(address, val);
} }
if (REDUCE == MIN || REDUCE == MAX) {
assert(false); // TODO
__syncthreads();
if (*address == val) {
*arg_address = arg;
}
}
} }
}; };
...@@ -111,7 +111,7 @@ segment_csr_kernel(const scalar_t *src_data, ...@@ -111,7 +111,7 @@ segment_csr_kernel(const scalar_t *src_data,
int row_end = __ldg(indptr_info.data + offset + int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]); indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init(), tmp; scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg, arg_tmp; int64_t arg, arg_tmp;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
...@@ -123,17 +123,11 @@ segment_csr_kernel(const scalar_t *src_data, ...@@ -123,17 +123,11 @@ segment_csr_kernel(const scalar_t *src_data,
#pragma unroll #pragma unroll
for (int i = TB / 2; i > 0; i /= 2) { for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp. // Parallel reduction inside a single warp.
if (REDUCE == MIN || REDUCE == MAX) { if (REDUCE == MIN || REDUCE == MAX)
tmp = __shfl_down_sync(FULL_MASK, val, i);
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i); arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
// Only update valid entries.
if (lane_idx < i && row_start + lane_idx + i < row_end)
Reducer<scalar_t, REDUCE>::update(&val, tmp, &arg, arg_tmp);
} else {
Reducer<scalar_t, REDUCE>::update( Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp); &val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
} }
}
if (lane_idx == 0) { if (lane_idx == 0) {
Reducer<scalar_t, REDUCE>::write(out_data + row_idx, val, Reducer<scalar_t, REDUCE>::write(out_data + row_idx, val,
...@@ -256,7 +250,7 @@ template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL> ...@@ -256,7 +250,7 @@ template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
__global__ void __global__ void
segment_coo_kernel(const scalar_t *src_data, segment_coo_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info, const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int64_t *arg_out_data, size_t E) { scalar_t *out_data, size_t E, size_t N) {
// Each thread processes exactly one entry. Within a warp, we perform a // Each thread processes exactly one entry. Within a warp, we perform a
// parallel reduction across equal indices, and write the intermediate // parallel reduction across equal indices, and write the intermediate
...@@ -269,32 +263,44 @@ segment_coo_kernel(const scalar_t *src_data, ...@@ -269,32 +263,44 @@ segment_coo_kernel(const scalar_t *src_data,
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get( int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info); row_idx, index_info);
int idx = index_info.data[offset], next_idx; int idx = index_info.data[offset], next_idx;
int out_idx = (row_idx / index_info.sizes[index_info.dims - 1]) * N + idx;
scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp; scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp;
int64_t arg, arg_tmp;
if (REDUCE == MIN || REDUCE == MAX) {
arg = row_idx % index_info.sizes[index_info.dims - 1];
}
#pragma unroll #pragma unroll
for (int i = 1; i < 32; i *= 2) { for (int i = 1; i < 32; i *= 2) {
// Parallel reduction inside a single warp. // Parallel reduction inside a single warp.
tmp = __shfl_up_sync(FULL_MASK, val, i); tmp = __shfl_up_sync(FULL_MASK, val, i);
if (REDUCE == MIN || REDUCE == MAX) {
arg_tmp = __shfl_up_sync(FULL_MASK, arg, i);
}
next_idx = __shfl_up_sync(FULL_MASK, idx, i); next_idx = __shfl_up_sync(FULL_MASK, idx, i);
assert(idx >= next_idx); assert(idx >= next_idx);
if (lane_idx >= i && idx == next_idx) if (lane_idx >= i && idx == next_idx)
Reducer<scalar_t, REDUCE>::update(&val, tmp, &arg, arg_tmp); Reducer<scalar_t, REDUCE>::update(&val, tmp);
} }
next_idx = __shfl_down_sync(FULL_MASK, idx, 1); next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
if (lane_idx == 32 - 1 || idx != next_idx) { if (lane_idx == 32 - 1 || idx != next_idx) {
Reducer<scalar_t, REDUCE>::atomic_write(out_data + idx, val, Reducer<scalar_t, REDUCE>::atomic_write(out_data + out_idx, val);
arg_out_data + idx, arg); }
} }
}
template <typename scalar_t>
__global__ void segment_coo_arg_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t N) {
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int idx = index_info.data[offset];
int out_idx = (row_idx / index_info.sizes[index_info.dims - 1]) * N + idx;
scalar_t val = __ldg(out_data + out_idx);
if (src_data[row_idx] == val)
arg_out_data[out_idx] = row_idx % index_info.sizes[index_info.dims - 1];
} }
} }
...@@ -302,7 +308,7 @@ template <typename scalar_t, ReductionType REDUCE, int TB> ...@@ -302,7 +308,7 @@ template <typename scalar_t, ReductionType REDUCE, int TB>
__global__ void segment_coo_broadcast_kernel( __global__ void segment_coo_broadcast_kernel(
const scalar_t *src_data, const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info, const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K) { scalar_t *out_data, size_t E, size_t K, size_t N) {
// Each thread processes a single column and `TB` index entries. Coalesced // Each thread processes a single column and `TB` index entries. Coalesced
// read and write is performed in column-major order. The intermediate // read and write is performed in column-major order. The intermediate
...@@ -314,6 +320,7 @@ __global__ void segment_coo_broadcast_kernel( ...@@ -314,6 +320,7 @@ __global__ void segment_coo_broadcast_kernel(
if (row_start < E && col_idx < K) { if (row_start < E && col_idx < K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get( int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_start, index_info); row_start, index_info);
int out_idx = (row_start / index_info.sizes[index_info.dims - 1]) * N;
int idx1 = __ldg(index_info.data + offset); int idx1 = __ldg(index_info.data + offset);
scalar_t val = src_data[K * row_start + col_idx]; scalar_t val = src_data[K * row_start + col_idx];
...@@ -327,15 +334,42 @@ __global__ void segment_coo_broadcast_kernel( ...@@ -327,15 +334,42 @@ __global__ void segment_coo_broadcast_kernel(
i * index_info.strides[index_info.dims - 1]); i * index_info.strides[index_info.dims - 1]);
assert(idx1 <= idx2); assert(idx1 <= idx2);
if (idx1 == idx2) { if (idx1 == idx2) {
val += src_data[K * (row_start + i) + col_idx]; Reducer<scalar_t, REDUCE>::update(
&val, src_data[K * (row_start + i) + col_idx]);
} else { } else {
atomAdd(out_data + K * idx1 + col_idx, val); Reducer<scalar_t, REDUCE>::atomic_write(
out_data + (out_idx + idx1) * K + col_idx, val);
val = src_data[K * (row_start + i) + col_idx]; val = src_data[K * (row_start + i) + col_idx];
} }
idx1 = idx2; idx1 = idx2;
} }
atomAdd(out_data + K * idx1 + col_idx, val); Reducer<scalar_t, REDUCE>::atomic_write(
out_data + (out_idx + idx1) * K + col_idx, val);
}
}
template <typename scalar_t>
__global__ void segment_coo_arg_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K, size_t N) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int col_idx = thread_idx % K;
if (row_idx < E && col_idx < K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int idx = __ldg(index_info.data + offset);
int out_idx =
((row_idx / index_info.sizes[index_info.dims - 1]) * N + idx) * K +
col_idx;
scalar_t val = __ldg(out_data + out_idx);
if (src_data[thread_idx] == val)
arg_out_data[out_idx] = row_idx % index_info.sizes[index_info.dims - 1];
} }
} }
...@@ -371,6 +405,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -371,6 +405,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
auto E = index.numel(); auto E = index.numel();
auto K = src.numel() / E; auto K = src.numel() / E;
auto N = out.size(reduce_dim);
auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim); auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index); auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
...@@ -383,25 +418,37 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -383,25 +418,37 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
if (K == 1) { if (K == 1) {
segment_coo_kernel<scalar_t, REDUCE, true> segment_coo_kernel<scalar_t, REDUCE, true>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info, <<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info,
out_data, arg_out_data, E); out_data, E, N);
} else if (avg_len <= 8) { } else if (avg_len <= 8) {
segment_coo_broadcast_kernel<scalar_t, REDUCE, 4> segment_coo_broadcast_kernel<scalar_t, REDUCE, 4>
<<<dim3(((E + (8 * 4) - 1) / (8 * 4)), (K + 31) / 32), dim3(32, 8), <<<dim3(((E + (8 * 4) - 1) / (8 * 4)), (K + 31) / 32), dim3(32, 8),
0, stream>>>(src_data, index_info, out_data, arg_out_data, E, K); 0, stream>>>(src_data, index_info, out_data, E, K, N);
} else if (avg_len <= 16) { } else if (avg_len <= 16) {
segment_coo_broadcast_kernel<scalar_t, REDUCE, 8> segment_coo_broadcast_kernel<scalar_t, REDUCE, 8>
<<<dim3(((E + (8 * 8) - 1) / (8 * 8)), (K + 31) / 32), dim3(32, 8), <<<dim3(((E + (8 * 8) - 1) / (8 * 8)), (K + 31) / 32), dim3(32, 8),
0, stream>>>(src_data, index_info, out_data, arg_out_data, E, K); 0, stream>>>(src_data, index_info, out_data, E, K, N);
} else if (avg_len <= 32) { } else if (avg_len <= 32) {
segment_coo_broadcast_kernel<scalar_t, REDUCE, 16> segment_coo_broadcast_kernel<scalar_t, REDUCE, 16>
<<<dim3(((E + (8 * 16) - 1) / (8 * 16)), (K + 31) / 32), <<<dim3(((E + (8 * 16) - 1) / (8 * 16)), (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
arg_out_data, E, K); N);
} else { } else {
segment_coo_broadcast_kernel<scalar_t, REDUCE, 32> segment_coo_broadcast_kernel<scalar_t, REDUCE, 32>
<<<dim3(((E + (8 * 32) - 1) / (8 * 32)), (K + 31) / 32), <<<dim3(((E + (8 * 32) - 1) / (8 * 32)), (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
arg_out_data, E, K); N);
}
if (REDUCE == MIN || REDUCE == MAX) {
if (K == 1) {
segment_coo_arg_kernel<scalar_t>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, N);
} else {
segment_coo_arg_broadcast_kernel<scalar_t>
<<<BLOCKS(1, E * K), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, K, N);
}
} }
}); });
}); });
...@@ -415,12 +462,17 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -415,12 +462,17 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
auto count_data = count.DATA_PTR<scalar_t>(); auto count_data = count.DATA_PTR<scalar_t>();
segment_coo_kernel<scalar_t, ADD, false> segment_coo_kernel<scalar_t, ADD, false>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info, <<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
count_data, nullptr, E); count_data, E, N);
}); });
count.clamp_(1); count.clamp_(1);
out.div_(count);
arg_out = count; arg_out = count;
for (int i = reduce_dim + 1; i < out.dim(); i++) {
count = count.unsqueeze(-1);
}
out.div_(count);
} }
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
......
...@@ -3,7 +3,7 @@ from itertools import product ...@@ -3,7 +3,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_scatter import segment_coo, segment_csr from torch_scatter import segment_coo, segment_csr
from torch_scatter import scatter_add, scatter_mean, scatter_min # noqa from torch_scatter import scatter_max
from .utils import tensor from .utils import tensor
...@@ -18,24 +18,39 @@ def test_forward(dtype, device): ...@@ -18,24 +18,39 @@ def test_forward(dtype, device):
device) device)
src = tensor([1, 2, 3, 4, 5, 6], dtype, device) src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
# src = tensor([-1, -2, -3, -4, -5, -6], dtype, device)
src.requires_grad_() src.requires_grad_()
indptr = tensor([0, 2, 5, 5, 6], torch.long, device) indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device) index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
# out = scatter_min(src, index, dim=0)[0] out, arg = scatter_max(src, index, dim=0)
print('SCA')
print(out)
print(arg)
# print('SCA', out)
# grad_out = torch.randn_like(out) # grad_out = torch.randn_like(out)
# print(grad_out) # print(grad_out)
# out.backward(grad_out) # out.backward(grad_out)
# print(src.grad) # print(src.grad)
src.grad = None # src.grad = None
out = segment_csr(src, indptr, reduce='mean') out, arg = segment_coo(src, index, reduce='max')
print('CSR', out) print('COO')
print(out)
print(arg)
out, arg = segment_csr(src, indptr, reduce='max')
print('CSR')
print(out)
print(arg)
# out.backward(grad_out) # out.backward(grad_out)
# print(src.grad) # print(src.grad)
# out = out[0] if isinstance(out, tuple) else out # out = out[0] if isinstance(out, tuple) else out
# out.backward(torch.randn_like(out)) # out.backward(torch.randn_like(out))
out = segment_coo(src, index, reduce='mean') # out = segment_coo(src, index, reduce='max')[0]
print('COO', out) # print('COO', out)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment