single_batch_overlap.py 4.59 KB
Newer Older
1
2
3
4
5
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional

import torch

6
from sglang.srt.layers import deep_gemm_wrapper
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from sglang.srt.layers.moe import get_moe_runner_backend
from sglang.srt.layers.moe.utils import is_sbo_enabled
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_int_env_var

if TYPE_CHECKING:
    from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE


class SboFlags:
    # TODO may have: "enable_dispatch_shared_one_stream_overlap", "enable_dispatch_gateup_gemm_two_stream_overlap", ...

    @classmethod
    def enable_combine_down_gemm_two_stream_overlap(cls):
        return (
            is_sbo_enabled()
            # currently only cutedsl backend supports it
            and get_moe_runner_backend().is_flashinfer_cutedsl()
        )

    @classmethod
    def enable_combine_shared_two_stream_overlap(cls):
        return is_sbo_enabled()

    @classmethod
    def fuse_shared_experts_inside_sbo(cls):
        # TODO after antgroup's PR, should be `... or cls.enable_dispatch_shared_one_stream_overlap()`
        return cls.enable_combine_shared_two_stream_overlap()


@dataclass
class CombineOverlapArgs:
    # this "overlap" flag means overlapping with down gemm, not the general two-stream overlap
    overlap: bool
    stream: torch.cuda.Stream
    wait_event: torch.cuda.Event
    num_sms: int
    signal: Optional[torch.Tensor] = None
45
    threshold: int = 0
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82


@dataclass
class DownGemmOverlapArgs:
    num_sms: int
    signal: torch.Tensor
    start_event: torch.cuda.Event


def execute_sbo(
    forward_shared_experts: Callable[[], Any],
    experts: "DeepEPMoE",
    hidden_states: torch.Tensor,
    topk_idx: torch.Tensor,
    topk_weights: torch.Tensor,
    forward_batch: ForwardBatch,
    alt_stream: Optional = None,
):
    dispatch_output = experts.dispatch(
        hidden_states, topk_idx, topk_weights, forward_batch
    )

    combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
        _compute_overlap_args(dispatch_output, alt_stream)
    )

    hidden_states = experts.moe_impl(
        dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
    )
    if (e := meta_overlap_args.get("record_event_after_down")) is not None:
        e.record()

    if SboFlags.enable_combine_shared_two_stream_overlap():
        # TODO reduce sm for non-deepgemm
        with deep_gemm_wrapper.configure_deep_gemm_num_sms(
            meta_overlap_args["compute_num_sms"]
        ):
83
            forward_shared_experts()
84
85
86
87
88
89
90
91
92

    hidden_states = experts.combine(
        hidden_states,
        dispatch_output.topk_idx,
        dispatch_output.topk_weights,
        forward_batch,
        overlap_args=combine_overlap_args,
    )

93
    return hidden_states
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148


def _compute_overlap_args(dispatch_output, alt_stream):
    if not (
        SboFlags.enable_combine_down_gemm_two_stream_overlap()
        or SboFlags.enable_combine_shared_two_stream_overlap()
    ):
        return None, None, {}

    hidden_states = dispatch_output.hidden_states_fp8
    if isinstance(hidden_states, tuple):
        hidden_states = hidden_states[0]

    num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape

    total_num_sms = torch.cuda.get_device_properties(
        device="cuda"
    ).multi_processor_count
    communicate_num_sms = get_int_env_var("SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS", 32)
    compute_num_sms = total_num_sms - communicate_num_sms

    assert alt_stream is not None
    combine_wait_event = torch.cuda.Event()
    combine_overlap_args = CombineOverlapArgs(
        overlap=False,
        num_sms=communicate_num_sms,
        stream=alt_stream,
        wait_event=combine_wait_event,
    )
    meta_overlap_args = dict(
        compute_num_sms=compute_num_sms,
    )
    down_gemm_overlap_args = None

    if SboFlags.enable_combine_down_gemm_two_stream_overlap():
        # TODO use zero_allocator to remove this `torch.zeros` call
        # NOTE ours v2 use uint32 not int32 currently
        combine_signal = torch.zeros(
            num_local_experts, dtype=torch.uint32, device=hidden_states.device
        )

        down_gemm_overlap_args = DownGemmOverlapArgs(
            signal=combine_signal,
            start_event=combine_wait_event,
            num_sms=compute_num_sms,
        )
        combine_overlap_args.overlap = True
        combine_overlap_args.signal = combine_signal
        combine_overlap_args.threshold = compute_num_sms
    else:
        meta_overlap_args |= dict(
            record_event_after_down=combine_wait_event,
        )

    return combine_overlap_args, down_gemm_overlap_args, meta_overlap_args