Commit 2a7622b6 authored by rusty1s's avatar rusty1s
Browse files

max fix + compute capability 3.5

parent 52f2ad25
......@@ -6,7 +6,8 @@ import wget
import torch
from scipy.io import loadmat
from torch_scatter import scatter_add, segment_csr, segment_coo
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
from torch_scatter import segment_coo, segment_csr
iters = 20
device = 'cuda'
......@@ -54,6 +55,33 @@ def correctness(dataset):
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
out1 = scatter_mean(x, row, dim=0, dim_size=dim_size)
out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean')
out3 = segment_csr(x, rowptr, reduce='mean')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
out1, arg_out1 = scatter_max(x, row, dim=0, dim_size=dim_size)
out3, arg_out3 = segment_csr(x, rowptr, reduce='max')
# print(out1[:5])
# print(out3[:5])
nnz = (out1 != out3).nonzero().flatten()
nnz1 = nnz[0].item()
print(rowptr[nnz1], rowptr[nnz1 + 1])
print(x[rowptr[nnz1]:rowptr[nnz1 + 1]])
print(x[rowptr[nnz1]:rowptr[nnz1 + 1]])
print(out1[nnz1])
print(out3[nnz1])
assert torch.allclose(out1, out3, atol=1e-4)
assert torch.all(arg_out1 == arg_out3)
except RuntimeError:
torch.cuda.empty_cache()
......@@ -197,4 +225,4 @@ if __name__ == '__main__':
for dataset in itertools.chain(short_rows, long_rows):
download(dataset)
correctness(dataset)
timing(dataset)
# timing(dataset)
......@@ -111,7 +111,7 @@ segment_csr_kernel(const scalar_t *src_data,
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init();
scalar_t val = Reducer<scalar_t, REDUCE>::init(), tmp;
int64_t arg, arg_tmp;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
......@@ -124,11 +124,15 @@ segment_csr_kernel(const scalar_t *src_data,
for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp.
if (REDUCE == MIN || REDUCE == MAX) {
tmp = __shfl_down_sync(FULL_MASK, val, i);
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
}
if (row_start + lane_idx + i < row_end)
Reducer<scalar_t, REDUCE>::update(&val, tmp, &arg, arg_tmp);
} else {
Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
}
}
if (lane_idx == 0) {
Reducer<scalar_t, REDUCE>::write(out_data + row_idx, val,
......@@ -246,7 +250,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
return std::make_tuple(out, arg_out);
}
template <typename scalar_t, ReductionType REDUCE>
template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
__global__ void
segment_coo_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
......@@ -264,8 +268,12 @@ segment_coo_kernel(const scalar_t *src_data,
row_idx, index_info);
int idx = index_info.data[offset], next_idx;
scalar_t val = src_data[row_idx], tmp;
int64_t arg = row_idx % index_info.sizes[index_info.dims - 1], arg_tmp;
scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp;
int64_t arg, arg_tmp;
if (REDUCE == MIN || REDUCE == MAX) {
arg = row_idx % index_info.sizes[index_info.dims - 1];
}
#pragma unroll
for (int i = 1; i < 32; i *= 2) {
......@@ -298,7 +306,7 @@ __global__ void segment_coo_broadcast_kernel(
// read and write is performed in column-major order. The intermediate
// results are written via atomics.
int row_start = blockIdx.x * (blockDim.y + threadIdx.y) * TB;
int row_start = (blockIdx.x * blockDim.y + threadIdx.y) * TB;
int col_idx = blockIdx.y * blockDim.x + threadIdx.x;
if (row_start < E && col_idx < K) {
......@@ -371,7 +379,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (K == 1) {
segment_coo_kernel<scalar_t, REDUCE>
segment_coo_kernel<scalar_t, REDUCE, true>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info,
out_data, arg_out_data, E);
} else if (avg_len <= 8) {
......@@ -397,12 +405,19 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
});
if (reduce == "mean") {
auto count = at::empty_like(index, out.options());
auto sizes = index.sizes().vec();
sizes[reduce_dim] = out.size(reduce_dim);
auto count = at::zeros(sizes, out.options());
AT_DISPATCH_ALL_TYPES(out.scalar_type(), "count_kernel", [&] {
auto count_data = count.DATA_PTR<scalar_t>();
AT_ASSERTM(false); // TODO
segment_coo_kernel<scalar_t, ADD, false>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
count_data, nullptr, E);
});
out = out / count;
count.clamp_(1);
out.div_(count);
arg_out = count;
}
......
import platform
import os.path as osp
from glob import glob
from setuptools import setup, find_packages
......@@ -10,27 +11,33 @@ USE_GPU = True
if '--cpu' in argv:
USE_GPU = False
extra_compile_args = []
cxx_extra_compile_args = []
nvcc_extra_compile_args = ['-arch=sm_35']
if platform.system() != 'Windows':
cxx_extra_compile_args += ['-Wno-unused-variable']
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args += ['-DVERSION_GE_1_3']
cxx_extra_compile_args += ['-DVERSION_GE_1_3']
nvcc_extra_compile_args += ['-DVERSION_GE_1_3']
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
ext_modules = []
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))]
ext_modules += [
CppExtension(f'torch_scatter.{ext}_cpu', [f'cpu/{ext}.cpp'],
extra_compile_args=extra_compile_args) for ext in exts
extra_compile_args=cxx_extra_compile_args) for ext in exts
]
# ['-Wno-unused-variable'] if platform.system() != 'Windows' else []
if CUDA_HOME is not None and USE_GPU:
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cuda', '*.cpp'))]
ext_modules += [
CUDAExtension(f'torch_scatter.{ext}_cuda',
[f'cuda/{ext}.cpp', f'cuda/{ext}_kernel.cu'],
extra_compile_args=extra_compile_args) for ext in exts
CUDAExtension(
f'torch_scatter.{ext}_cuda',
[f'cuda/{ext}.cpp', f'cuda/{ext}_kernel.cu'], extra_compile_args={
'cxx': cxx_extra_compile_args,
'nvcc': nvcc_extra_compile_args,
}) for ext in exts
]
__version__ = '1.5.0'
......
......@@ -22,19 +22,20 @@ def test_forward(dtype, device):
indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = scatter_min(src, index, dim=0)[0]
grad_out = torch.randn_like(out)
print(grad_out)
out.backward(grad_out)
print(src.grad)
# out = scatter_min(src, index, dim=0)[0]
# grad_out = torch.randn_like(out)
# print(grad_out)
# out.backward(grad_out)
# print(src.grad)
src.grad = None
out = segment_csr(src, indptr, reduce='min')[0]
out.backward(grad_out)
print(src.grad)
out = segment_csr(src, indptr, reduce='mean')
print('CSR', out)
# out.backward(grad_out)
# print(src.grad)
# out = out[0] if isinstance(out, tuple) else out
# out.backward(torch.randn_like(out))
out = segment_coo(src, index, reduce='any')
out = segment_coo(src, index, reduce='mean')
print('COO', out)
import torch
from torch_scatter.utils import min_value, max_value
from torch_scatter.helpers import min_value, max_value
if torch.cuda.is_available():
from torch_scatter import segment_cuda, gather_cuda
......@@ -63,12 +63,16 @@ def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
fill_value = min_value(src.dtype)
out = src.new_full(size, fill_value)
out, arg_out = segment_cuda.segment_coo(src, index, out, reduce)
if fill_value != 0:
out.masked_fill_(out == fill_value, 0)
return out if arg_out is None else (out, arg_out)
if reduce == 'min' or reduce == 'max':
return out, arg_out
else:
return out
def segment_csr(src, indptr, out=None, reduce='add'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment