"src/vscode:/vscode.git/clone" did not exist on "c4a3b09a36fb22b949dc7d56f447206d5fd3b0d5"
Unverified Commit a5a03209 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Fix circular import (#10107)

parent 21af5c04
......@@ -6,12 +6,6 @@ from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
import torch
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
CombineInputFormat,
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
if TYPE_CHECKING:
......@@ -20,6 +14,12 @@ if TYPE_CHECKING:
TritonRunnerInput,
TritonRunnerOutput,
)
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
CombineInputFormat,
DispatchOutput,
DispatchOutputFormat,
)
@dataclass
......@@ -143,17 +143,12 @@ class PermuteMethodPool:
:param runner_backend_name: The MoeRunnerBackend name.
:param permute_func: The permute function to register.
"""
# TODO: check if registration is valid
key = (dispatch_output_name, runner_backend_name)
if key in cls._pre_permute_methods:
raise ValueError(
f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered."
)
assert DispatchOutputFormat(
dispatch_output_name
), f"Invalid dispatch output name: {dispatch_output_name}"
assert MoeRunnerBackend(
runner_backend_name
), f"Invalid runner backend name: {runner_backend_name}"
cls._pre_permute_methods[key] = permute_func
@classmethod
......@@ -170,17 +165,12 @@ class PermuteMethodPool:
:param combine_input_name: The CombineInputFormat name.
:param permute_func: The permute function to register.
"""
# TODO: check if registration is valid
key = (runner_backend_name, combine_input_name)
if key in cls._post_permute_methods:
raise ValueError(
f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered."
)
assert MoeRunnerBackend(
runner_backend_name
), f"Invalid runner backend name: {runner_backend_name}"
assert CombineInputFormat(
combine_input_name
), f"Invalid combine input name: {combine_input_name}"
cls._post_permute_methods[key] = permute_func
@classmethod
......
......@@ -10,15 +10,11 @@ from sglang.srt.layers.moe.moe_runner.base import (
PermuteMethodPool,
)
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
from sglang.srt.layers.moe.token_dispatcher.base import (
CombineInput,
CombineInputFormat,
DispatchOutput,
)
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput, DispatchOutput
from sglang.srt.layers.moe.utils import MoeRunnerBackend
logger = logging.getLogger(__name__)
......
......@@ -18,13 +18,16 @@ from sglang.srt.layers.moe.moe_runner.base import (
register_post_permute,
register_pre_permute,
)
from sglang.srt.layers.moe.token_dispatcher import (
StandardCombineInput,
StandardDispatchOutput,
)
from sglang.srt.layers.moe.utils import MoeRunnerBackend
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
StandardDispatchOutput,
)
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support()
......@@ -325,6 +328,7 @@ def fused_experts_none_to_triton(
runner_config: MoeRunnerConfig,
) -> StandardCombineInput:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
output = fused_experts(
hidden_states=dispatch_output.hidden_states,
......@@ -437,6 +441,8 @@ def post_permute_triton_to_standard(
# NOTE: this is dead code as a fused func for standard format is registered.
# This is left here for testing and examples.
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
return StandardCombineInput(
hidden_states=runner_output.hidden_states,
)
......@@ -42,11 +42,6 @@ from enum import Enum, IntEnum, auto
import torch
import torch.distributed as dist
from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_permute_triton_kernel,
deepep_post_reorder_triton_kernel,
deepep_run_moe_deep_preprocess,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
......@@ -439,6 +434,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
):
from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_post_reorder_triton_kernel,
)
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
output = hidden_states
else:
......
......@@ -9,7 +9,6 @@ from torch.nn.parameter import Parameter
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
......@@ -297,6 +296,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment