scheduler_pp_mixin.py 16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
from typing import List, Optional

from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.utils import GenerationBatchResult
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
from sglang.srt.utils import DynamicGradMode, point_to_point_pyobj


class SchedulerPPMixin:

    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

                # (last rank) send the outputs to the next step
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids = result.next_token_ids
                        if self.cur_batch.return_logprob:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                    "extend_input_len_per_req": result.extend_input_len_per_req,
                                    "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
                                }
                                | (
                                    {
                                        f"logits_output.{k}": v
                                        for k, v in result.logits_output.__dict__.items()
                                    }
                                    if result.logits_output is not None
                                    else {}
                                )
                            )
                        else:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                }
                            )
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
                    logits_output_args = {
                        k[len("logits_output.") :]: v
                        for k, v in next_pp_outputs.tensors.items()
                        if k.startswith("logits_output.")
                    }
                    if len(logits_output_args) > 0:
                        logits_output = LogitsProcessorOutput(**logits_output_args)
                    else:
                        logits_output = None

                    output_result = GenerationBatchResult.from_pp_proxy(
                        logits_output=logits_output,
                        next_pp_outputs=next_pp_outputs,
                        can_run_cuda_graph=result.can_run_cuda_graph,
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

                # (not last rank)
                if not self.pp_group.is_last_rank:
                    # carry the outputs to the next stage
                    # send the outputs from the last round to let the next stage worker run post processing
                    if pp_outputs:
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                    # send out reqs to the next stage
                    dp_offset = self.attn_dp_rank * self.attn_tp_size
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
                            self.world_group.device_group,
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        # FIXME(lsyin): remove this assert
                        assert result.pp_hidden_states_proxy_tensors.tensors is not None
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
                # When the server is idle, do self-check and re-init some states
                self.self_check_during_idle()

    @DynamicGradMode()
    def event_loop_pp_disagg_prefill(self):
        """
        An event loop for the prefill server in pipeline parallelism.

        Rules:
        1. Each stage runs in the same order and is notified by the previous stage.
        2. Each send/recv operation is blocking and matched by the neighboring stage.

        Regular Schedule:
        ====================================================================
        Stage i                   | Stage i+1
        send ith req              | recv ith req
        send ith proxy            | recv ith proxy
        send prev (i+1)th carry   | recv prev (i+1)th carry
        ====================================================================

        Prefill Server Schedule:
        ====================================================================
        Stage i                        | Stage i+1
        send ith req                   | recv ith req
        send ith bootstrap req         | recv ith bootstrap req
        send ith transferred req       | recv ith transferred req
        send ith proxy                 | recv ith proxy
        send prev (i+1)th carry        | recv prev (i+1)th carry
        send prev (i+1)th release req  | recv prev (i+1)th release req
        ====================================================================

        There are two additional elements compared to the regular schedule:

        1. Bootstrap Requests:
            a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
            b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
            c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.

        2. Transferred Requests + Release Requests:
            a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
            b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
            c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
        """
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        pp_outputs: Optional[PPProxyTensors] = None

        # Either success or failed
        bootstrapped_rids: List[str] = []
        transferred_rids: List[str] = []
        release_rids: Optional[List[str]] = None

        # transferred microbatch
        tmbs = [None] * self.pp_size

        ENABLE_RELEASE = True  # For debug

        while True:
            server_is_idle = True

            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()

                self.process_input_requests(recv_reqs)

                if self.pp_group.is_first_rank:
                    # First rank, pop the bootstrap reqs from the bootstrap queue
                    bootstrapped_reqs, failed_reqs = (
                        self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
                            return_failed_reqs=True
                        )
                    )
                    bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
                        req.rid for req in failed_reqs
                    ]
                    self.waiting_queue.extend(bootstrapped_reqs)
                else:
                    # Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
                    bootstrapped_rids = self.recv_pyobj_from_prev_stage()
                    bootstrapped_reqs = (
                        self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
                            rids_to_check=bootstrapped_rids
                        )
                    )
                    self.waiting_queue.extend(bootstrapped_reqs)

                if self.pp_group.is_first_rank:
                    transferred_rids = self.get_transferred_rids()
                # if other ranks,
                else:
                    # 1. recv previous stage's transferred reqs info
                    prev_transferred_rids = self.recv_pyobj_from_prev_stage()
                    # 2. get the current stage's transferred reqs info
                    curr_transferred_rids = self.get_transferred_rids()
                    # 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
                    transferred_rids = list(
                        set(prev_transferred_rids) & set(curr_transferred_rids)
                    )

                tmbs[mb_id] = transferred_rids

                self.process_prefill_chunk()
                mbs[mb_id] = self.get_new_batch_prefill()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

                # send the outputs to the next step
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids = result.next_token_ids
                        pp_outputs = PPProxyTensors(
                            {
                                "next_token_ids": next_token_ids,
                            }
                        )
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                if ENABLE_RELEASE:
                    if self.pp_group.is_last_rank:
                        # At the last stage, all stages has reached the consensus to release memory for transferred_rids
                        release_rids = transferred_rids
                        # send to the first rank
                        self.send_pyobj_to_next_stage(release_rids)

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                next_release_rids = None

                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
                    output_result = GenerationBatchResult(
                        logits_output=None,
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
                        extend_input_len_per_req=None,
                        extend_logprob_start_len_per_req=None,
                        can_run_cuda_graph=result.can_run_cuda_graph,
                    )
                    self.process_batch_result_disagg_prefill(
                        mbs[next_mb_id], output_result
                    )

                    last_mbs[next_mb_id] = mbs[next_mb_id]

                if ENABLE_RELEASE:
                    if tmbs[next_mb_id] is not None:
                        # recv consensus rids from the previous rank
                        next_release_rids = self.recv_pyobj_from_prev_stage()
                        self.process_disagg_prefill_inflight_queue(next_release_rids)

                # carry the outputs to the next stage
                if not self.pp_group.is_last_rank:
                    if pp_outputs:
                        # send the outputs from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )
                    if ENABLE_RELEASE:
                        if release_rids is not None:
                            self.send_pyobj_to_next_stage(release_rids)

                if not self.pp_group.is_last_rank:
                    # send out reqs to the next stage
                    self.send_pyobj_to_next_stage(recv_reqs)
                    self.send_pyobj_to_next_stage(bootstrapped_rids)
                    self.send_pyobj_to_next_stage(transferred_rids)

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        # FIXME(lsyin): remove this assert
                        assert result.pp_hidden_states_proxy_tensors.tensors is not None
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs
                release_rids = next_release_rids

                self.running_batch.batch_is_full = False

            if not ENABLE_RELEASE:
                if len(self.disagg_prefill_inflight_queue) > 0:
                    self.process_disagg_prefill_inflight_queue()

            # When the server is idle, self-check and re-init some states
            if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
                self.check_memory()
                self.check_tree_cache()
                self.new_token_ratio = self.init_new_token_ratio