eagle_worker.py 12.4 KB
Newer Older
Yineng Zhang's avatar
Yineng Zhang committed
1
2
import logging
import time
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from typing import List, Optional, Union

import torch

from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import (
    CaptureHiddenMode,
    ForwardBatch,
    ForwardMode,
)
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
Yineng Zhang's avatar
Yineng Zhang committed
17
18
19
20
21
22
23
24
25
26
27
28
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
    EAGLEDraftCudaGraphRunner,
)
from sglang.srt.speculative.eagle_utils import (
    EagleDraftInput,
    EagleVerifyInput,
    assign_draft_cache_locs,
    fast_topk,
    select_top_k_tokens,
)

logger = logging.getLogger(__name__)
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54


class EAGLEWorker(TpModelWorker):

    def __init__(
        self,
        server_args: ServerArgs,
        gpu_id: int,
        tp_rank: int,
        dp_rank: Optional[int],
        nccl_port: int,
        target_worker: TpModelWorker,
    ):
        # Do not capture cuda graph in `super().__init__()`
        # We will capture it later
        backup_disable_cuda_graph = server_args.disable_cuda_graph
        server_args.disable_cuda_graph = True
        super().__init__(
            gpu_id=gpu_id,
            tp_rank=tp_rank,
            server_args=server_args,
            nccl_port=nccl_port,
            dp_rank=dp_rank,
            is_draft_worker=True,
        )
        self.target_worker = target_worker
55
        self.finish_extend_len = []
56

Yineng Zhang's avatar
Yineng Zhang committed
57
58
59
60
61
        # Parse arguments
        self.topk = server_args.speculative_eagle_topk
        self.speculative_num_steps = server_args.speculative_num_steps
        self.server_args = server_args

62
63
64
65
66
        # Share the embedding and lm_head
        embed, head = self.target_worker.model_runner.model.get_embed_and_head()
        self.model_runner.model.set_embed_and_head(embed, head)
        self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph

Yineng Zhang's avatar
Yineng Zhang committed
67
        # Create multi-step attn backends and cuda graph runners
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        if server_args.attention_backend == "flashinfer":
            from sglang.srt.layers.attention.flashinfer_backend import (
                FlashInferMultiStepDraftBackend,
            )

            self.draft_attn_backend = FlashInferMultiStepDraftBackend(
                self.model_runner,
                self.topk,
                self.speculative_num_steps,
            )
        elif server_args.attention_backend == "triton":
            from sglang.srt.layers.attention.triton_backend import (
                TritonMultiStepDraftBackend,
            )

            self.draft_attn_backend = TritonMultiStepDraftBackend(
                self.model_runner,
                self.topk,
                self.speculative_num_steps,
            )
        else:
            raise ValueError(
                f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
            )
92

Yineng Zhang's avatar
Yineng Zhang committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        self.model_runner.draft_attn_backend = self.draft_attn_backend
        self.init_cuda_graphs()

    def init_cuda_graphs(self):
        """Capture cuda graphs."""
        self.cuda_graph_runner = None

        if self.server_args.disable_cuda_graph:
            return

        tic = time.time()
        logger.info("Capture cuda graph begin. This can take up to several minutes.")
        self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
        logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
107
108
109

    def forward_batch_speculative_generation(self, batch: ScheduleBatch):
        if batch.forward_mode.is_decode():
Lianmin Zheng's avatar
Lianmin Zheng committed
110
            # Draft
Yineng Zhang's avatar
Yineng Zhang committed
111
            spec_info: EagleVerifyInput = self.draft(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
112
113

            # Verify
114
115
116
117
118
            (
                next_draft_input,
                logits_output,
                verified_id,
                self.finish_extend_len,
Lianmin Zheng's avatar
Lianmin Zheng committed
119
                accept_length_cpu,
120
                model_worker_batch,
Yineng Zhang's avatar
Yineng Zhang committed
121
            ) = self.verify(batch, spec_info)
122
123
124
            batch.spec_info = next_draft_input
            # if it is None, means all requsets are finished
            if batch.spec_info.verified_id is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
125
126
127
128
129
130
131
                self.forward_draft_extend_after_decode(batch)
            return (
                logits_output,
                verified_id,
                model_worker_batch,
                sum(accept_length_cpu),
            )
132
133

        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
134
135
            # Forward with the target model and get hidden states.
            # We need the full hidden states to prefill the KV cache of the draft model.
136
            model_worker_batch = batch.get_model_worker_batch()
Lianmin Zheng's avatar
Lianmin Zheng committed
137
            model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
138
139
140
            logits_output, next_token_ids = self.target_worker.forward_batch_generation(
                model_worker_batch
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
141
142

            # Forward with the draft model.
Yineng Zhang's avatar
Yineng Zhang committed
143
144
145
146
            batch.spec_info = EagleDraftInput(
                hidden_states=logits_output.hidden_states,
                verified_id=next_token_ids,
            )
147
            self.forward_draft_extend(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
148
            return logits_output, next_token_ids, model_worker_batch, 0
149

Yineng Zhang's avatar
Yineng Zhang committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    def draft(self, batch: ScheduleBatch):
        self._set_mem_pool(batch, self.model_runner)

        # Parse args
        num_seqs = batch.batch_size()
        spec_info = batch.spec_info

        # Allocate cache locations
        out_cache_loc = batch.alloc_token_slots(
            num_seqs * self.topk * self.speculative_num_steps
        )
        assign_draft_cache_locs[(num_seqs,)](
            batch.req_pool_indices,
            batch.req_to_token_pool.req_to_token,
            batch.seq_lens,
            out_cache_loc,
            batch.req_to_token_pool.req_to_token.shape[1],
            self.topk,
            self.speculative_num_steps,
        )

        batch.out_cache_loc = out_cache_loc
        batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
        spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)

        # Get forward batch
        spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
        model_worker_batch = batch.get_model_worker_batch()
        forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
        can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
            forward_batch
        )

        if can_cuda_graph:
            score_list, token_list, parents_list = self.cuda_graph_runner.replay(
                forward_batch
            )
        else:
            # Initialize attention backend
            self.draft_attn_backend.init_forward_metadata(forward_batch)

            # Run forward steps
            score_list, token_list, parents_list = self.draft_forward(forward_batch)

        ret = EagleVerifyInput.create(
            spec_info.verified_id,
            score_list,
            token_list,
            parents_list,
            batch.seq_lens,
            batch.seq_lens_sum,
            self.topk,
            self.speculative_num_steps,
            self.server_args.speculative_num_draft_tokens,
204
            batch.sampling_info.is_all_greedy,
Yineng Zhang's avatar
Yineng Zhang committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        )

        # Free cache locations
        batch.token_to_kv_pool.free(out_cache_loc)
        self._set_mem_pool(batch, self.target_worker.model_runner)
        return ret

    def draft_forward(self, forward_batch: ForwardBatch):
        # Parse args
        spec_info = forward_batch.spec_info
        out_cache_loc = forward_batch.out_cache_loc
        topk_p, topk_index, hidden_states = (
            spec_info.topk_p,
            spec_info.topk_index,
            spec_info.hidden_states,
        )

        # Return values
        score_list: List[torch.Tensor] = []
        token_list: List[torch.Tensor] = []
        parents_list: List[torch.Tensor] = []

        # Forward multiple steps
        scores = None
        for i in range(self.speculative_num_steps):
            input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
                i, topk_p, topk_index, hidden_states, scores, self.topk
            )
            score_list.append(tree_info[0])
            token_list.append(tree_info[1])
            parents_list.append(tree_info[2])

            # Set inputs
            forward_batch.input_ids = input_ids
            forward_batch.out_cache_loc = out_cache_loc[
                forward_batch.batch_size
                * self.topk
                * i : forward_batch.batch_size
                * self.topk
                * (i + 1)
            ]
            forward_batch.positions.add_(1)
            forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
            spec_info.hidden_states = hidden_states

            # Run forward
            logits_output = self.model_runner.model.forward(
                forward_batch.input_ids, forward_batch.positions, forward_batch
            )
            probs = torch.softmax(logits_output.next_token_logits, dim=-1)
            topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
            hidden_states = logits_output.hidden_states

        return score_list, token_list, parents_list

    def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
        spec_info.prepare_for_verify(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
262
        batch.forward_mode = ForwardMode.TARGET_VERIFY
Yineng Zhang's avatar
Yineng Zhang committed
263
        batch.spec_info = spec_info
264
265
266
267
        model_worker_batch = batch.get_model_worker_batch()
        logits_output, _ = self.target_worker.forward_batch_generation(
            model_worker_batch, skip_sample=True
        )
Yineng Zhang's avatar
Yineng Zhang committed
268
269
        spec_info.hidden_states = logits_output.hidden_states
        res = spec_info.verify(batch, logits_output)
270
271
272
        batch.forward_mode = ForwardMode.DECODE
        return res + (model_worker_batch,)

Yineng Zhang's avatar
Yineng Zhang committed
273
274
275
276
277
278
279
280
281
282
    def forward_draft_extend(self, batch: ScheduleBatch):
        self._set_mem_pool(batch, self.model_runner)
        batch.spec_info.prepare_for_extend(batch)
        batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
        model_worker_batch = batch.get_model_worker_batch()
        forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
        logits_output = self.model_runner.forward(forward_batch)
        self.capture_for_decode(logits_output, forward_batch)
        self._set_mem_pool(batch, self.target_worker.model_runner)

Lianmin Zheng's avatar
Lianmin Zheng committed
283
    def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
284
285
286
        batch.token_to_kv_pool = runner.token_to_kv_pool
        batch.req_to_token_pool = runner.req_to_token_pool

Lianmin Zheng's avatar
Lianmin Zheng committed
287
    def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
288
        seq_lens_backup = batch.seq_lens
289
        req_pool_indices_backup = batch.req_pool_indices
290

Lianmin Zheng's avatar
Lianmin Zheng committed
291
        self._set_mem_pool(batch, self.model_runner)
292
        batch.forward_mode = ForwardMode.DRAFT_EXTEND
Yineng Zhang's avatar
Yineng Zhang committed
293
        batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
294
        batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
295
296
297
298
        model_worker_batch = batch.get_model_worker_batch()
        forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
        logits_output = self.model_runner.forward(forward_batch)
        self.capture_for_decode(logits_output, forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
299
        self._set_mem_pool(batch, self.target_worker.model_runner)
300

301
302
303
304
        # Restore backup.
        # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
        batch.forward_mode = ForwardMode.DECODE
        batch.seq_lens = seq_lens_backup
305
        batch.req_pool_indices = req_pool_indices_backup
306

Lianmin Zheng's avatar
Lianmin Zheng committed
307
308
309
    def capture_for_decode(
        self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
    ):
Yineng Zhang's avatar
Yineng Zhang committed
310
        probs = torch.softmax(logits_output.next_token_logits, dim=-1)
Lianmin Zheng's avatar
Lianmin Zheng committed
311
        spec_info = forward_batch.spec_info
Yineng Zhang's avatar
Yineng Zhang committed
312
        spec_info.topk_p, spec_info.topk_index = fast_topk(probs, self.topk, dim=-1)
Lianmin Zheng's avatar
Lianmin Zheng committed
313
        spec_info.hidden_states = logits_output.hidden_states
314
315
316
317
318
319

    # Don't support prefix share now.
    def finish_request(self, reqs: Union[Req, List[Req]]):
        if not isinstance(reqs, List):
            reqs = [reqs]
        for req in reqs:
320
321
            if req.rid not in self.finish_extend_len:
                continue
322
323
324
325
326
327
328
329
330
331
332
            req_len = (
                len(req.origin_input_ids)
                + len(req.output_ids)
                - self.finish_extend_len[req.rid]
                - 1
            )
            kv_indices = self.model_runner.req_to_token_pool.req_to_token[
                req.req_pool_idx
            ][:req_len]
            self.model_runner.token_to_kv_pool.free(kv_indices)
            self.model_runner.req_to_token_pool.free(req.req_pool_idx)