Commit 6e561c88 authored by rusty1s's avatar rusty1s
Browse files

atomics

parent 9725b043
......@@ -67,10 +67,30 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
}
static inline __device__ void atom_write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg) {
if (REDUCE == ADD) {
atomAdd(address, val);
} else if (REDUCE == MEAN) {
atomAdd(address, val);
} else if (REDUCE == MIN && val < *address) {
atomMin(address, val);
} else if (REDUCE == MAX && val > *address) {
atomMax(address, val);
}
if (REDUCE == MIN || REDUCE == MAX) {
__syncthreads();
if (*address == val) {
*arg_address = arg;
}
}
}
};
// We need our own `IndexToOffset` implementation since we do not want to access
// the last element of the `indexptr`.
// We need our own `IndexToOffset` implementation since we do not want to
// access the last element of the `indexptr`.
template <typename scalar_t> struct IndexPtrToOffset {
static inline __host__ __device__ int
get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) {
......@@ -92,8 +112,8 @@ segment_csr_kernel(const scalar_t *src_data,
scalar_t *out_data, int64_t *arg_out_data, size_t N,
size_t E) {
// Each warp processes exactly `32/TB` rows and aggregates all row values via
// a parallel reduction.
// Each warp processes exactly `32/TB` rows and aggregates all row values
// via a parallel reduction.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB;
......@@ -106,7 +126,7 @@ segment_csr_kernel(const scalar_t *src_data,
indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg, tmp;
int64_t arg, arg_tmp;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
......@@ -118,10 +138,10 @@ segment_csr_kernel(const scalar_t *src_data,
for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp.
if (REDUCE == MIN || REDUCE == MAX) {
tmp = __shfl_down_sync(FULL_MASK, arg, i);
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
}
Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, tmp);
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
}
if (lane_idx == 0) {
......@@ -241,20 +261,27 @@ segment_coo_kernel(const scalar_t *src_data,
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int idx = index_info.data[offset], next_idx;
scalar_t val = src_data[row_idx], tmp;
int64_t arg = row_idx % index_info.sizes[index_info.dims - 1], arg_tmp;
#pragma unroll
for (int i = 1; i < 32; i *= 2) {
// Parallel reduction inside a single warp.
tmp = __shfl_up_sync(FULL_MASK, val, i);
if (REDUCE == MIN || REDUCE == MAX) {
arg_tmp = __shfl_up_sync(FULL_MASK, arg, i);
}
next_idx = __shfl_up_sync(FULL_MASK, idx, i);
assert(idx >= next_idx);
if (lane_idx >= i && idx == next_idx)
val += tmp;
Reducer<scalar_t, REDUCE>::update(&val, tmp, &arg, arg_tmp);
}
next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
if (lane_idx == 32 - 1 || idx != next_idx) {
atomAdd(out_data + idx, val);
Reducer<scalar_t, REDUCE>::atom_write(out_data + idx, val,
arg_out_data + idx, arg);
}
}
}
......@@ -365,5 +392,9 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
0, stream>>>(src_data, index_info, out_data, nullptr, E, K);
});
if (reduce == "mean") {
AT_ASSERTM(false); // TODO: DIVIDE ENTRIES.
}
return std::make_tuple(out, arg_out);
}
......@@ -18,18 +18,12 @@ def test_forward(dtype, device):
src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
# out = segment_coo(src, index)
# print('COO', out)
out = segment_csr(src, indptr, reduce='add')
print('CSR', out)
out = segment_csr(src, indptr, reduce='mean')
print('CSR', out)
out = segment_csr(src, indptr, reduce='min')
print('CSR', out)
out = segment_csr(src, indptr, reduce='max')
print('CSR', out)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = segment_coo(src, index, reduce='add')
print('COO', out)
# @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment