"vscode:/vscode.git/clone" did not exist on "dafe46710b5f4a93bfdceb84c7201d1c83423394"
Unverified Commit a071dc40 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny add stage assertions to DeepEPDispatcher to avoid misuse (#6467)

parent a40aecc5
import logging import logging
from dataclasses import dataclass
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.managers.expert_distribution import ( from sglang.srt.managers.expert_distribution import (
...@@ -18,7 +19,7 @@ try: ...@@ -18,7 +19,7 @@ try:
except ImportError: except ImportError:
use_deepep = False use_deepep = False
from enum import IntEnum, auto from enum import Enum, IntEnum, auto
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
...@@ -627,6 +628,14 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -627,6 +628,14 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
) )
@dataclass
class _Stage(Enum):
INITIAL = auto()
AFTER_DISPATCH_A = auto()
AFTER_DISPATCH_B = auto()
AFTER_COMBINE_A = auto()
class DeepEPDispatcher: class DeepEPDispatcher:
def __init__( def __init__(
self, self,
...@@ -665,6 +674,8 @@ class DeepEPDispatcher: ...@@ -665,6 +674,8 @@ class DeepEPDispatcher:
**common_kwargs, **common_kwargs,
) )
self._stage = _Stage.INITIAL
def dispatch(self, *args, **kwargs) -> Tuple: def dispatch(self, *args, **kwargs) -> Tuple:
self.dispatch_a(*args, **kwargs) self.dispatch_a(*args, **kwargs)
ret = self.dispatch_b() ret = self.dispatch_b()
...@@ -677,6 +688,7 @@ class DeepEPDispatcher: ...@@ -677,6 +688,7 @@ class DeepEPDispatcher:
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
forward_mode: ForwardMode = None, forward_mode: ForwardMode = None,
): ):
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
inner_state = self._get_impl(forward_mode).dispatch_a( inner_state = self._get_impl(forward_mode).dispatch_a(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
...@@ -685,6 +697,7 @@ class DeepEPDispatcher: ...@@ -685,6 +697,7 @@ class DeepEPDispatcher:
self._dispatch_intermediate_state = forward_mode, inner_state self._dispatch_intermediate_state = forward_mode, inner_state
def dispatch_b(self): def dispatch_b(self):
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
forward_mode, inner_state = self._dispatch_intermediate_state forward_mode, inner_state = self._dispatch_intermediate_state
del self._dispatch_intermediate_state del self._dispatch_intermediate_state
return self._get_impl(forward_mode).dispatch_b(*inner_state) return self._get_impl(forward_mode).dispatch_b(*inner_state)
...@@ -701,6 +714,7 @@ class DeepEPDispatcher: ...@@ -701,6 +714,7 @@ class DeepEPDispatcher:
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
forward_mode: ForwardMode, forward_mode: ForwardMode,
): ):
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
inner_state = self._get_impl(forward_mode).combine_a( inner_state = self._get_impl(forward_mode).combine_a(
hidden_states=hidden_states, hidden_states=hidden_states,
topk_idx=topk_idx, topk_idx=topk_idx,
...@@ -709,6 +723,7 @@ class DeepEPDispatcher: ...@@ -709,6 +723,7 @@ class DeepEPDispatcher:
self._combine_intermediate_state = forward_mode, inner_state self._combine_intermediate_state = forward_mode, inner_state
def combine_b(self): def combine_b(self):
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
forward_mode, inner_state = self._combine_intermediate_state forward_mode, inner_state = self._combine_intermediate_state
del self._combine_intermediate_state del self._combine_intermediate_state
return self._get_impl(forward_mode).combine_b(*inner_state) return self._get_impl(forward_mode).combine_b(*inner_state)
...@@ -721,3 +736,7 @@ class DeepEPDispatcher: ...@@ -721,3 +736,7 @@ class DeepEPDispatcher:
return self._low_latency_dispatcher return self._low_latency_dispatcher
else: else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
def _update_stage(self, old_stage, new_stage):
assert self._stage == old_stage
self._stage = new_stage
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment