"tests/models/language/generation/test_granitemoehybrid.py" did not exist on "521b35f799d8d7e22961a79e41256ff770ab2b95"
topk.cu 6.32 KB
Newer Older
1
2
// Persistent TopK kernel for DeepSeek V3 sparse attention indexer.
// See persistent_topk.cuh for kernel implementation.
3

4
5
6
7
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <algorithm>
8
9

#ifndef USE_ROCM
10
  #include "persistent_topk.cuh"
11
12
#endif

13
14
15
16
17
18
19
20
21
22
23
void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths,
                     torch::Tensor& output, torch::Tensor& workspace, int64_t k,
                     int64_t max_seq_len) {
#ifndef USE_ROCM
  TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor");
  TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor");
  TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor");
  TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported");
  TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32");
  TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32");
  TORCH_CHECK(logits.dim() == 2, "logits must be 2D");
24
25
26
  TORCH_CHECK(lengths.dim() == 1 || lengths.dim() == 2,
              "lengths must be 1D or 2D");
  TORCH_CHECK(lengths.is_contiguous(), "lengths must be contiguous");
27
28
29
30
31
  TORCH_CHECK(output.dim() == 2, "output must be 2D");

  const int64_t num_rows = logits.size(0);
  const int64_t stride = logits.size(1);

32
  TORCH_CHECK(lengths.numel() == num_rows, "lengths size mismatch");
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
  TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k,
              "output size mismatch");
  namespace P = vllm::persistent;

  TORCH_CHECK(k == P::TopK, "k must be 2048");
  TORCH_CHECK(k <= stride, "k out of range");

  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  static int num_sms = 0;
  static int max_smem_per_block = 0;
  if (num_sms == 0) {
    int device;
    cudaGetDevice(&device);
    cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device);
    cudaDeviceGetAttribute(&max_smem_per_block,
                           cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
50
51
  }

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
  if (num_rows > 32 && max_smem_per_block >= 128 * 1024) {
    cudaError_t status = vllm::FilteredTopKRaggedTransform<float, int32_t>(
        logits.data_ptr<float>(), output.data_ptr<int32_t>(),
        lengths.data_ptr<int32_t>(), static_cast<uint32_t>(num_rows),
        static_cast<uint32_t>(k), static_cast<uint32_t>(stride), stream);
    TORCH_CHECK(status == cudaSuccess,
                "FilteredTopK failed: ", cudaGetErrorString(status));
  } else {
    TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor");
    TORCH_CHECK(workspace.dtype() == torch::kUInt8, "workspace must be uint8");

    // Smem cap: smaller smem → more CTAs/group → more per-row parallelism for
    // large path. Empirically tuned.
    int effective_max_smem;
    if (num_rows <= 4) {
      effective_max_smem =
          std::min(max_smem_per_block, static_cast<int>(P::kSmemMedium));
    } else if (num_rows <= 8) {
      constexpr int kSmemCapMedium = 48 * 1024;
      effective_max_smem = std::min(max_smem_per_block, kSmemCapMedium);
    } else {
      effective_max_smem = max_smem_per_block;
74
75
    }

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    size_t available_for_ordered =
        static_cast<size_t>(effective_max_smem) - P::kFixedSmemLarge;
    uint32_t max_chunk_elements =
        static_cast<uint32_t>(available_for_ordered / sizeof(uint32_t));

    uint32_t vec_size = 1;
    if (stride % 4 == 0)
      vec_size = 4;
    else if (stride % 2 == 0)
      vec_size = 2;

    max_chunk_elements = (max_chunk_elements / vec_size) * vec_size;
    uint32_t min_chunk = vec_size * P::kThreadsPerBlock;
    if (max_chunk_elements < min_chunk) max_chunk_elements = min_chunk;

    uint32_t ctas_per_group =
        (static_cast<uint32_t>(stride) + max_chunk_elements - 1) /
        max_chunk_elements;
    uint32_t chunk_size =
        (static_cast<uint32_t>(stride) + ctas_per_group - 1) / ctas_per_group;
    chunk_size = ((chunk_size + vec_size - 1) / vec_size) * vec_size;
    if (chunk_size > max_chunk_elements) chunk_size = max_chunk_elements;

    size_t smem_size = P::kFixedSmemLarge + chunk_size * sizeof(uint32_t);
    if (smem_size < P::kSmemMedium) smem_size = P::kSmemMedium;

    int occupancy = 1;
    cudaOccupancyMaxActiveBlocksPerMultiprocessor(
        &occupancy, P::persistent_topk_kernel<4>, P::kThreadsPerBlock,
        smem_size);
    if (occupancy < 1) occupancy = 1;

    uint32_t max_resident_ctas = static_cast<uint32_t>(num_sms) * occupancy;
    uint32_t num_groups = std::min(max_resident_ctas / ctas_per_group,
                                   static_cast<uint32_t>(num_rows));
    if (num_groups == 0) num_groups = 1;
    uint32_t total_ctas = num_groups * ctas_per_group;

    size_t state_bytes = num_groups * sizeof(P::RadixRowState);
    TORCH_CHECK(workspace.size(0) >= static_cast<int64_t>(state_bytes),
                "workspace too small, need ", state_bytes, " bytes");

    P::PersistentTopKParams params;
    params.input = logits.data_ptr<float>();
    params.output = output.data_ptr<int32_t>();
    params.lengths = lengths.data_ptr<int32_t>();
    params.num_rows = static_cast<uint32_t>(num_rows);
    params.stride = static_cast<uint32_t>(stride);
    params.chunk_size = chunk_size;
    params.row_states =
        reinterpret_cast<P::RadixRowState*>(workspace.data_ptr<uint8_t>());
    params.ctas_per_group = ctas_per_group;
    params.max_seq_len = static_cast<uint32_t>(max_seq_len);

  #define LAUNCH_PERSISTENT(VS)                                               \
    do {                                                                      \
      auto kernel = &P::persistent_topk_kernel<VS>;                           \
      cudaError_t err = cudaFuncSetAttribute(                                 \
          kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);    \
      TORCH_CHECK(err == cudaSuccess,                                         \
                  "Failed to set smem: ", cudaGetErrorString(err));           \
      kernel<<<total_ctas, P::kThreadsPerBlock, smem_size, stream>>>(params); \
    } while (0)

    if (vec_size == 4) {
      LAUNCH_PERSISTENT(4);
    } else if (vec_size == 2) {
      LAUNCH_PERSISTENT(2);
    } else {
      LAUNCH_PERSISTENT(1);
146
    }
147
  #undef LAUNCH_PERSISTENT
148
149
  }

150
151
152
  cudaError_t err = cudaGetLastError();
  TORCH_CHECK(err == cudaSuccess,
              "persistent_topk failed: ", cudaGetErrorString(err));
153
#else
154
  TORCH_CHECK(false, "persistent_topk is not supported on ROCm");
155
156
#endif
}