Unverified Commit a071dc40 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny add stage assertions to DeepEPDispatcher to avoid misuse (#6467)

parent a40aecc5
import logging
from dataclasses import dataclass
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.managers.expert_distribution import (
......@@ -18,7 +19,7 @@ try:
except ImportError:
use_deepep = False
from enum import IntEnum, auto
from enum import Enum, IntEnum, auto
from typing import Optional, Tuple, Union
import torch
......@@ -627,6 +628,14 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
)
@dataclass
class _Stage(Enum):
INITIAL = auto()
AFTER_DISPATCH_A = auto()
AFTER_DISPATCH_B = auto()
AFTER_COMBINE_A = auto()
class DeepEPDispatcher:
def __init__(
self,
......@@ -665,6 +674,8 @@ class DeepEPDispatcher:
**common_kwargs,
)
self._stage = _Stage.INITIAL
def dispatch(self, *args, **kwargs) -> Tuple:
self.dispatch_a(*args, **kwargs)
ret = self.dispatch_b()
......@@ -677,6 +688,7 @@ class DeepEPDispatcher:
topk_weights: torch.Tensor,
forward_mode: ForwardMode = None,
):
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
inner_state = self._get_impl(forward_mode).dispatch_a(
hidden_states=hidden_states,
topk_idx=topk_idx,
......@@ -685,6 +697,7 @@ class DeepEPDispatcher:
self._dispatch_intermediate_state = forward_mode, inner_state
def dispatch_b(self):
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
forward_mode, inner_state = self._dispatch_intermediate_state
del self._dispatch_intermediate_state
return self._get_impl(forward_mode).dispatch_b(*inner_state)
......@@ -701,6 +714,7 @@ class DeepEPDispatcher:
topk_weights: torch.Tensor,
forward_mode: ForwardMode,
):
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
inner_state = self._get_impl(forward_mode).combine_a(
hidden_states=hidden_states,
topk_idx=topk_idx,
......@@ -709,6 +723,7 @@ class DeepEPDispatcher:
self._combine_intermediate_state = forward_mode, inner_state
def combine_b(self):
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
forward_mode, inner_state = self._combine_intermediate_state
del self._combine_intermediate_state
return self._get_impl(forward_mode).combine_b(*inner_state)
......@@ -721,3 +736,7 @@ class DeepEPDispatcher:
return self._low_latency_dispatcher
else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
def _update_stage(self, old_stage, new_stage):
assert self._stage == old_stage
self._stage = new_stage
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment