eagle_utils.cu 8.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*
 * Copyright (c) 2025 by SGLang team.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
26
27
28
29
30
31
32
33
34
35
36
37
__global__ void build_tree_efficient(
    int64_t* parent_list,
    int64_t* selected_index,
    int32_t* verified_seq_len,
    bool* tree_mask,
    int64_t* positions,
    int64_t* retrive_index,
    int64_t* retrive_next_token,
    int64_t* retrive_next_sibling,
    int topk,
    int depth,
    int draft_token_num) {
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  int bid = blockIdx.x;
  int tid = threadIdx.x;

  if (tid >= draft_token_num) {
    return;
  }
  int seq_tree_idx = draft_token_num * draft_token_num * bid;
  for (int i = 0; i < bid; i++) {
    seq_tree_idx += verified_seq_len[i] * draft_token_num;
  }
  int seq_len = verified_seq_len[bid];
  int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
  for (int i = 0; i < draft_token_num - 1; i++) {
    tree_mask[token_tree_idx + i] = false;
  }

  int position = 0;
  if (tid == 0) {
    positions[bid * draft_token_num] = seq_len;

    int retrive_index_offset = bid * draft_token_num;
    for (int i = draft_token_num - 1; i > 0; --i) {
      int current_token_idx = retrive_index_offset + i;
      retrive_index[bid * draft_token_num + i] = current_token_idx;
      int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk;
      int parent_position = 0;
      if (parent_tb_idx > 0) {
        int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
        for (; parent_position < draft_token_num; ++parent_position) {
          if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) {
            ++parent_position;
            break;
          }
        }
      }
      if (parent_position == draft_token_num) {
        printf(
            "ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token "
            "will be dropped.");
        continue;
      }

      if (retrive_next_token[bid * draft_token_num + parent_position] == -1) {
        retrive_next_token[bid * draft_token_num + parent_position] = i;
      } else {
        int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position];
        retrive_next_token[bid * draft_token_num + parent_position] = i;
        retrive_next_sibling[bid * draft_token_num + i] = origin_next_token;
      }
    }
    retrive_index[bid * draft_token_num] = bid * draft_token_num;
  } else {
    int cur_position = tid - 1;
    while (true) {
      position += 1;
      tree_mask[token_tree_idx + cur_position] = true;
      int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
      if (parent_tb_idx == 0) {
        break;
      }

      int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
      for (cur_position = 0; cur_position < draft_token_num; ++cur_position) {
        if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
          break;
        }
      }
    }
    positions[bid * draft_token_num + tid] = position + seq_len;
  }
}

110
111
112
113
114
115
116
117
118
119
120
121
void build_tree_kernel_efficient(
    at::Tensor parent_list,
    at::Tensor selected_index,
    at::Tensor verified_seq_len,
    at::Tensor tree_mask,
    at::Tensor positions,
    at::Tensor retrive_index,
    at::Tensor retrive_next_token,
    at::Tensor retrive_next_sibling,
    int64_t topk,
    int64_t depth,
    int64_t draft_token_num) {
122
123
124
125
126
127
128
129
  // TODO (ying) check shape
  // TODO (ying) check type
  int bs = parent_list.size(0);
  dim3 grid(bs);
  dim3 block(draft_token_num);
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  build_tree_efficient<<<grid, block, 0, stream>>>(
130
131
132
133
134
135
136
137
138
139
140
      static_cast<int64_t*>(parent_list.data_ptr()),
      static_cast<int64_t*>(selected_index.data_ptr()),
      static_cast<int32_t*>(verified_seq_len.data_ptr()),
      static_cast<bool*>(tree_mask.data_ptr()),
      static_cast<int64_t*>(positions.data_ptr()),
      static_cast<int64_t*>(retrive_index.data_ptr()),
      static_cast<int64_t*>(retrive_next_token.data_ptr()),
      static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
      int32_t(topk),
      int32_t(depth),
      int32_t(draft_token_num));
141
142
143
144
145
146
147
148
}

// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token, depth + 2]
149
150
151
152
153
154
155
156
157
158
__global__ void build_tree(
    int64_t* parent_list,
    int64_t* selected_index,
    int32_t* verified_seq_len,
    bool* tree_mask,
    int64_t* positions,
    int64_t* retrive_index,
    int topk,
    int depth,
    int draft_token_num) {
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
  int bid = blockIdx.x;
  int tid = threadIdx.x;

  if (tid >= draft_token_num) {
    return;
  }
  int seq_tree_idx = draft_token_num * draft_token_num * bid;
  for (int i = 0; i < bid; i++) {
    seq_tree_idx += verified_seq_len[i] * draft_token_num;
  }
  int seq_len = verified_seq_len[bid];
  int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
  for (int i = 0; i < draft_token_num - 1; i++) {
    tree_mask[token_tree_idx + i] = false;
  }

  int position = 0;
  if (tid == 0) {
    positions[bid * draft_token_num] = seq_len;
    retrive_index[bid * draft_token_num * (depth + 2)] = bid * draft_token_num;
    return;
  }

  int depends_order[10];

  int cur_position = tid - 1;
  while (true) {
    depends_order[position] = cur_position + 1;
    position += 1;
    tree_mask[token_tree_idx + cur_position] = true;
    int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
    if (parent_tb_idx == 0) {
      break;
    }

    int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
    for (cur_position = 0; cur_position < draft_token_num; cur_position++) {
      if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
        break;
      }
    }
    if (cur_position == draft_token_num) {
      printf(
          "ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token "
          "will be dropped.");
      break;
    }
  }
  positions[bid * draft_token_num + tid] = position + seq_len;

  int is_leaf = 0;
  for (int i = 1; i < draft_token_num; i++) {
    if (tree_mask[seq_tree_idx + i * (draft_token_num + seq_len) + seq_len + tid]) {
      is_leaf++;
    }
  }
  if (is_leaf == 1) {
    for (int i = 0; i < position; i++) {
      retrive_index[(bid * (draft_token_num) + tid) * (depth + 2) + position - i] =
          depends_order[i] + bid * draft_token_num;
    }
    retrive_index[(bid * (draft_token_num) + tid) * (depth + 2)] = bid * draft_token_num;
  }
}

224
225
226
227
228
229
230
231
232
233
void build_tree_kernel(
    at::Tensor parent_list,
    at::Tensor selected_index,
    at::Tensor verified_seq_len,
    at::Tensor tree_mask,
    at::Tensor positions,
    at::Tensor retrive_index,
    int64_t topk,
    int64_t depth,
    int64_t draft_token_num) {
234
235
236
237
238
239
240
241
  // TODO (ying) check shape
  // TODO (ying) check type
  int bs = parent_list.size(0);
  dim3 grid(bs);
  dim3 block(draft_token_num);
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  build_tree<<<grid, block, 0, stream>>>(
242
243
244
245
246
247
248
249
250
      static_cast<int64_t*>(parent_list.data_ptr()),
      static_cast<int64_t*>(selected_index.data_ptr()),
      static_cast<int32_t*>(verified_seq_len.data_ptr()),
      static_cast<bool*>(tree_mask.data_ptr()),
      static_cast<int64_t*>(positions.data_ptr()),
      static_cast<int64_t*>(retrive_index.data_ptr()),
      int32_t(topk),
      int32_t(depth),
      int32_t(draft_token_num));
251
}