"python/sglang/vscode:/vscode.git/clone" did not exist on "81964328b7ed9f12bcc5ce171fe87795fd202de3"
segment_kernel.cu 7 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
rusty1s's avatar
rusty1s committed
3
4
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
rusty1s's avatar
rusty1s committed
5
6
7
8
9
10

#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>

#include <thrust/execution_policy.h>

rusty1s's avatar
rusty1s committed
11
#include "atomics.cuh"
rusty1s's avatar
rusty1s committed
12
#include "compat.cuh"
rusty1s's avatar
rusty1s committed
13
#include "index.cuh"
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
#define THREADS 256
rusty1s's avatar
rusty1s committed
16
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
rusty1s's avatar
rusty1s committed
17
18
#define FULL_MASK 0xffffffff

rusty1s's avatar
rusty1s committed
19
20
21
22
template <typename T, typename I> struct IndexPtrToOffset {
  static __host__ __device__ I
  get(I idx, const at::cuda::detail::TensorInfo<T, I> &info) {
    I offset = idx % (info.sizes[info.dims - 1] - 1);
rusty1s's avatar
rusty1s committed
23
    offset *= info.strides[info.dims - 1];
rusty1s's avatar
rusty1s committed
24
25
26
27
28
29
30
31
32
    idx /= info.sizes[info.dims - 1] - 1;
    for (int i = info.dims - 2; i >= 0; --i) {
      offset += (idx % info.sizes[i]) * info.strides[i];
      idx /= info.sizes[i];
    }
    return offset;
  }
};

rusty1s's avatar
rusty1s committed
33
template <typename scalar_t, int TB>
rusty1s's avatar
rusty1s committed
34
35
36
37
__global__ void segment_add_csr_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
    scalar_t *out_data, size_t N, size_t E) {
rusty1s's avatar
rusty1s committed
38
39

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
rusty1s's avatar
rusty1s committed
40
  int row_idx = thread_idx / TB;
rusty1s's avatar
rusty1s committed
41
42
  int lane_idx = thread_idx & (TB - 1);

rusty1s's avatar
rusty1s committed
43
44
  if (row_idx < N) {
    auto offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
rusty1s's avatar
rusty1s committed
45
    int row_start = __ldg(indptr_info.data + offset);
rusty1s's avatar
rusty1s committed
46
47
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);
rusty1s's avatar
rusty1s committed
48
49
    scalar_t val = (scalar_t)0;

rusty1s's avatar
rusty1s committed
50
    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
rusty1s's avatar
rusty1s committed
51
    for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) {
rusty1s's avatar
rusty1s committed
52
      val += src_data[offset + src_idx];
rusty1s's avatar
rusty1s committed
53
54
55
    }

#pragma unroll
rusty1s's avatar
rusty1s committed
56
    for (int i = TB / 2; i > 0; i /= 2)
rusty1s's avatar
rusty1s committed
57
      val += __shfl_down_sync(FULL_MASK, val, i); // Parallel reduction
rusty1s's avatar
rusty1s committed
58
59

    if (lane_idx == 0) {
rusty1s's avatar
rusty1s committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
      out_data[row_idx] = val;
    }
  }
}

template <typename scalar_t>
__global__ void segment_add_csr_broadcast_kernel(
    const scalar_t *src_data,
    const at::cuda::detail::TensorInfo<int64_t, int> indptr_info,
    scalar_t *out_data, size_t N, size_t K, size_t E) {

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int row_idx = thread_idx / K;
  int lane_idx = thread_idx % K;

  if (thread_idx < N * K) {
    auto offset = IndexPtrToOffset<int64_t, int>::get(row_idx, indptr_info);
    int row_start = __ldg(indptr_info.data + offset);
    int row_end = __ldg(indptr_info.data + offset +
                        indptr_info.strides[indptr_info.dims - 1]);
    scalar_t val = (scalar_t)0;

    offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
    for (int src_idx = row_start; src_idx < row_end; src_idx++) {
      // Coalesced read into `src_data`.
      val += src_data[offset + K * src_idx + lane_idx];
rusty1s's avatar
rusty1s committed
86
    }
rusty1s's avatar
rusty1s committed
87
88

    out_data[thread_idx] = val; // Coalesced write into `out_data`
rusty1s's avatar
rusty1s committed
89
90
91
  }
}

rusty1s's avatar
rusty1s committed
92
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
rusty1s's avatar
rusty1s committed
93
  AT_ASSERTM(src.dim() >= indptr.dim());
rusty1s's avatar
rusty1s committed
94
95
96
  for (int i = 0; i < indptr.dim() - 1; i++)
    AT_ASSERTM(src.size(i) == indptr.size(i));

rusty1s's avatar
rusty1s committed
97
  src = src.contiguous();
rusty1s's avatar
rusty1s committed
98
99
100
101
102

  auto reduce_dim = indptr.dim() - 1;
  auto sizes = src.sizes().vec();
  sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
  auto out = at::empty(sizes, src.options());
rusty1s's avatar
rusty1s committed
103

rusty1s's avatar
rusty1s committed
104
105
  auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
  auto K = out.numel() / N;
rusty1s's avatar
rusty1s committed
106
107
  auto E = src.size(reduce_dim);
  auto avg_length = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
rusty1s's avatar
rusty1s committed
108

rusty1s's avatar
rusty1s committed
109
  auto indptr_info = at::cuda::detail::getTensorInfo<int64_t, int>(indptr);
rusty1s's avatar
rusty1s committed
110
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
111
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_csr_kernel", [&] {
rusty1s's avatar
rusty1s committed
112
113
114
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
115
    if (K == 1 && avg_length <= 4) {
rusty1s's avatar
rusty1s committed
116
117
      segment_add_csr_kernel<scalar_t, 4><<<BLOCKS(4, N), THREADS, 0, stream>>>(
          src_data, indptr_info, out_data, N, E);
rusty1s's avatar
rusty1s committed
118
    } else if (K == 1 && avg_length <= 8) {
rusty1s's avatar
rusty1s committed
119
120
      segment_add_csr_kernel<scalar_t, 8><<<BLOCKS(8, N), THREADS, 0, stream>>>(
          src_data, indptr_info, out_data, N, E);
rusty1s's avatar
rusty1s committed
121
    } else if (K == 1 && avg_length <= 16) {
rusty1s's avatar
rusty1s committed
122
      segment_add_csr_kernel<scalar_t, 16>
rusty1s's avatar
rusty1s committed
123
124
          <<<BLOCKS(16, N), THREADS, 0, stream>>>(src_data, indptr_info,
                                                  out_data, N, E);
rusty1s's avatar
rusty1s committed
125
    } else if (K == 1) {
rusty1s's avatar
rusty1s committed
126
      segment_add_csr_kernel<scalar_t, 32>
rusty1s's avatar
rusty1s committed
127
128
          <<<BLOCKS(32, N), THREADS, 0, stream>>>(src_data, indptr_info,
                                                  out_data, N, E);
rusty1s's avatar
rusty1s committed
129
130
131
132
133
    } else {
      segment_add_csr_broadcast_kernel<scalar_t>
          <<<BLOCKS(1, N * K), THREADS, 0, stream>>>(src_data, indptr_info,
                                                     out_data, N, K, E);
    }
rusty1s's avatar
rusty1s committed
134
135
136
137
138
  });

  return out;
}

rusty1s's avatar
rusty1s committed
139
template <typename scalar_t, int TB>
rusty1s's avatar
rusty1s committed
140
141
142
143
144
__global__ void segment_add_coo_kernel(const scalar_t *src_data,
                                       const int64_t *index_data,
                                       scalar_t *out_data, size_t numel) {

  int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
rusty1s's avatar
rusty1s committed
145
  int lane_idx = thread_idx & (TB - 1);
rusty1s's avatar
rusty1s committed
146
147
148
149
150
151

  if (thread_idx < numel) {
    auto idx = __ldg(index_data + thread_idx);
    scalar_t val = src_data[thread_idx], tmp;

#pragma unroll
rusty1s's avatar
rusty1s committed
152
    for (int offset = 1; offset < TB; offset *= 2) {
rusty1s's avatar
rusty1s committed
153
154
155
156
157
158
159
      tmp = __shfl_up_sync(FULL_MASK, val, offset);
      if (lane_idx >= offset &&
          idx == __ldg(index_data + thread_idx - offset)) {
        val += tmp;
      }
    }

rusty1s's avatar
rusty1s committed
160
    if (lane_idx == TB - 1 || idx != __ldg(index_data + thread_idx + 1)) {
rusty1s's avatar
rusty1s committed
161
162
163
164
165
166
167
      atomAdd(out_data + idx, val);
    }
  }
}

void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
  auto numel = src.numel();
rusty1s's avatar
rusty1s committed
168
  auto avg_length = (float)numel / (float)out.numel();
rusty1s's avatar
rusty1s committed
169
170
171
172
173
174
175

  auto index_data = index.DATA_PTR<int64_t>();
  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_coo_kernel", [&] {
    auto src_data = src.DATA_PTR<scalar_t>();
    auto out_data = out.DATA_PTR<scalar_t>();

rusty1s's avatar
rusty1s committed
176
177
178
    segment_add_coo_kernel<scalar_t, 32>
        <<<BLOCKS(1, numel), THREADS, 0, stream>>>(src_data, index_data,
                                                   out_data, numel);
rusty1s's avatar
rusty1s committed
179
  });
rusty1s's avatar
rusty1s committed
180
181
}

rusty1s's avatar
rusty1s committed
182
183
void segment_add_thrust_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
184
185
186
  auto allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA());
  auto policy = thrust::cuda::par(allocator).on(stream);

rusty1s's avatar
rusty1s committed
187
188
  auto key = at::full_like(out, -1, out.options().dtype(at::kLong));

rusty1s's avatar
rusty1s committed
189
  auto index_data = thrust::device_ptr<int64_t>(index.DATA_PTR<int64_t>());
rusty1s's avatar
rusty1s committed
190
191
  auto key_data = thrust::device_ptr<int64_t>(key.DATA_PTR<int64_t>());

rusty1s's avatar
rusty1s committed
192
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_add_thrust_kernel", [&] {
rusty1s's avatar
rusty1s committed
193
194
    auto src_data = thrust::device_ptr<scalar_t>(src.DATA_PTR<scalar_t>());
    auto out_data = thrust::device_ptr<scalar_t>(out.DATA_PTR<scalar_t>());
rusty1s's avatar
rusty1s committed
195

rusty1s's avatar
rusty1s committed
196
    thrust::reduce_by_key(policy, index_data, index_data + index.numel(),
rusty1s's avatar
rusty1s committed
197
                          src_data, key_data, out_data);
rusty1s's avatar
rusty1s committed
198
199
  });
}