topk.cu 13.4 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
// Portions of this file are adapted from SGLang PR:
// https://github.com/sgl-project/sglang/pull/11194
// and
// https://github.com/sgl-project/sglang/pull/17747

#include "cuda_compat.h"
#include "dispatch_utils.h"

#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>

#ifndef USE_ROCM
  #include <cub/cub.cuh>
#else
  #include <hipcub/hipcub.hpp>
#endif

namespace vllm {

constexpr int TopK = 2048;              // DeepSeek V3 sparse attention top-k
constexpr int kThreadsPerBlock = 1024;  // Threads per block

// Shared memory budget
#if defined(USE_ROCM)
constexpr size_t kSmem = 48 * 1024;  // ROCm default: 48KB
#else
// Reduced from 128KB to 32KB to improve occupancy.
// Each radix pass needs at most ~TopK candidates in the threshold bin,
// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient.
constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t);  // 32KB (bytes)
#endif

struct FastTopKParams {
  const float* __restrict__ input;         // [batch, seq_len] Logits
  const int32_t* __restrict__ row_starts;  // [batch] Offset into each row
                                           // (optional)
  int32_t* __restrict__ indices;           // [batch, TopK] Output top-k indices
  int32_t* __restrict__ lengths;           // [batch] Sequence lengths per row
  int64_t input_stride;                    // Stride between rows
};

__device__ __forceinline__ auto convert_to_uint32_v2(float x) -> uint32_t {
  uint32_t bits = __float_as_uint(x);
  return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u);
}

__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t {
  __half h = __float2half_rn(x);
  uint16_t bits = __half_as_ushort(h);
  uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits)
                                 : static_cast<uint16_t>(bits | 0x8000);
  return static_cast<uint8_t>(key >> 8);
}

__device__ void naive_topk_cuda(const float* __restrict__ logits,
                                int32_t* __restrict__ output_indices,
                                int32_t seq_len) {
  const int thread_id = threadIdx.x;
  for (int i = thread_id; i < TopK; i += kThreadsPerBlock) {
    output_indices[i] = (i < seq_len) ? i : -1;
  }
}

// Adapted from:
// https://github.com/sgl-project/sglang/blob/v0.5.8/sgl-kernel/csrc/elementwise/topk.cu#L87
// by: DarkSharpness
// which at the same time is an optimized topk kernel copied from tilelang
// kernel
__device__ void fast_topk_cuda_tl(
    const float* __restrict__ logits,  // Input logits [seq_len]
    int* __restrict__ output_indices,  // Output top-k indices [TopK]
    int logits_offset,                 // Starting offset in logits array
    int seq_len)                       // Number of valid logits to process
{
  constexpr int RADIX = 256;
  constexpr int MAX_BUFFERED_ITEMS = kSmem / (2 * sizeof(int));

  alignas(128) __shared__ int shared_histogram[2][RADIX + 128];
  alignas(128) __shared__ int shared_output_count;
  alignas(128) __shared__ int shared_threshold_bin;
  alignas(128) __shared__ int shared_buffered_count[2];

  extern __shared__ int buffered_indices[][MAX_BUFFERED_ITEMS];

  const int thread_id = threadIdx.x;
  int remaining_k = TopK;

  // Pass 0: Build coarse 8-bit histogram using FP16 high bits
  if (thread_id < RADIX + 1) {
    shared_histogram[0][thread_id] = 0;
  }
  __syncthreads();

  for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
    const auto bin = convert_to_uint8(logits[idx + logits_offset]);
    ::atomicAdd(&shared_histogram[0][bin], 1);
  }
  __syncthreads();

  // Helper: Compute cumulative sum (suffix sum) over histogram using ping-pong
  // buffers
  auto compute_cumulative_sum = [&]() {
    static_assert(1 << 8 == RADIX,
                  "Radix must be 256 for 8 unrolled iterations");
#pragma unroll 8
    for (int i = 0; i < 8; ++i) {
      if (C10_LIKELY(thread_id < RADIX)) {
        const int stride = 1 << i;
        const int src_buffer = i & 1;
        const int dst_buffer = src_buffer ^ 1;

        int value = shared_histogram[src_buffer][thread_id];
        if (thread_id < RADIX - stride) {
          value += shared_histogram[src_buffer][thread_id + stride];
        }
        shared_histogram[dst_buffer][thread_id] = value;
      }
      __syncthreads();
    }
  };

  compute_cumulative_sum();

  // Find threshold bin where cumsum crosses remaining_k
  if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k &&
      shared_histogram[0][thread_id + 1] <= remaining_k) {
    shared_threshold_bin = thread_id;
    shared_buffered_count[0] = 0;
    shared_output_count = 0;
  }
  __syncthreads();

  const int threshold_bin = shared_threshold_bin;
  remaining_k -= shared_histogram[0][threshold_bin + 1];

  // Early exit if threshold bin perfectly matches remaining_k
  if (remaining_k == 0) {
    for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
      const int bin = convert_to_uint8(logits[idx + logits_offset]);
      if (bin > threshold_bin) {
        const int output_pos = ::atomicAdd(&shared_output_count, 1);
        output_indices[output_pos] = idx;
      }
    }
    __syncthreads();
    return;
  }

  // Prepare for refinement passes: Process threshold bin
  __syncthreads();
  if (thread_id < RADIX + 1) {
    shared_histogram[0][thread_id] = 0;
  }
  __syncthreads();

  // Scan all elements and:
  // 1. Write indices > threshold_bin to output
  // 2. Buffer indices == threshold_bin for refinement
  // 3. Build histogram for next refinement pass (fused optimization)
  for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) {
    const float logit_value = logits[idx + logits_offset];
    const int bin = convert_to_uint8(logit_value);

    if (bin > threshold_bin) {
      // in top-k, write to output
      const int output_pos = ::atomicAdd(&shared_output_count, 1);
      output_indices[output_pos] = idx;
    } else if (bin == threshold_bin) {
      // Candidate for top-k, needs refinement
      const int buffer_pos = ::atomicAdd(&shared_buffered_count[0], 1);
      if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) {
        buffered_indices[0][buffer_pos] = idx;
        // Fused: Build histogram for next pass
        const uint32_t fp32_bits = convert_to_uint32_v2(logit_value);
        const int next_bin = (fp32_bits >> 24) & 0xFF;
        ::atomicAdd(&shared_histogram[0][next_bin], 1);
      }
    }
  }
  __syncthreads();

  // ============================================================================
  // Passes 1-4: Refine using 8-bit passes over FP32 bits
  // ============================================================================
  // FP32 bits [31:0] split into 4 bytes processed MSB-first:
  // Pass 1: bits [31:24], Pass 2: bits [23:16], Pass 3: bits [15:8], Pass 4:
  // bits [7:0]
#pragma unroll 4
  for (int pass = 0; pass < 4; ++pass) {
    __shared__ int shared_final_k;  // For final pass: remaining slots to fill
    const int src_buffer = pass % 2;
    const int dst_buffer = src_buffer ^ 1;

    // Clamp buffered count to prevent overflow
    const int raw_buffered = shared_buffered_count[src_buffer];
    const int num_buffered =
        (raw_buffered < MAX_BUFFERED_ITEMS) ? raw_buffered : MAX_BUFFERED_ITEMS;

    compute_cumulative_sum();

    // Find threshold bin for this pass
    if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k &&
        shared_histogram[0][thread_id + 1] <= remaining_k) {
      shared_threshold_bin = thread_id;
      shared_buffered_count[dst_buffer] = 0;
      shared_final_k = remaining_k - shared_histogram[0][thread_id + 1];
    }
    __syncthreads();

    const int threshold_bin = shared_threshold_bin;
    remaining_k -= shared_histogram[0][threshold_bin + 1];

    // Bit offset for this pass: 24, 16, 8, 0
    const int bit_offset = 24 - pass * 8;

    // Early exit if threshold bin perfectly matches
    if (remaining_k == 0) {
      for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) {
        const int idx = buffered_indices[src_buffer][i];
        const uint32_t fp32_bits =
            convert_to_uint32_v2(logits[idx + logits_offset]);
        const int bin = (fp32_bits >> bit_offset) & 0xFF;
        if (bin > threshold_bin) {
          const int output_pos = ::atomicAdd(&shared_output_count, 1);
          output_indices[output_pos] = idx;
        }
      }
      __syncthreads();
      break;
    }

    // Continue refinement
    __syncthreads();
    if (thread_id < RADIX + 1) {
      shared_histogram[0][thread_id] = 0;
    }
    __syncthreads();

    for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) {
      const int idx = buffered_indices[src_buffer][i];
      const float logit_value = logits[idx + logits_offset];
      const uint32_t fp32_bits = convert_to_uint32_v2(logit_value);
      const int bin = (fp32_bits >> bit_offset) & 0xFF;

      if (bin > threshold_bin) {
        // Definitely in top-k
        const int output_pos = ::atomicAdd(&shared_output_count, 1);
        output_indices[output_pos] = idx;
      } else if (bin == threshold_bin) {
        if (pass == 3) {
          // Final pass (bits [7:0]): No more refinement possible
          // Fill remaining slots in reverse order to maintain descending order
          const int slot = ::atomicAdd(&shared_final_k, -1);
          if (slot > 0) {
            output_indices[TopK - slot] = idx;
          }
        } else {
          // Buffer for next pass and build next histogram
          const int buffer_pos =
              ::atomicAdd(&shared_buffered_count[dst_buffer], 1);
          if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) {
            buffered_indices[dst_buffer][buffer_pos] = idx;
            // Fused: Build histogram for next pass
            const int next_bit_offset = bit_offset - 8;
            const int next_bin = (fp32_bits >> next_bit_offset) & 0xFF;
            ::atomicAdd(&shared_histogram[0][next_bin], 1);
          }
        }
      }
    }
    __syncthreads();
  }
}

