single_batch_overlap.py 5.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

17
18
19
20
21
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional

import torch

22
from sglang.srt.layers import deep_gemm_wrapper
23
from sglang.srt.layers.moe import get_moe_runner_backend
24
from sglang.srt.layers.moe.topk import TopKOutput
25
26
27
28
from sglang.srt.layers.moe.utils import is_sbo_enabled
from sglang.srt.utils import get_int_env_var

if TYPE_CHECKING:
29
    from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60


class SboFlags:
    # TODO may have: "enable_dispatch_shared_one_stream_overlap", "enable_dispatch_gateup_gemm_two_stream_overlap", ...

    @classmethod
    def enable_combine_down_gemm_two_stream_overlap(cls):
        return (
            is_sbo_enabled()
            # currently only cutedsl backend supports it
            and get_moe_runner_backend().is_flashinfer_cutedsl()
        )

    @classmethod
    def enable_combine_shared_two_stream_overlap(cls):
        return is_sbo_enabled()

    @classmethod
    def fuse_shared_experts_inside_sbo(cls):
        # TODO after antgroup's PR, should be `... or cls.enable_dispatch_shared_one_stream_overlap()`
        return cls.enable_combine_shared_two_stream_overlap()


@dataclass
class CombineOverlapArgs:
    # this "overlap" flag means overlapping with down gemm, not the general two-stream overlap
    overlap: bool
    stream: torch.cuda.Stream
    wait_event: torch.cuda.Event
    num_sms: int
    signal: Optional[torch.Tensor] = None
61
    threshold: int = 0
62
63
64
65
66
67
68
69
70
71
72


@dataclass
class DownGemmOverlapArgs:
    num_sms: int
    signal: torch.Tensor
    start_event: torch.cuda.Event


def execute_sbo(
    forward_shared_experts: Callable[[], Any],
73
    experts: FusedMoE,
74
    hidden_states: torch.Tensor,
75
76
    topk_output: TopKOutput,
    alt_stream: Optional[torch.cuda.Stream] = None,
77
    disable_sbo: bool = False,
78
):
79
80
81

    dispatch_output = experts.dispatcher.dispatch(
        hidden_states=hidden_states, topk_output=topk_output
82
83
84
    )

    combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
85
        _compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
86
87
    )

88
    combine_input = experts.run_moe_core(
89
90
91
92
93
        dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
    )
    if (e := meta_overlap_args.get("record_event_after_down")) is not None:
        e.record()

94
    if (not disable_sbo) and SboFlags.enable_combine_shared_two_stream_overlap():
95
96
97
98
        # TODO reduce sm for non-deepgemm
        with deep_gemm_wrapper.configure_deep_gemm_num_sms(
            meta_overlap_args["compute_num_sms"]
        ):
99
            forward_shared_experts()
100

101
    hidden_states = experts.dispatcher.combine(combine_input=combine_input)
102

103
    return hidden_states
104
105


106
107
def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo):
    if disable_sbo or not (
108
109
110
111
112
        SboFlags.enable_combine_down_gemm_two_stream_overlap()
        or SboFlags.enable_combine_shared_two_stream_overlap()
    ):
        return None, None, {}

113
    hidden_states = dispatch_output.hidden_states
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

    num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape

    total_num_sms = torch.cuda.get_device_properties(
        device="cuda"
    ).multi_processor_count
    communicate_num_sms = get_int_env_var("SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS", 32)
    compute_num_sms = total_num_sms - communicate_num_sms

    assert alt_stream is not None
    combine_wait_event = torch.cuda.Event()
    combine_overlap_args = CombineOverlapArgs(
        overlap=False,
        num_sms=communicate_num_sms,
        stream=alt_stream,
        wait_event=combine_wait_event,
    )
    meta_overlap_args = dict(
        compute_num_sms=compute_num_sms,
    )
    down_gemm_overlap_args = None

    if SboFlags.enable_combine_down_gemm_two_stream_overlap():
        # TODO use zero_allocator to remove this `torch.zeros` call
        # NOTE ours v2 use uint32 not int32 currently
        combine_signal = torch.zeros(
            num_local_experts, dtype=torch.uint32, device=hidden_states.device
        )

        down_gemm_overlap_args = DownGemmOverlapArgs(
            signal=combine_signal,
            start_event=combine_wait_event,
            num_sms=compute_num_sms,
        )
        combine_overlap_args.overlap = True
        combine_overlap_args.signal = combine_signal
        combine_overlap_args.threshold = compute_num_sms
    else:
        meta_overlap_args |= dict(
            record_event_after_down=combine_wait_event,
        )

    return combine_overlap_args, down_gemm_overlap_args, meta_overlap_args