"docs/git@developer.sourcefind.cn:change/sglang.git" did not exist on "9e656dd3b2bef5cec5c05a8cf393a4bff0231f86"
Commit 6e561c88 authored by rusty1s's avatar rusty1s
Browse files

atomics

parent 9725b043
...@@ -67,10 +67,30 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -67,10 +67,30 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
} }
} }
} }
static inline __device__ void atom_write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg) {
if (REDUCE == ADD) {
atomAdd(address, val);
} else if (REDUCE == MEAN) {
atomAdd(address, val);
} else if (REDUCE == MIN && val < *address) {
atomMin(address, val);
} else if (REDUCE == MAX && val > *address) {
atomMax(address, val);
}
if (REDUCE == MIN || REDUCE == MAX) {
__syncthreads();
if (*address == val) {
*arg_address = arg;
}
}
}
}; };
// We need our own `IndexToOffset` implementation since we do not want to access // We need our own `IndexToOffset` implementation since we do not want to
// the last element of the `indexptr`. // access the last element of the `indexptr`.
template <typename scalar_t> struct IndexPtrToOffset { template <typename scalar_t> struct IndexPtrToOffset {
static inline __host__ __device__ int static inline __host__ __device__ int
get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) { get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) {
...@@ -92,8 +112,8 @@ segment_csr_kernel(const scalar_t *src_data, ...@@ -92,8 +112,8 @@ segment_csr_kernel(const scalar_t *src_data,
scalar_t *out_data, int64_t *arg_out_data, size_t N, scalar_t *out_data, int64_t *arg_out_data, size_t N,
size_t E) { size_t E) {
// Each warp processes exactly `32/TB` rows and aggregates all row values via // Each warp processes exactly `32/TB` rows and aggregates all row values
// a parallel reduction. // via a parallel reduction.
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / TB; int row_idx = thread_idx / TB;
...@@ -106,7 +126,7 @@ segment_csr_kernel(const scalar_t *src_data, ...@@ -106,7 +126,7 @@ segment_csr_kernel(const scalar_t *src_data,
indptr_info.strides[indptr_info.dims - 1]); indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init(); scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg, tmp; int64_t arg, arg_tmp;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) { for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
...@@ -118,10 +138,10 @@ segment_csr_kernel(const scalar_t *src_data, ...@@ -118,10 +138,10 @@ segment_csr_kernel(const scalar_t *src_data,
for (int i = TB / 2; i > 0; i /= 2) { for (int i = TB / 2; i > 0; i /= 2) {
// Parallel reduction inside a single warp. // Parallel reduction inside a single warp.
if (REDUCE == MIN || REDUCE == MAX) { if (REDUCE == MIN || REDUCE == MAX) {
tmp = __shfl_down_sync(FULL_MASK, arg, i); arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
} }
Reducer<scalar_t, REDUCE>::update( Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, tmp); &val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
} }
if (lane_idx == 0) { if (lane_idx == 0) {
...@@ -241,20 +261,27 @@ segment_coo_kernel(const scalar_t *src_data, ...@@ -241,20 +261,27 @@ segment_coo_kernel(const scalar_t *src_data,
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get( int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info); row_idx, index_info);
int idx = index_info.data[offset], next_idx; int idx = index_info.data[offset], next_idx;
scalar_t val = src_data[row_idx], tmp; scalar_t val = src_data[row_idx], tmp;
int64_t arg = row_idx % index_info.sizes[index_info.dims - 1], arg_tmp;
#pragma unroll #pragma unroll
for (int i = 1; i < 32; i *= 2) { for (int i = 1; i < 32; i *= 2) {
// Parallel reduction inside a single warp.
tmp = __shfl_up_sync(FULL_MASK, val, i); tmp = __shfl_up_sync(FULL_MASK, val, i);
if (REDUCE == MIN || REDUCE == MAX) {
arg_tmp = __shfl_up_sync(FULL_MASK, arg, i);
}
next_idx = __shfl_up_sync(FULL_MASK, idx, i); next_idx = __shfl_up_sync(FULL_MASK, idx, i);
assert(idx >= next_idx); assert(idx >= next_idx);
if (lane_idx >= i && idx == next_idx) if (lane_idx >= i && idx == next_idx)
val += tmp; Reducer<scalar_t, REDUCE>::update(&val, tmp, &arg, arg_tmp);
} }
next_idx = __shfl_down_sync(FULL_MASK, idx, 1); next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
if (lane_idx == 32 - 1 || idx != next_idx) { if (lane_idx == 32 - 1 || idx != next_idx) {
atomAdd(out_data + idx, val); Reducer<scalar_t, REDUCE>::atom_write(out_data + idx, val,
arg_out_data + idx, arg);
} }
} }
} }
...@@ -365,5 +392,9 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -365,5 +392,9 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
0, stream>>>(src_data, index_info, out_data, nullptr, E, K); 0, stream>>>(src_data, index_info, out_data, nullptr, E, K);
}); });
if (reduce == "mean") {
AT_ASSERTM(false); // TODO: DIVIDE ENTRIES.
}
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
} }
...@@ -18,18 +18,12 @@ def test_forward(dtype, device): ...@@ -18,18 +18,12 @@ def test_forward(dtype, device):
src = tensor([1, 2, 3, 4, 5, 6], dtype, device) src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
indptr = tensor([0, 2, 5, 5, 6], torch.long, device) indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
# out = segment_coo(src, index)
# print('COO', out)
out = segment_csr(src, indptr, reduce='add') out = segment_csr(src, indptr, reduce='add')
print('CSR', out) print('CSR', out)
out = segment_csr(src, indptr, reduce='mean')
print('CSR', out) index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = segment_csr(src, indptr, reduce='min') out = segment_coo(src, index, reduce='add')
print('CSR', out) print('COO', out)
out = segment_csr(src, indptr, reduce='max')
print('CSR', out)
# @pytest.mark.parametrize('dtype,device', product(dtypes, devices)) # @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment