two_batch_overlap.py 33.4 KB
Newer Older
1
2
from __future__ import annotations

3
import copy
4
import dataclasses
5
import logging
6
from dataclasses import replace
7
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
8
9
10
11

import torch

from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
12
13
14
15
16
from sglang.srt.layers.communicator import (
    CommunicateContext,
    CommunicateSummableTensorPairFn,
    ScatterMode,
)
17
18
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.utils import DeepEPMode
19
from sglang.srt.layers.quantization import deep_gemm_wrapper
20
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
21
22
23
24
25
from sglang.srt.model_executor.forward_batch_info import (
    ForwardBatch,
    ForwardMode,
    compute_position,
)
26
27
from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy
28
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
29
from sglang.srt.utils import BumpAllocator, get_bool_env_var
30

31
if TYPE_CHECKING:
32
    from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
33

34
35
36
37
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")

logger = logging.getLogger(__name__)

38
39
40
41

# -------------------------------- Compute Basic Info ---------------------------------------


42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def get_token_num_per_seq(
    forward_mode: ForwardMode,
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
):
    if forward_mode.is_target_verify():
        return spec_info.draft_token_num
    elif forward_mode.is_decode():
        return 1
    elif forward_mode.is_idle():
        return 0
    else:
        # For extend, we should not use `token_num_per_seq`.
        return None


57
58
59
60
61
# TODO: may smartly disable TBO when batch size is too small b/c it will slow down
def compute_split_seq_index(
    forward_mode: "ForwardMode",
    num_tokens: int,
    extend_lens: Optional[Sequence[int]],
62
    token_num_per_seq: Optional[int],
63
) -> Optional[int]:
64
    if forward_mode == ForwardMode.EXTEND:
65
        assert extend_lens is not None
66
        return _split_extend_seqs(extend_lens)
67
68
69
    elif forward_mode.is_target_verify() or forward_mode.is_decode():
        assert token_num_per_seq is not None
        return (num_tokens // token_num_per_seq) // 2
70
71
72
73
    elif forward_mode.is_idle():
        assert num_tokens == 0
        return 0
    else:
74
        raise NotImplementedError()
75
76


77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
    if extend_lens is None:
        return False

    vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
    left_sum = sum(extend_lens[:vanilla_split_seq_index])
    overall_sum = sum(extend_lens)
    threshold = global_server_args_dict["tbo_token_distribution_threshold"]
    assert threshold <= 0.5, f"{threshold=}"
    return left_sum < overall_sum * threshold or left_sum > overall_sum * (
        1 - threshold
    )


def _split_extend_seqs(arr: Sequence[int]) -> int:
    if _is_two_chunk_split_enabled(arr):
        return _split_array_by_cum_less_than_half(arr)

    return _split_array_by_balanced_sum(arr)


def _split_array_by_cum_less_than_half(arr: Sequence[int]) -> int:
    left_sum = 0
    overall_sum = sum(arr)
    half_sum = overall_sum // 2
    chosen_index = 0

    for i in range(len(arr)):
        left_sum += arr[i]
        if left_sum > half_sum:
            chosen_index = i
            break

    return chosen_index


def _split_array_by_balanced_sum(arr: Sequence[int]) -> int:
114
    overall_sum = sum(arr)
115
116
117
118
119
120
121
122
123
124
125
126
    left_sum = 0
    min_diff = float("inf")
    best_index = 0

    for i in range(1, len(arr)):
        left_sum += arr[i - 1]
        right_sum = overall_sum - left_sum
        diff = abs(left_sum - right_sum)
        if diff <= min_diff:
            min_diff = diff
            best_index = i
        else:
127
            break
128
129

    return best_index
130
131


132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def _update_device_and_sum_field_from_cpu_field(
    batch: ForwardBatch, cpu_field: str, device_field: str, sum_field: str = None
):
    cpu_value = getattr(batch, cpu_field, None)
    old_device_value = getattr(batch, device_field, None)
    if (
        cpu_value is None
        or old_device_value is None
        or not (isinstance(cpu_value, torch.Tensor) or isinstance(cpu_value, list))
    ):
        return

    new_device_value = (
        cpu_value
        if isinstance(cpu_value, torch.Tensor)
        else torch.tensor(cpu_value, dtype=old_device_value.dtype)
    ).to(device=global_server_args_dict["device"], non_blocking=True)
    setattr(batch, device_field, new_device_value)

    if sum_field is not None:
        sum_value = (
            cpu_value.sum().item()
            if isinstance(cpu_value, torch.Tensor)
            else sum(cpu_value)
        )
        setattr(batch, sum_field, sum_value)


160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
    if seq_index == 0:
        return 0

    offset = 0
    max_seq_len = min(seq_index, spec_info.seq_lens_cpu.shape[0])
    for i in range(max_seq_len):
        offset += (
            spec_info.seq_lens_cpu[i] + spec_info.draft_token_num
        ) * spec_info.draft_token_num
    return offset


def split_spec_info(
    spec_info: Optional[EagleVerifyInput],
    start_seq_index: int,
    end_seq_index: int,
    start_token_index: int,
    end_token_index: int,
):
    if spec_info is None:
        return None
    if spec_info.draft_token is not None:
        draft_token = spec_info.draft_token[start_token_index:end_token_index]
    else:
        draft_token = None
    if spec_info.custom_mask is not None and spec_info.draft_token is not None:
        custom_mask_start = _compute_mask_offset(start_seq_index, spec_info)
        if end_seq_index == spec_info.seq_lens_cpu.shape[0]:
            custom_mask_end = spec_info.custom_mask.shape[0]
        else:
            custom_mask_end = _compute_mask_offset(end_seq_index, spec_info)

        if custom_mask_end > custom_mask_start:
            custom_mask = spec_info.custom_mask[custom_mask_start:custom_mask_end]
        else:
            custom_mask = spec_info.custom_mask
    else:
        custom_mask = spec_info.custom_mask
    if spec_info.positions is not None:
        positions = spec_info.positions[start_token_index:end_token_index]
    else:
        positions = None
    if spec_info.retrive_index is not None:
        retrive_index = spec_info.retrive_index[start_seq_index:end_seq_index]
    else:
        retrive_index = None
    if spec_info.retrive_next_token is not None:
        retrive_next_token = spec_info.retrive_next_token[start_seq_index:end_seq_index]
    else:
        retrive_next_token = None
    if spec_info.retrive_next_sibling is not None:
        retrive_next_sibling = spec_info.retrive_next_sibling[
            start_seq_index:end_seq_index
        ]
    else:
        retrive_next_sibling = None
    if spec_info.retrive_cum_len is not None:
        retrive_cum_len = spec_info.retrive_cum_len[start_seq_index:end_seq_index]
    else:
        retrive_cum_len = None

    if spec_info.seq_lens_cpu is not None:
        seq_lens_cpu = spec_info.seq_lens_cpu[start_seq_index:end_seq_index]
    else:
        seq_lens_cpu = None
    if seq_lens_cpu is not None:
        seq_lens_sum = seq_lens_cpu.sum()
    else:
        seq_lens_sum = None
    output_spec_info = replace(
        spec_info,
        custom_mask=custom_mask,
        draft_token=draft_token,
        positions=positions,
        retrive_index=retrive_index,
        retrive_next_token=retrive_next_token,
        retrive_next_sibling=retrive_next_sibling,
        retrive_cum_len=retrive_cum_len,
        seq_lens_cpu=seq_lens_cpu,
        seq_lens_sum=seq_lens_sum,
    )
    return output_spec_info


245
246
247
248
def compute_split_token_index(
    split_seq_index: int,
    forward_mode: "ForwardMode",
    extend_seq_lens: Optional[Sequence[int]],
249
    token_num_per_seq: Optional[int],
250
) -> int:
251
    if forward_mode == ForwardMode.EXTEND:
252
        assert extend_seq_lens is not None
253
254
        if _is_two_chunk_split_enabled(extend_seq_lens):
            return sum(extend_seq_lens) // 2
255
        return sum(extend_seq_lens[:split_seq_index])
256
257
258
    elif forward_mode.is_target_verify() or forward_mode.is_decode():
        assert token_num_per_seq is not None
        return split_seq_index * token_num_per_seq
259
260
261
262
263
264
265
    elif forward_mode.is_idle():
        assert split_seq_index == 0
        return 0
    else:
        raise NotImplementedError


266
267
268
def compute_split_indices_for_cuda_graph_replay(
    forward_mode: ForwardMode,
    cuda_graph_num_tokens: int,
269
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
270
271
272
273
):
    forward_mode_for_tbo_split = (
        forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
    )
274
275
276
    token_num_per_seq = get_token_num_per_seq(
        forward_mode=forward_mode, spec_info=spec_info
    )
277
278
279
280
    tbo_split_seq_index = compute_split_seq_index(
        forward_mode=forward_mode_for_tbo_split,
        num_tokens=cuda_graph_num_tokens,
        extend_lens=None,
281
        token_num_per_seq=token_num_per_seq,
282
283
284
285
286
    )
    tbo_split_token_index = compute_split_token_index(
        split_seq_index=tbo_split_seq_index,
        forward_mode=forward_mode_for_tbo_split,
        extend_seq_lens=None,
287
        token_num_per_seq=token_num_per_seq,
288
289
290
291
    )
    return tbo_split_seq_index, tbo_split_token_index


292
293
294
# -------------------------------- Preparation ---------------------------------------


295
296
class TboCudaGraphRunnerPlugin:
    def __init__(self):
297
        self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
298
299
300
301

    def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
        if not global_server_args_dict["enable_two_batch_overlap"]:
            return
302
303
304
        token_num_per_seq = get_token_num_per_seq(
            forward_mode=batch.forward_mode, spec_info=batch.spec_info
        )
305
306
307
308
309

        batch.tbo_split_seq_index = compute_split_seq_index(
            forward_mode=batch.forward_mode,
            num_tokens=num_tokens,
            extend_lens=None,
310
            token_num_per_seq=token_num_per_seq,
311
312
313
314
        )
        # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
        assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"

315
316
317
318
319
320
321
322
        self._tbo_children_num_token_non_padded[...] = (
            TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded(batch)
        )

        TboForwardBatchPreparer.prepare_raw(
            batch,
            tbo_children_num_token_non_padded=self._tbo_children_num_token_non_padded,
        )
323
324

    def replay_prepare(
325
326
327
328
329
        self,
        forward_mode: ForwardMode,
        bs: int,
        num_token_non_padded: int,
        spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
330
    ):
331
332
333
        token_num_per_seq = get_token_num_per_seq(
            forward_mode=forward_mode, spec_info=spec_info
        )
334
335
336
        tbo_split_seq_index, tbo_split_token_index = (
            compute_split_indices_for_cuda_graph_replay(
                forward_mode=forward_mode,
337
338
                cuda_graph_num_tokens=bs * token_num_per_seq,
                spec_info=spec_info,
339
340
341
342
343
344
345
346
347
            )
        )

        self._tbo_children_num_token_non_padded[...] = (
            TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded_raw(
                tbo_split_token_index=tbo_split_token_index,
                num_token_non_padded=num_token_non_padded,
            )
        )
348
349
350
351


class TboDPAttentionPreparer:
    def prepare_all_gather(
352
353
354
355
356
        self,
        local_batch: ScheduleBatch,
        deepep_mode: DeepEPMode,
        enable_deepep_moe: bool,
        enable_two_batch_overlap: bool,
357
358
359
360
    ):
        self.enable_two_batch_overlap = enable_two_batch_overlap

        if local_batch is not None:
361
362
363
364
365
366
367
368
369
370
371
            token_num_per_seq = get_token_num_per_seq(
                forward_mode=local_batch.forward_mode, spec_info=local_batch.spec_info
            )

            if (
                local_batch.forward_mode.is_target_verify()
                or local_batch.forward_mode.is_decode()
            ):
                num_tokens = local_batch.batch_size() * token_num_per_seq
            else:
                num_tokens = local_batch.extend_num_tokens
372
373
            self.local_tbo_split_seq_index = compute_split_seq_index(
                forward_mode=local_batch.forward_mode,
374
                num_tokens=num_tokens,
375
                extend_lens=local_batch.extend_lens,
376
                token_num_per_seq=token_num_per_seq,
377
            )
378
            resolved_deepep_mode = deepep_mode.resolve(local_batch.is_extend_in_batch)
379
            local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
380
381
382
383
                (
                    local_batch.forward_mode.is_extend()
                    and not local_batch.forward_mode.is_target_verify()
                )
384
                and enable_deepep_moe
385
                and (resolved_deepep_mode == DeepEPMode.LOW_LATENCY)
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
            )
        else:
            self.local_tbo_split_seq_index = 0
            local_can_run_tbo = True

        local_forward_mode = self._compute_local_forward_mode(local_batch)

        return local_can_run_tbo, local_forward_mode

    def compute_output(self, partial_global_info):
        local_can_run_tbo_aggregated = min(partial_global_info[:, 0, 0].tolist())
        forward_modes = partial_global_info[:, 0, 1].tolist()

        global_forward_mode, forward_mode_agree = self._compute_global_forward_mode(
            forward_modes
        )

        can_run_tbo = (
            self.enable_two_batch_overlap
            and local_can_run_tbo_aggregated
            and forward_mode_agree
        )

        tbo_split_seq_index = self.local_tbo_split_seq_index if can_run_tbo else None
        global_forward_mode = global_forward_mode if can_run_tbo else None
        return tbo_split_seq_index, global_forward_mode

    @staticmethod
    def _compute_local_forward_mode(local_batch):
        return (
            local_batch.forward_mode if local_batch is not None else ForwardMode.IDLE
        ).value

    @staticmethod
    def _compute_global_forward_mode(forward_modes):
421
422
        forward_modes_excluding_idle = [
            x for x in forward_modes if x != ForwardMode.IDLE.value
423
        ]
424
425
426
427

        if not forward_modes_excluding_idle:
            return ForwardMode.IDLE, False

428
        forward_mode_agree = TboDPAttentionPreparer._is_all_same(
429
            forward_modes_excluding_idle
430
431
        )
        global_forward_mode = (
432
            ForwardMode(forward_modes_excluding_idle[0]) if forward_mode_agree else None
433
434
435
436
437
438
439
440
441
442
        )
        return global_forward_mode, forward_mode_agree

    @staticmethod
    def _is_all_same(x):
        return all(value == x[0] for value in x)


class TboForwardBatchPreparer:
    @classmethod
443
444
    def prepare(cls, batch: ForwardBatch, is_draft_worker: bool = False):
        if batch.tbo_split_seq_index is None or is_draft_worker:
445
446
            return

447
448
        tbo_children_num_token_non_padded = (
            cls.compute_tbo_children_num_token_non_padded(batch)
449
        )
450
451
452
453
454
455
456
457
458
459
460
        cls.prepare_raw(
            batch, tbo_children_num_token_non_padded=tbo_children_num_token_non_padded
        )

    @classmethod
    def prepare_raw(
        cls, batch: ForwardBatch, tbo_children_num_token_non_padded: torch.Tensor
    ):
        from sglang.srt.layers.attention.tbo_backend import TboAttnBackend

        tbo_split_token_index = cls._compute_split_token_index(batch)
461

462
463
464
465
466
        is_enable_two_chunk = (
            batch.forward_mode == ForwardMode.EXTEND
            and _is_two_chunk_split_enabled(batch.extend_seq_lens_cpu)
        )

467
468
469
        if _tbo_debug:
            logger.info(
                f"TboForwardBatchPreparer.prepare "
470
                f"is_enable_two_chunk={is_enable_two_chunk} "
471
472
                f"tbo_split_seq_index={batch.tbo_split_seq_index} "
                f"tbo_split_token_index={tbo_split_token_index} "
473
474
475
                f"extend_seq_lens={batch.extend_seq_lens_cpu} "
                f"bs={batch.batch_size} "
                f"forward_mode={batch.forward_mode}"
476
477
            )

478
479
480
        assert isinstance(batch.attn_backend, TboAttnBackend)
        attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children

481
482
483
484
        [out_num_token_non_padded_a, out_num_token_non_padded_b] = (
            tbo_children_num_token_non_padded
        )

485
486
487
488
489
        child_a = cls.filter_batch(
            batch,
            start_token_index=0,
            end_token_index=tbo_split_token_index,
            start_seq_index=0,
490
491
492
493
494
            end_seq_index=(
                batch.tbo_split_seq_index + 1
                if is_enable_two_chunk
                else batch.tbo_split_seq_index
            ),
495
            output_attn_backend=attn_backend_child_a,
496
            out_num_token_non_padded=out_num_token_non_padded_a,
497
498
499
500
501
502
503
504
        )
        child_b = cls.filter_batch(
            batch,
            start_token_index=tbo_split_token_index,
            end_token_index=batch.input_ids.shape[0],
            start_seq_index=batch.tbo_split_seq_index,
            end_seq_index=batch.batch_size,
            output_attn_backend=attn_backend_child_b,
505
            out_num_token_non_padded=out_num_token_non_padded_b,
506
507
        )

508
509
510
511
512
513
514
515
        if is_enable_two_chunk:
            cls.derive_fields_related_to_seq_len_for_two_chunk(
                batch,
                child_a=child_a,
                child_b=child_b,
                tbo_split_seq_index=batch.tbo_split_seq_index,
            )

516
517
518
        assert batch.tbo_children is None
        batch.tbo_children = [child_a, child_b]

519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
    @classmethod
    def derive_fields_related_to_seq_len_for_two_chunk(
        cls,
        batch: ForwardBatch,
        *,
        child_a: ForwardBatch,
        child_b: ForwardBatch,
        tbo_split_seq_index: int,
    ):
        extend_seq_lens_cpu = batch.extend_seq_lens_cpu
        overall_seq_lens_sum = sum(extend_seq_lens_cpu)
        half_seq_lens_sum = overall_seq_lens_sum // 2
        left_last_seq_token_num = half_seq_lens_sum - sum(
            extend_seq_lens_cpu[:tbo_split_seq_index]
        )
        right_first_seq_token_num = (
            extend_seq_lens_cpu[tbo_split_seq_index] - left_last_seq_token_num
        )

        # making deepcopy to be extra safe
        child_a.extend_seq_lens_cpu = copy.deepcopy(child_a.extend_seq_lens_cpu)
        child_a.extend_seq_lens_cpu[-1] = left_last_seq_token_num
        child_b.extend_seq_lens_cpu = copy.deepcopy(child_b.extend_seq_lens_cpu)
        child_b.extend_seq_lens_cpu[0] = right_first_seq_token_num
        for child in [child_a, child_b]:
            _update_device_and_sum_field_from_cpu_field(
                batch=child,
                cpu_field="extend_seq_lens_cpu",
                device_field="extend_seq_lens",
                sum_field="extend_num_tokens",
            )

        assert (
            child_a.extend_num_tokens == half_seq_lens_sum
        ), f"{child_a.extend_num_tokens=}, {half_seq_lens_sum=}"

        child_a.seq_lens_cpu = copy.deepcopy(child_a.seq_lens_cpu)
        child_a.seq_lens_cpu[-1] = (
            child_a.extend_seq_lens_cpu[-1] + child_a.extend_prefix_lens_cpu[-1]
        )
        _update_device_and_sum_field_from_cpu_field(
            batch=child_a,
            cpu_field="seq_lens_cpu",
            device_field="seq_lens",
            sum_field="seq_lens_sum",
        )

        child_b.extend_prefix_lens_cpu = copy.deepcopy(child_b.extend_prefix_lens_cpu)
        child_b.extend_prefix_lens_cpu[0] += left_last_seq_token_num
        _update_device_and_sum_field_from_cpu_field(
            batch=child_b,
            cpu_field="extend_prefix_lens_cpu",
            device_field="extend_prefix_lens",
            sum_field=None,
        )
        _, child_b.extend_start_loc = compute_position(
            global_server_args_dict["attention_backend"],
            child_b.extend_prefix_lens,
            child_b.extend_seq_lens,
            child_b.extend_num_tokens,
        )

581
582
583
584
585
586
587
588
589
590
    @classmethod
    def filter_batch(
        cls,
        batch: ForwardBatch,
        *,
        start_token_index: int,
        end_token_index: int,
        start_seq_index: int,
        end_seq_index: int,
        output_attn_backend: AttentionBackend,
591
        out_num_token_non_padded: torch.Tensor,
592
    ):
593
594
595
        assert (
            end_token_index >= start_token_index
        ), f"{end_token_index=}, {start_token_index=}, batch={batch}"
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
        num_tokens = batch.input_ids.shape[0]
        num_seqs = batch.batch_size

        output_dict = dict()

        for key in [
            "input_ids",
            "positions",
            "out_cache_loc",
        ]:
            old_value = getattr(batch, key)
            assert (
                old_value.shape[0] == num_tokens
            ), f"{key=} {old_value=} {num_tokens=} {batch=}"
            output_dict[key] = old_value[start_token_index:end_token_index]

        for key in [
            "req_pool_indices",
            "seq_lens",
            "seq_lens_cpu",
            "extend_seq_lens",
            "extend_prefix_lens",
            "extend_start_loc",
            "extend_prefix_lens_cpu",
            "extend_seq_lens_cpu",
            "extend_logprob_start_lens_cpu",
622
            "lora_ids",
623
624
625
626
        ]:
            old_value = getattr(batch, key)
            if old_value is None:
                continue
627
628
629
630
631
632
633
634
635
636
            elif batch.forward_mode.is_target_verify() and (
                key == "extend_seq_lens"
                or key == "extend_prefix_lens"
                or key == "extend_start_loc"
                or key == "extend_prefix_lens_cpu"
                or key == "extend_seq_lens_cpu"
                or key == "extend_logprob_start_lens_cpu"
            ):
                output_dict[key] = None
                continue
637
638
639
640
641
            assert (
                len(old_value) == num_seqs
            ), f"{key=} {old_value=} {num_seqs=} {batch=}"
            output_dict[key] = old_value[start_seq_index:end_seq_index]

642
643
644
645
646
647
648
649
650
        spec_info = getattr(batch, "spec_info")
        output_spec_info = split_spec_info(
            spec_info=spec_info,
            start_token_index=start_token_index,
            end_token_index=end_token_index,
            start_seq_index=start_seq_index,
            end_seq_index=end_seq_index,
        )
        output_dict["spec_info"] = output_spec_info
651
652
        for key in [
            "forward_mode",
653
            "is_extend_in_batch",
654
655
656
657
658
659
660
661
662
            "return_logprob",
            "req_to_token_pool",
            "token_to_kv_pool",
            "can_run_dp_cuda_graph",
            "global_forward_mode",
            "spec_algorithm",
            "capture_hidden_mode",
            "padded_static_len",
            "mrope_positions",  # only used by qwen2-vl, thus not care
663
            "split_index",  # for split prefill
664
            "orig_seq_lens",  # only used by qwen-1m, thus not care
665
666
        ]:
            output_dict[key] = getattr(batch, key)
667
668
669
670
671
        if not batch.forward_mode.is_target_verify():
            assert (
                _compute_extend_num_tokens(batch.input_ids, batch.forward_mode)
                == batch.extend_num_tokens
            ), f"{batch=}"
672
673
674
675
676
        extend_num_tokens = _compute_extend_num_tokens(
            output_dict["input_ids"], output_dict["forward_mode"]
        )

        # TODO improve, e.g. unify w/ `init_raw`
677
678
679
680
        if (
            global_server_args_dict["moe_dense_tp_size"] == 1
            and batch.gathered_buffer is not None
        ):
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
            sum_len = end_token_index - start_token_index
            gathered_buffer = torch.zeros(
                (sum_len, batch.gathered_buffer.shape[1]),
                dtype=batch.gathered_buffer.dtype,
                device=batch.gathered_buffer.device,
            )
        else:
            gathered_buffer = None

        output_dict.update(
            dict(
                batch_size=end_seq_index - start_seq_index,
                seq_lens_sum=(
                    output_dict["seq_lens_cpu"].sum()
                    if "seq_lens_cpu" in output_dict
                    else None
                ),
                extend_num_tokens=extend_num_tokens,
                attn_backend=output_attn_backend,
700
                num_token_non_padded=out_num_token_non_padded,
701
702
703
704
705
                tbo_split_seq_index=None,
                tbo_parent_token_range=(start_token_index, end_token_index),
                tbo_children=None,
                global_num_tokens_gpu=None,
                global_num_tokens_cpu=None,
Cheng Wan's avatar
Cheng Wan committed
706
                dp_padding_mode=None,
707
708
709
710
711
712
713
714
715
716
                gathered_buffer=gathered_buffer,
                global_num_tokens_for_logprob_gpu=None,
                global_num_tokens_for_logprob_cpu=None,
                sampling_info=None,
                # For logits and logprobs post processing, thus we do not care
                temp_scaled_logprobs=False,
                temperature=None,
                top_p_normalized_logprobs=False,
                top_p=None,
                mm_inputs=None,
717
718
                top_logprobs_nums=None,
                token_ids_logprobs=None,
719
                next_token_logits_buffer=None,
720
721
722
723
724
725
726
727
728
729
730
731
732
733
            )
        )

        errors = []
        for field in dataclasses.fields(ForwardBatch):
            if getattr(batch, field.name) is not None and field.name not in output_dict:
                errors.append(
                    f"Field {field.name} has value, but is not yet supported (value={getattr(batch, field.name)} batch={batch})"
                )
        if len(errors) > 0:
            raise Exception(f"{len(errors)} errors happen:\n" + "\n\n".join(errors))

        return ForwardBatch(**output_dict)

734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
    @classmethod
    def compute_tbo_children_num_token_non_padded(cls, batch: ForwardBatch):
        return cls.compute_tbo_children_num_token_non_padded_raw(
            tbo_split_token_index=cls._compute_split_token_index(batch),
            num_token_non_padded=len(batch.input_ids),
        )

    @classmethod
    def compute_tbo_children_num_token_non_padded_raw(
        cls, tbo_split_token_index: int, num_token_non_padded: int
    ):
        # TODO we may make padding on both sub-batches to make it slightly more balanced
        value_a = min(tbo_split_token_index, num_token_non_padded)
        value_b = max(0, num_token_non_padded - tbo_split_token_index)
        return torch.tensor([value_a, value_b], dtype=torch.int32).to(
            device=global_server_args_dict["device"], non_blocking=True
        )

    @classmethod
    def _compute_split_token_index(cls, batch: ForwardBatch):
754
755
756
        token_num_per_seq = get_token_num_per_seq(
            forward_mode=batch.forward_mode, spec_info=batch.spec_info
        )
757
758
759
760
        return compute_split_token_index(
            split_seq_index=batch.tbo_split_seq_index,
            forward_mode=batch.forward_mode,
            extend_seq_lens=batch.extend_seq_lens_cpu,
761
            token_num_per_seq=token_num_per_seq,
762
763
        )

764
765

def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
766
767
768
769
770
    if (
        forward_mode.is_decode()
        or forward_mode.is_idle()
        or forward_mode.is_target_verify()
    ):
771
        return None
772
773
    elif forward_mode.is_extend():
        return input_ids.shape[0]
774
775
776
777
778
779
780
781
782
783
784
785
    raise NotImplementedError


# -------------------------------- Execution ---------------------------------------


def model_forward_maybe_tbo(
    layers,
    enable_tbo: bool,
    positions: torch.Tensor,
    forward_batch: ForwardBatch,
    hidden_states: torch.Tensor,
786
    input_data_scatter_mode: ScatterMode,
787
    residual: Optional[torch.Tensor],
788
    zero_allocator: Optional[BumpAllocator] = None,
789
790
791
792
793
794
):
    inputs = dict(
        positions=positions,
        hidden_states=hidden_states,
        forward_batch=forward_batch,
        residual=residual,
Yi Zhang's avatar
Yi Zhang committed
795
        zero_allocator=zero_allocator,
796
    )
797
    layer_input_scatter_mode = layers[0].layer_scatter_modes.layer_input_mode
798
799
800
801
    operations_strategy = OperationsStrategy.init_new_tbo(
        layers, forward_batch.global_forward_mode
    )
    if enable_tbo:
802
803
804
805
806
807
        return _model_forward_tbo(
            inputs=inputs,
            operations_strategy=operations_strategy,
            input_data_scatter_mode=input_data_scatter_mode,
            layer_input_scatter_mode=layer_input_scatter_mode,
        )
808
809
810
811
    else:
        return _model_forward_non_tbo(inputs, operations_strategy)


812
813
814
815
816
817
818
819
820
821
822
def _model_forward_tbo(
    inputs,
    operations_strategy: OperationsStrategy,
    input_data_scatter_mode: ScatterMode,
    layer_input_scatter_mode: ScatterMode,
):
    inputs_arr = _model_forward_tbo_split_inputs(
        **inputs,
        input_data_scatter_mode=input_data_scatter_mode,
        layer_input_scatter_mode=layer_input_scatter_mode,
    )
823
824
    del inputs

825
826
827
    with deep_gemm_wrapper.configure_deep_gemm_num_sms(
        operations_strategy.deep_gemm_num_sms
    ):
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
        outputs_arr = execute_overlapped_operations(
            inputs_arr=inputs_arr,
            operations_arr=[operations_strategy.operations] * 2,
            delta_stages=[0, operations_strategy.tbo_delta_stages],
        )

    return _model_forward_tbo_merge_outputs(*outputs_arr)


def _model_forward_non_tbo(inputs, operations_strategy: OperationsStrategy):
    outputs = execute_operations(inputs, operations_strategy.operations)
    return outputs["hidden_states"], outputs["residual"]


def _model_forward_tbo_split_inputs(
    hidden_states: torch.Tensor,
    residual: torch.Tensor,
    positions: torch.Tensor,
    forward_batch: ForwardBatch,
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
    zero_allocator: Optional[BumpAllocator],
    input_data_scatter_mode: ScatterMode,
    layer_input_scatter_mode: ScatterMode,
) -> List[Dict]:
    tbo_splitter_scatter_mode = ScatterMode.TP_ATTN_FULL
    context = CommunicateContext.init_new()

    hidden_states, residual = CommunicateSummableTensorPairFn.execute(
        hidden_states_input_mode=input_data_scatter_mode,
        residual_input_mode=input_data_scatter_mode,
        output_mode=tbo_splitter_scatter_mode,
        hidden_states=hidden_states,
        residual=residual,
        forward_batch=forward_batch,
        context=context,
    )

    inputs_arr = _model_forward_tbo_split_inputs_raw(
        hidden_states=hidden_states,
        residual=residual,
        positions=positions,
        forward_batch=forward_batch,
        zero_allocator=zero_allocator,
    )

    def _post_transform(hidden_states, residual, forward_batch, **kwargs):
        hidden_states, residual = CommunicateSummableTensorPairFn.execute(
            hidden_states_input_mode=tbo_splitter_scatter_mode,
            residual_input_mode=tbo_splitter_scatter_mode,
            output_mode=layer_input_scatter_mode,
            hidden_states=hidden_states,
            residual=residual,
            forward_batch=forward_batch,
            context=context,
        )
        return dict(
            hidden_states=hidden_states,
            residual=residual,
            forward_batch=forward_batch,
            **kwargs,
        )

    return [_post_transform(**inputs) for inputs in inputs_arr]


def _model_forward_tbo_split_inputs_raw(
    hidden_states: torch.Tensor,
    residual: torch.Tensor,
    positions: torch.Tensor,
    forward_batch: ForwardBatch,
    zero_allocator: Optional[BumpAllocator],
898
899
900
901
902
903
904
905
906
907
) -> List[Dict]:
    return [
        dict(
            **_model_forward_filter_inputs(
                hidden_states=hidden_states,
                residual=residual,
                positions=positions,
                output_forward_batch=output_forward_batch,
                tbo_subbatch_index=tbo_subbatch_index,
            ),
908
909
910
911
912
            **(
                dict(zero_allocator=zero_allocator)
                if zero_allocator is not None
                else {}
            ),
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
        )
        for tbo_subbatch_index, output_forward_batch in enumerate(
            forward_batch.tbo_children
        )
    ]


def _model_forward_filter_inputs(
    hidden_states: torch.Tensor,
    residual: torch.Tensor,
    positions: torch.Tensor,
    output_forward_batch: ForwardBatch,
    tbo_subbatch_index: int,
) -> Dict:
    token_slice = slice(*output_forward_batch.tbo_parent_token_range)
    return dict(
        hidden_states=hidden_states[token_slice],
        residual=None if residual is None else residual[token_slice],
        positions=positions[token_slice],
        forward_batch=output_forward_batch,
        tbo_subbatch_index=tbo_subbatch_index,
    )


def _model_forward_tbo_merge_outputs(output_a, output_b):
    def _handle_key(name):
        value_a = output_a[name]
        value_b = output_b[name]
        assert (value_a is None) == (value_b is None)
        if value_a is None:
            return None
        return torch.concat([value_a, value_b], dim=0)

    return _handle_key("hidden_states"), _handle_key("residual")


# -------------------------------- Utilities and wrappers ---------------------------------------


class MaybeTboDeepEPDispatcher:
    def __init__(self, **kwargs):
        num_inner_dispatchers = (
            2 if global_server_args_dict["enable_two_batch_overlap"] else 1
        )
        self._inners = [
            DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
        ]

    def _execute(self, name, tbo_subbatch_index: Optional[int] = None, **kwargs):
        return getattr(self._inners[tbo_subbatch_index or 0], name)(**kwargs)

964
    def dispatch(self, **kwargs) -> DispatchOutput:
965
966
967
968
969
970
971
972
        return self._execute("dispatch", **kwargs)

    def dispatch_a(self, **kwargs):
        return self._execute("dispatch_a", **kwargs)

    def dispatch_b(self, **kwargs):
        return self._execute("dispatch_b", **kwargs)

973
    def combine(self, **kwargs) -> torch.Tensor:
974
975
976
977
978
979
980
        return self._execute("combine", **kwargs)

    def combine_a(self, **kwargs):
        return self._execute("combine_a", **kwargs)

    def combine_b(self, **kwargs):
        return self._execute("combine_b", **kwargs)