__global__ __launch_bounds__(kThreadsPerBlock) void topk_kernel(
    const FastTopKParams params) {
  const auto& [input, row_starts, indices, lengths, input_stride] = params;
  const uint64_t batch_idx = blockIdx.x;
  const int logits_offset = row_starts == nullptr ? 0 : row_starts[batch_idx];
  const int seq_len = lengths[batch_idx];
  int* output_indices = indices + batch_idx * TopK;
  const float* logits = input + batch_idx * input_stride;

  if (seq_len <= TopK) {
    // Shortcut: All elements are in top-k
    return naive_topk_cuda(logits, output_indices, seq_len);
  } else {
    return fast_topk_cuda_tl(logits, output_indices, logits_offset, seq_len);
  }
}

FastTopKParams get_params(
    const at::Tensor& score, const at::Tensor& lengths,
    std::optional<at::Tensor> row_starts_opt = std::nullopt,
    std::optional<at::Tensor> indices_opt = std::nullopt) {
  const int64_t batch_size = score.size(0);

  TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1,
              "score must be 2D with contiguous rows");
  TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous() &&
                  lengths.size(0) == batch_size,
              "lengths must be 1D contiguous with size matching batch");

  const int32_t* row_starts_ptr = nullptr;
  if (row_starts_opt.has_value()) {
    const auto& row_starts = *row_starts_opt;
    TORCH_CHECK(row_starts.dim() == 1 && row_starts.size(0) == batch_size,
                "row_starts must be 1D with size matching batch");
    row_starts_ptr = row_starts.data_ptr<int32_t>();
  }

  int32_t* indices_ptr = nullptr;
  if (indices_opt.has_value()) {
    const auto& indices = *indices_opt;
    TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous() &&
                    indices.size(0) == batch_size && indices.size(1) == TopK,
                "indices must be 2D contiguous [batch, TopK]");
    indices_ptr = indices.data_ptr<int32_t>();
  }

  return FastTopKParams{
      .input = score.data_ptr<float>(),
      .row_starts = row_starts_ptr,
      .indices = indices_ptr,
      .lengths = lengths.data_ptr<int32_t>(),
      .input_stride = score.stride(0),
  };
}

template <auto* kernel_func, size_t smem_bytes>
void setup_kernel_smem_once() {
  static const cudaError_t result = []() -> cudaError_t {
#ifdef USE_ROCM
    auto func_ptr = reinterpret_cast<const void*>(kernel_func);
#else
    auto func_ptr = kernel_func;
#endif
    return cudaFuncSetAttribute(
        func_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
  }();

  TORCH_CHECK(
      result == cudaSuccess,
      "Failed to set kernel shared memory limit: ", cudaGetErrorString(result));
}

}  // namespace vllm

void large_context_topk(
    const torch::Tensor& logits, torch::Tensor& indices,
    const torch::Tensor& seq_lens,
    std::optional<torch::Tensor> row_starts = std::nullopt) {
  TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor");
  TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor");
  TORCH_CHECK(seq_lens.is_cuda(), "seq_lens must be a CUDA tensor");
  if (row_starts.has_value()) {
    TORCH_CHECK(row_starts->is_cuda(), "row_starts must be a CUDA tensor");
  }

  const auto params = vllm::get_params(logits, seq_lens, row_starts, indices);
  const int64_t batch_size = logits.size(0);

  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  const dim3 grid(static_cast<uint32_t>(batch_size));
  const dim3 block(vllm::kThreadsPerBlock);

  vllm::setup_kernel_smem_once<vllm::topk_kernel, vllm::kSmem>();
  vllm::topk_kernel<<<grid, block, vllm::kSmem, stream>>>(params);

  const cudaError_t result = cudaGetLastError();
  TORCH_CHECK(result == cudaSuccess,
              "large_context_topk kernel failed: ", cudaGetErrorString(result));
}