"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "eb74fa34b41cd2fa615e8a0f7b29616c7e1fdb0f"
Commit 2a7622b6 authored by rusty1s's avatar rusty1s
Browse files

max fix + compute capability 3.5

parent 52f2ad25
...@@ -6,7 +6,8 @@ import wget ...@@ -6,7 +6,8 @@ import wget
import torch import torch
from scipy.io import loadmat from scipy.io import loadmat
from torch_scatter import scatter_add, segment_csr, segment_coo from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
from torch_scatter import segment_coo, segment_csr
iters = 20 iters = 20
device = 'cuda' device = 'cuda'
...@@ -54,6 +55,33 @@ def correctness(dataset): ...@@ -54,6 +55,33 @@ def correctness(dataset):
assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
out1 = scatter_mean(x, row, dim=0, dim_size=dim_size)
out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean')
out3 = segment_csr(x, rowptr, reduce='mean')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
out1, arg_out1 = scatter_max(x, row, dim=0, dim_size=dim_size)
out3, arg_out3 = segment_csr(x, rowptr, reduce='max')
# print(out1[:5])
# print(out3[:5])
nnz = (out1 != out3).nonzero().flatten()
nnz1 = nnz[0].item()
print(rowptr[nnz1], rowptr[nnz1 + 1])
print(x[rowptr[nnz1]:rowptr[nnz1 + 1]])
print(x[rowptr[nnz1]:rowptr[nnz1 + 1]])
print(out1[nnz1])
print(out3[nnz1])
assert torch.allclose(out1, out3, atol=1e-4)
assert torch.all(arg_out1 == arg_out3)
except RuntimeError: except RuntimeError:
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -197,4 +225,4 @@ if __name__ == '__main__': ...@@ -197,4 +225,4 @@ if __name__ == '__main__':
for dataset in itertools.chain(short_rows, long_rows): for dataset in itertools.chain(short_rows, long_rows):
download(dataset) download(dataset)
correctness(dataset) correctness(dataset)
timing(dataset) # timing(dataset)
...@@ -111,7 +111,7 @@ segment_csr_kernel(const scalar_t *src_data, ...@@ -111,7 +111,7 @@ segment_csr_kernel(const scalar_t *src_data,
int row_end = __ldg(indptr_info.data + offset + int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]); indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init(); scalar_t val = Reducer<scalar_t, REDUCE>::init(), tmp;
int64_t arg, arg_tmp; int64_t arg, arg_tmp;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
...@@ -124,10 +124,14 @@ segment_csr_kernel(const scalar_t *src_data, ...@@ -124,10 +124,14 @@ segment_csr_kernel(const scalar_t *src_data,
for (int i = TB / 2; i > 0; i /= 2) { for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp. // Parallel reduction inside a single warp.
if (REDUCE == MIN || REDUCE == MAX) { if (REDUCE == MIN || REDUCE == MAX) {
tmp = __shfl_down_sync(FULL_MASK, val, i);
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i); arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
if (row_start + lane_idx + i < row_end)
Reducer<scalar_t, REDUCE>::update(&val, tmp, &arg, arg_tmp);
} else {
Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
} }
Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
} }
if (lane_idx == 0) { if (lane_idx == 0) {
...@@ -246,7 +250,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr, ...@@ -246,7 +250,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
} }
template <typename scalar_t, ReductionType REDUCE> template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
__global__ void __global__ void
segment_coo_kernel(const scalar_t *src_data, segment_coo_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info, const at::cuda::detail::TensorInfo<int64_t, int> index_info,
...@@ -264,8 +268,12 @@ segment_coo_kernel(const scalar_t *src_data, ...@@ -264,8 +268,12 @@ segment_coo_kernel(const scalar_t *src_data,
row_idx, index_info); row_idx, index_info);
int idx = index_info.data[offset], next_idx; int idx = index_info.data[offset], next_idx;
scalar_t val = src_data[row_idx], tmp; scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp;
int64_t arg = row_idx % index_info.sizes[index_info.dims - 1], arg_tmp; int64_t arg, arg_tmp;
if (REDUCE == MIN || REDUCE == MAX) {
arg = row_idx % index_info.sizes[index_info.dims - 1];
}
#pragma unroll #pragma unroll
for (int i = 1; i < 32; i *= 2) { for (int i = 1; i < 32; i *= 2) {
...@@ -298,7 +306,7 @@ __global__ void segment_coo_broadcast_kernel( ...@@ -298,7 +306,7 @@ __global__ void segment_coo_broadcast_kernel(
// read and write is performed in column-major order. The intermediate // read and write is performed in column-major order. The intermediate
// results are written via atomics. // results are written via atomics.
int row_start = blockIdx.x * (blockDim.y + threadIdx.y) * TB; int row_start = (blockIdx.x * blockDim.y + threadIdx.y) * TB;
int col_idx = blockIdx.y * blockDim.x + threadIdx.x; int col_idx = blockIdx.y * blockDim.x + threadIdx.x;
if (row_start < E && col_idx < K) { if (row_start < E && col_idx < K) {
...@@ -371,7 +379,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -371,7 +379,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (K == 1) { if (K == 1) {
segment_coo_kernel<scalar_t, REDUCE> segment_coo_kernel<scalar_t, REDUCE, true>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info, <<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info,
out_data, arg_out_data, E); out_data, arg_out_data, E);
} else if (avg_len <= 8) { } else if (avg_len <= 8) {
...@@ -397,12 +405,19 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -397,12 +405,19 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
}); });
if (reduce == "mean") { if (reduce == "mean") {
auto count = at::empty_like(index, out.options()); auto sizes = index.sizes().vec();
sizes[reduce_dim] = out.size(reduce_dim);
auto count = at::zeros(sizes, out.options());
AT_DISPATCH_ALL_TYPES(out.scalar_type(), "count_kernel", [&] { AT_DISPATCH_ALL_TYPES(out.scalar_type(), "count_kernel", [&] {
auto count_data = count.DATA_PTR<scalar_t>(); auto count_data = count.DATA_PTR<scalar_t>();
AT_ASSERTM(false); // TODO segment_coo_kernel<scalar_t, ADD, false>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
count_data, nullptr, E);
}); });
out = out / count;
count.clamp_(1);
out.div_(count);
arg_out = count; arg_out = count;
} }
......
import platform
import os.path as osp import os.path as osp
from glob import glob from glob import glob
from setuptools import setup, find_packages from setuptools import setup, find_packages
...@@ -10,27 +11,33 @@ USE_GPU = True ...@@ -10,27 +11,33 @@ USE_GPU = True
if '--cpu' in argv: if '--cpu' in argv:
USE_GPU = False USE_GPU = False
extra_compile_args = [] cxx_extra_compile_args = []
nvcc_extra_compile_args = ['-arch=sm_35']
if platform.system() != 'Windows':
cxx_extra_compile_args += ['-Wno-unused-variable']
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args += ['-DVERSION_GE_1_3'] cxx_extra_compile_args += ['-DVERSION_GE_1_3']
nvcc_extra_compile_args += ['-DVERSION_GE_1_3']
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
ext_modules = [] ext_modules = []
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))] exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))]
ext_modules += [ ext_modules += [
CppExtension(f'torch_scatter.{ext}_cpu', [f'cpu/{ext}.cpp'], CppExtension(f'torch_scatter.{ext}_cpu', [f'cpu/{ext}.cpp'],
extra_compile_args=extra_compile_args) for ext in exts extra_compile_args=cxx_extra_compile_args) for ext in exts
] ]
# ['-Wno-unused-variable'] if platform.system() != 'Windows' else []
if CUDA_HOME is not None and USE_GPU: if CUDA_HOME is not None and USE_GPU:
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cuda', '*.cpp'))] exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cuda', '*.cpp'))]
ext_modules += [ ext_modules += [
CUDAExtension(f'torch_scatter.{ext}_cuda', CUDAExtension(
[f'cuda/{ext}.cpp', f'cuda/{ext}_kernel.cu'], f'torch_scatter.{ext}_cuda',
extra_compile_args=extra_compile_args) for ext in exts [f'cuda/{ext}.cpp', f'cuda/{ext}_kernel.cu'], extra_compile_args={
'cxx': cxx_extra_compile_args,
'nvcc': nvcc_extra_compile_args,
}) for ext in exts
] ]
__version__ = '1.5.0' __version__ = '1.5.0'
......
...@@ -22,19 +22,20 @@ def test_forward(dtype, device): ...@@ -22,19 +22,20 @@ def test_forward(dtype, device):
indptr = tensor([0, 2, 5, 5, 6], torch.long, device) indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device) index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = scatter_min(src, index, dim=0)[0] # out = scatter_min(src, index, dim=0)[0]
grad_out = torch.randn_like(out) # grad_out = torch.randn_like(out)
print(grad_out) # print(grad_out)
out.backward(grad_out) # out.backward(grad_out)
print(src.grad) # print(src.grad)
src.grad = None src.grad = None
out = segment_csr(src, indptr, reduce='min')[0] out = segment_csr(src, indptr, reduce='mean')
out.backward(grad_out) print('CSR', out)
print(src.grad) # out.backward(grad_out)
# print(src.grad)
# out = out[0] if isinstance(out, tuple) else out # out = out[0] if isinstance(out, tuple) else out
# out.backward(torch.randn_like(out)) # out.backward(torch.randn_like(out))
out = segment_coo(src, index, reduce='any') out = segment_coo(src, index, reduce='mean')
print('COO', out) print('COO', out)
import torch import torch
from torch_scatter.utils import min_value, max_value from torch_scatter.helpers import min_value, max_value
if torch.cuda.is_available(): if torch.cuda.is_available():
from torch_scatter import segment_cuda, gather_cuda from torch_scatter import segment_cuda, gather_cuda
...@@ -63,12 +63,16 @@ def segment_coo(src, index, out=None, dim_size=None, reduce='add'): ...@@ -63,12 +63,16 @@ def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
fill_value = min_value(src.dtype) fill_value = min_value(src.dtype)
out = src.new_full(size, fill_value) out = src.new_full(size, fill_value)
out, arg_out = segment_cuda.segment_coo(src, index, out, reduce) out, arg_out = segment_cuda.segment_coo(src, index, out, reduce)
if fill_value != 0: if fill_value != 0:
out.masked_fill_(out == fill_value, 0) out.masked_fill_(out == fill_value, 0)
return out if arg_out is None else (out, arg_out) if reduce == 'min' or reduce == 'max':
return out, arg_out
else:
return out
def segment_csr(src, indptr, out=None, reduce='add'): def segment_csr(src, indptr, out=None, reduce='add'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment