Commit 9a91c42d authored by rusty1s's avatar rusty1s
Browse files

reduce op in segment_csr

parent 611b2994
......@@ -7,7 +7,7 @@ from scipy.io import loadmat
import torch
from torch_scatter import scatter_add
from torch_scatter.segment import segment_add_csr, segment_add_coo
from torch_scatter import segment_csr, segment_coo
iters = 20
device = 'cuda'
......@@ -51,8 +51,8 @@ def correctness(dataset):
x = x.unsqueeze(-1) if size == 1 else x
out1 = scatter_add(x, row, dim=0, dim_size=dim_size)
out2 = segment_add_coo(x, row, dim_size=dim_size)
out3 = segment_add_csr(x, rowptr)
out2 = segment_coo(x, row, dim_size=dim_size)
out3 = segment_csr(x, rowptr)
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
......@@ -104,7 +104,7 @@ def timing(dataset):
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = segment_add_coo(x, row, dim_size=dim_size)
out = segment_coo(x, row, dim_size=dim_size)
del out
torch.cuda.synchronize()
t3.append(time.perf_counter() - t)
......@@ -116,7 +116,7 @@ def timing(dataset):
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = segment_add_csr(x, rowptr)
out = segment_csr(x, rowptr)
del out
torch.cuda.synchronize()
t4.append(time.perf_counter() - t)
......
......@@ -2,28 +2,33 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt);
at::Tensor segment_add_coo_cuda(at::Tensor src, at::Tensor index,
at::Tensor out);
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt, std::string reduce);
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
std::string reduce);
at::Tensor segment_add_csr(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt) {
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(indptr);
if (out_opt.has_value())
CHECK_CUDA(out_opt.value());
return segment_add_csr_cuda(src, indptr, out_opt);
return segment_csr_cuda(src, indptr, out_opt, reduce);
}
at::Tensor segment_add_coo(at::Tensor src, at::Tensor index, at::Tensor out) {
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(index);
CHECK_CUDA(out);
return segment_add_coo_cuda(src, index, out);
return segment_coo_cuda(src, index, out, reduce);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("segment_add_csr", &segment_add_csr, "Segment Add CSR (CUDA)");
m.def("segment_add_coo", &segment_add_coo, "Segment Add COO (CUDA)");
m.def("segment_csr", &segment_csr, "Segment CSR (CUDA)");
m.def("segment_coo", &segment_coo, "Segment COO (CUDA)");
}
......@@ -10,6 +10,11 @@
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
#define ADD 0
#define MEAN 1
#define MIN 2
#define MAX 3
// We need our own `IndexToOffset` implementation since we do not want to access
// the last element of the `indexptr`.
template <typename T, typename I> struct IndexPtrToOffset {
......@@ -26,14 +31,13 @@ template <typename T, typename I> struct IndexPtrToOffset {
}
};
template <typename scalar_t, int TB>
template <typename scalar_t, int REDUCE, int TB>
__global__ void segment_add_csr_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t E) {
scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t E) {
// Each warp processes exactly `32/TB` rows. We usually set `TB=32` and only
// make use of it in case the average row length is less than 32.
// Each warp processes exactly `32/TB` rows.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB;
......@@ -44,30 +48,90 @@ __global__ void segment_add_csr_kernel(
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = (scalar_t)0;
scalar_t val, tmp;
int64_t arg_val, arg_tmp;
if (REDUCE == ADD) {
val = (scalar_t)0;
} else if (REDUCE == MEAN) {
val = (scalar_t)0;
} else if (REDUCE == MIN) {
val = std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
val = std::numeric_limits<scalar_t>::min();
}
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
val += src_data[offset + src_idx]; // "Mostly" coalesced read.
tmp = src_data[offset + src_idx]; // "Mostly" coalesced read.
if (REDUCE == ADD) {
val += tmp;
} else if (REDUCE == MEAN) {
val += tmp;
} else if (REDUCE == MIN && tmp < val) {
val = tmp;
arg_val = src_idx;
} else if (REDUCE == MAX && tmp > val) {
val = tmp;
arg_val = src_idx;
}
}
#pragma unroll
for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp.
val += __shfl_down_sync(FULL_MASK, val, i);
tmp = __shfl_down_sync(FULL_MASK, val, i);
if (REDUCE == ADD) {
val += tmp;
} else if (REDUCE == MEAN) {
val += tmp;
} else if (REDUCE == MIN) {
arg_tmp = __shfl_down_sync(FULL_MASK, arg_val, i);
if (tmp < val) {
val = tmp;
arg_val = arg_tmp;
}
} else if (REDUCE == MAX) {
arg_tmp = __shfl_down_sync(FULL_MASK, arg_val, i);
if (tmp > val) {
val = tmp;
arg_val = arg_tmp;
}
}
}
if (lane_idx == 0) {
out_data[row_idx] = val; // "Mostly" coalesced write.
// "Mostly" coalesced write.
if (REDUCE == ADD) {
out_data[row_idx] = val;
} else if (REDUCE == MEAN) {
out_data[row_idx] = val / (scalar_t)max(row_end - row_start, 1);
} else if (REDUCE == MIN) {
if (row_end - row_start > 0) {
out_data[row_idx] = val;
arg_out_data[row_idx] = arg_val;
} else {
out_data[row_idx] = 0;
}
} else if (REDUCE == MAX) {
if (row_end - row_start > 0) {
out_data[row_idx] = val;
arg_out_data[row_idx] = arg_val;
} else {
out_data[row_idx] = 0;
}
}
}
}
}
template <typename scalar_t>
template <typename scalar_t, int REDUCE>
__global__ void segment_add_csr_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
scalar_t *out_data, size_t N, size_t K, size_t E) {
scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) {
// Each thread processes exactly one row. It turned out that is more efficient
// than using shared memory due to avoiding synchronization barriers.
......@@ -81,19 +145,62 @@ __global__ void segment_add_csr_broadcast_kernel(
int row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = (scalar_t)0;
scalar_t val, tmp;
int64_t arg_val;
if (REDUCE == ADD) {
val = (scalar_t)0;
} else if (REDUCE == MEAN) {
val = (scalar_t)0;
} else if (REDUCE == MIN) {
val = std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
val = std::numeric_limits<scalar_t>::min();
}
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int src_idx = row_start; src_idx < row_end; src_idx++) {
val += src_data[offset + K * src_idx + lane_idx]; // Coalesced read.
tmp = src_data[offset + K * src_idx + lane_idx]; // Coalesced read.
if (REDUCE == ADD) {
val += tmp;
} else if (REDUCE == MEAN) {
val += tmp;
} else if (REDUCE == MIN && tmp < val) {
val = tmp;
arg_val = src_idx;
} else if (REDUCE == MAX && tmp > val) {
val = tmp;
arg_val = src_idx;
}
}
out_data[thread_idx] = val; // Coalesced write.
// Coalesced write.
if (REDUCE == ADD) {
out_data[thread_idx] = val;
} else if (REDUCE == MEAN) {
out_data[thread_idx] = val / (scalar_t)max(row_end - row_start, 1);
} else if (REDUCE == MIN) {
if (row_end - row_start > 0) {
out_data[thread_idx] = val;
arg_out_data[thread_idx] = arg_val;
} else {
out_data[thread_idx] = 0;
}
} else if (REDUCE == MAX) {
if (row_end - row_start > 0) {
out_data[thread_idx] = val;
arg_out_data[thread_idx] = arg_val;
} else {
out_data[thread_idx] = 0;
}
}
}
}
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt) {
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt, std::string reduce) {
AT_ASSERTM(src.dim() >= indptr.dim());
for (int i = 0; i < indptr.dim() - 1; i++)
......@@ -104,7 +211,7 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr,
at::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value();
out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i));
......@@ -115,10 +222,15 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr,
out = at::empty(sizes, src.options());
}
at::optional<at::Tensor> arg_out = at::nullopt;
if (reduce == "min" || reduce == "max") {
arg_out = at::full_like(out, src.size(reduce_dim), indptr.options());
}
auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(reduce_dim);
auto avg_length = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
// auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
auto stream = at::cuda::getCurrentCUDAStream();
......@@ -126,37 +238,56 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr,
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
// Select the right kernel based on average row length and whether we need
// Select the right kernel based on the reduce operation and whether we need
// broadcasting capabilties (K > 1):
if (K == 1 && avg_length <= 4) {
segment_add_csr_kernel<scalar_t, 4><<<BLOCKS(4, N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, N, E);
} else if (K == 1 && avg_length <= 8) {
segment_add_csr_kernel<scalar_t, 8><<<BLOCKS(8, N), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, N, E);
} else if (K == 1 && avg_length <= 16) {
segment_add_csr_kernel<scalar_t, 16>
<<<BLOCKS(16, N), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, E);
} else if (K == 1) {
segment_add_csr_kernel<scalar_t, 32>
if (K == 1 && reduce == "add") {
segment_add_csr_kernel<scalar_t, ADD, 1>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, E);
} else {
segment_add_csr_broadcast_kernel<scalar_t>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, N, K, E);
out_data, nullptr, N, E);
} else if (K == 1 && reduce == "mean") {
segment_add_csr_kernel<scalar_t, MEAN, 1>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, nullptr, N, E);
} else if (K == 1 && reduce == "min") {
auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
segment_add_csr_kernel<scalar_t, MIN, 1>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, arg_out_data, N, E);
} else if (K == 1 && reduce == "max") {
auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
segment_add_csr_kernel<scalar_t, MAX, 1>
<<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
out_data, arg_out_data, N, E);
} else if (reduce == "add") {
segment_add_csr_broadcast_kernel<scalar_t, ADD>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, nullptr, N, K, E);
} else if (reduce == "mean") {
segment_add_csr_broadcast_kernel<scalar_t, MEAN>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, nullptr, N, K, E);
} else if (reduce == "min") {
auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
segment_add_csr_broadcast_kernel<scalar_t, MIN>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, K, E);
} else if (reduce == "max") {
auto arg_out_data = arg_out.value().DATA_PTR<int64_t>();
segment_add_csr_broadcast_kernel<scalar_t, MAX>
<<<BLOCKS(1, N * K), THREADS, 0, stream>>>(
src_data, indptr_info, out_data, arg_out_data, N, K, E);
}
});
return out;
return std::make_tuple(out, arg_out);
}
template <typename scalar_t>
template <typename scalar_t, int REDUCE>
__global__ void segment_add_coo_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E) {
scalar_t *out_data, int64_t *arg_out_data, size_t E) {
// Each thread processes exactly one entry. Within a warp, we perform a
// parallel reduction across equal indices, and write the intermediate
......@@ -187,11 +318,11 @@ __global__ void segment_add_coo_kernel(
}
}
template <typename scalar_t, int TB>
template <typename scalar_t, int REDUCE, int TB>
__global__ void segment_add_coo_broadcast_kernel(
const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, size_t E, size_t K) {
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t K) {
// Each thread processes a single column and `TB` rows. Coalesced read and
// write is performed in column-major order. The intermediate results are
......@@ -228,49 +359,60 @@ __global__ void segment_add_coo_broadcast_kernel(
}
}
at::Tensor segment_add_coo_cuda(at::Tensor src, at::Tensor index,
at::Tensor out) {
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
std::string reduce) {
AT_ASSERTM(src.dim() >= index.dim());
for (int i = 0; i < index.dim(); i++)
AT_ASSERTM(src.size(i) == index.size(i));
src = src.contiguous();
out = out.contiguous();
auto reduce_dim = index.dim() - 1;
for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i));
at::optional<at::Tensor> arg_out = at::nullopt;
if (reduce == "min" || reduce == "max") {
arg_out = at::full_like(out, src.size(reduce_dim), index.options());
}
auto E = index.numel();
auto K = src.numel() / index.numel();
auto avg_length = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_coo_kernel", [&] {
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
// Select the right kernel based on average row length (purely heuristic)
// and whether we need broadcasting capabilties (K > 1):
if (K == 1)
segment_add_coo_kernel<scalar_t><<<BLOCKS(1, E), THREADS, 0, stream>>>(
src_data, index_info, out_data, E);
else if (avg_length <= 8)
segment_add_coo_broadcast_kernel<scalar_t, 4>
segment_add_coo_kernel<scalar_t, ADD>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info, out_data,
nullptr, E);
else if (avg_len <= 8)
segment_add_coo_broadcast_kernel<scalar_t, ADD, 4>
<<<dim3(((E + (8 * 4) - 1) / (8 * 4)), (K + 31) / 32), dim3(32, 8), 0,
stream>>>(src_data, index_info, out_data, E, K);
else if (avg_length <= 16)
segment_add_coo_broadcast_kernel<scalar_t, 8>
stream>>>(src_data, index_info, out_data, nullptr, E, K);
else if (avg_len <= 16)
segment_add_coo_broadcast_kernel<scalar_t, ADD, 8>
<<<dim3(((E + (8 * 8) - 1) / (8 * 8)), (K + 31) / 32), dim3(32, 8), 0,
stream>>>(src_data, index_info, out_data, E, K);
else if (avg_length <= 32)
segment_add_coo_broadcast_kernel<scalar_t, 16>
stream>>>(src_data, index_info, out_data, nullptr, E, K);
else if (avg_len <= 32)
segment_add_coo_broadcast_kernel<scalar_t, ADD, 16>
<<<dim3(((E + (8 * 16) - 1) / (8 * 16)), (K + 31) / 32), dim3(32, 8),
0, stream>>>(src_data, index_info, out_data, E, K);
0, stream>>>(src_data, index_info, out_data, nullptr, E, K);
else
segment_add_coo_broadcast_kernel<scalar_t, 32>
segment_add_coo_broadcast_kernel<scalar_t, ADD, 32>
<<<dim3(((E + (8 * 32) - 1) / (8 * 32)), (K + 31) / 32), dim3(32, 8),
0, stream>>>(src_data, index_info, out_data, E, K);
0, stream>>>(src_data, index_info, out_data, nullptr, E, K);
});
return out;
return std::make_tuple(out, arg_out);
}
import time
from itertools import product
import pytest
import torch
from torch_scatter import segment_add, scatter_add
from torch_scatter.segment import segment_add_csr, segment_add_coo
from torch_scatter import segment_coo, segment_csr
from .utils import tensor
......@@ -14,101 +12,109 @@ devices = [torch.device('cuda')]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward(dtype, device):
src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
device)
indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
# src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
# device)
out = segment_add_csr(src, indptr)
print('CSR', out)
src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = segment_add_coo(src, index)
print('COO', out)
# out = segment_coo(src, index)
# print('COO', out)
out = segment_csr(src, indptr, reduce='add')
print('CSR', out)
out = segment_csr(src, indptr, reduce='mean')
print('CSR', out)
out = segment_csr(src, indptr, reduce='min')
print('CSR', out)
out = segment_csr(src, indptr, reduce='max')
print('CSR', out)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_benchmark(dtype, device):
from torch_geometric.datasets import Planetoid, Reddit # noqa
# data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
row, col = data.edge_index
print(data.num_edges)
print(row.size(0) / data.num_nodes)
num_repeats = 1
row = row.view(-1, 1).repeat(1, num_repeats).view(-1).contiguous()
col = col.view(-1, 1).repeat(1, num_repeats).view(-1).contiguous()
# Warmup
for _ in range(10):
torch.randn(100, 100, device=device).sum()
x = torch.randn(row.size(0), device=device)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out1 = scatter_add(x, row, dim=0, dim_size=data.num_nodes)
torch.cuda.synchronize()
print('Scatter Row', time.perf_counter() - t)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
scatter_add(x, col, dim=0, dim_size=data.num_nodes)
torch.cuda.synchronize()
print('Scatter Col', time.perf_counter() - t)
rowcount = segment_add(torch.ones_like(row), row)
rowptr = torch.cat([rowcount.new_zeros(1), rowcount.cumsum(0)], dim=0)
torch.cuda.synchronize()
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out3 = segment_add_csr(x, rowptr)
torch.cuda.synchronize()
print('CSR', time.perf_counter() - t)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out4 = segment_add_coo(x, row, dim_size=data.num_nodes)
torch.cuda.synchronize()
print('COO', time.perf_counter() - t)
assert torch.allclose(out1, out3, atol=1e-2)
assert torch.allclose(out1, out4, atol=1e-2)
x = torch.randn((row.size(0), 64), device=device)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out5 = scatter_add(x, row, dim=0, dim_size=data.num_nodes)
torch.cuda.synchronize()
print('Scatter Row + Dim', time.perf_counter() - t)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
scatter_add(x, col, dim=0, dim_size=data.num_nodes)
torch.cuda.synchronize()
print('Scatter Col + Dim', time.perf_counter() - t)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out6 = segment_add_csr(x, rowptr)
torch.cuda.synchronize()
print('CSR + Dim', time.perf_counter() - t)
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
out7 = segment_add_coo(x, row, dim_size=data.num_nodes)
torch.cuda.synchronize()
print('COO + Dim', time.perf_counter() - t)
assert torch.allclose(out5, out6, atol=1e-2)
assert torch.allclose(out5, out7, atol=1e-2)
# @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
# def test_benchmark(dtype, device):
# from torch_geometric.datasets import Planetoid, Reddit # noqa
# # data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
# data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
# row, col = data.edge_index
# print(data.num_edges)
# print(row.size(0) / data.num_nodes)
# num_repeats = 1
# row = row.view(-1, 1).repeat(1, num_repeats).view(-1).contiguous()
# col = col.view(-1, 1).repeat(1, num_repeats).view(-1).contiguous()
# # Warmup
# for _ in range(10):
# torch.randn(100, 100, device=device).sum()
# x = torch.randn(row.size(0), device=device)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out1 = scatter_add(x, row, dim=0, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('Scatter Row', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# scatter_add(x, col, dim=0, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('Scatter Col', time.perf_counter() - t)
# rowcount = segment_add(torch.ones_like(row), row)
# rowptr = torch.cat([rowcount.new_zeros(1), rowcount.cumsum(0)], dim=0)
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out3 = segment_add_csr(x, rowptr)
# torch.cuda.synchronize()
# print('CSR', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out4 = segment_add_coo(x, row, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('COO', time.perf_counter() - t)
# assert torch.allclose(out1, out3, atol=1e-2)
# assert torch.allclose(out1, out4, atol=1e-2)
# x = torch.randn((row.size(0), 64), device=device)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out5 = scatter_add(x, row, dim=0, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('Scatter Row + Dim', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# scatter_add(x, col, dim=0, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('Scatter Col + Dim', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out6 = segment_add_csr(x, rowptr)
# torch.cuda.synchronize()
# print('CSR + Dim', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out7 = segment_add_coo(x, row, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('COO + Dim', time.perf_counter() - t)
# assert torch.allclose(out5, out6, atol=1e-2)
# assert torch.allclose(out5, out7, atol=1e-2)
......@@ -8,7 +8,7 @@ from .max import scatter_max
from .min import scatter_min
from .logsumexp import scatter_logsumexp
from .segment import segment_add
from .segment import segment_coo, segment_csr
import torch_scatter.composite
......@@ -24,7 +24,8 @@ __all__ = [
'scatter_max',
'scatter_min',
'scatter_logsumexp',
'segment_add',
'segment_coo',
'segment_csr',
'torch_scatter',
'__version__',
]
import torch
from torch_scatter.add import scatter_add
if torch.cuda.is_available():
import torch_scatter.segment_cuda
def segment_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
return scatter_add(src, index, dim, out, dim_size, fill_value)
from torch_scatter import segment_cuda
def segment_add_csr(src, indptr, out=None):
return torch_scatter.segment_cuda.segment_add_csr(src, indptr, out)
def segment_add_coo(src, index, dim_size=None):
def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
assert reduce in ['add', 'mean', 'min', 'max']
if out is None:
dim_size = index.max().item() + 1 if dim_size is None else dim_size
size = list(src.size())
size[index.dim() - 1] = dim_size
out = src.new_zeros(size)
torch_scatter.segment_cuda.segment_add_coo(src, index, out)
return out
out = src.new_zeros(size) # TODO: DEPENDENT ON REDUCE
assert index.dtype == torch.long and src.dtype == out.dtype
out, arg_out = segment_cuda.segment_coo(src, index, out, reduce)
return out if arg_out is None else (out, arg_out)
def segment_csr(src, indptr, out=None, reduce='add'):
assert reduce in ['add', 'mean', 'min', 'max']
assert indptr.dtype == torch.long
out, arg_out = segment_cuda.segment_csr(src, indptr, out, reduce)
return out if arg_out is None else (out, arg_out)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment