Unverified Commit a5a03209 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Fix circular import (#10107)

parent 21af5c04
...@@ -6,12 +6,6 @@ from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard ...@@ -6,12 +6,6 @@ from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
import torch import torch
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
CombineInputFormat,
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -20,6 +14,12 @@ if TYPE_CHECKING: ...@@ -20,6 +14,12 @@ if TYPE_CHECKING:
TritonRunnerInput, TritonRunnerInput,
TritonRunnerOutput, TritonRunnerOutput,
) )
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
CombineInputFormat,
DispatchOutput,
DispatchOutputFormat,
)
@dataclass @dataclass
...@@ -143,17 +143,12 @@ class PermuteMethodPool: ...@@ -143,17 +143,12 @@ class PermuteMethodPool:
:param runner_backend_name: The MoeRunnerBackend name. :param runner_backend_name: The MoeRunnerBackend name.
:param permute_func: The permute function to register. :param permute_func: The permute function to register.
""" """
# TODO: check if registration is valid
key = (dispatch_output_name, runner_backend_name) key = (dispatch_output_name, runner_backend_name)
if key in cls._pre_permute_methods: if key in cls._pre_permute_methods:
raise ValueError( raise ValueError(
f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered." f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered."
) )
assert DispatchOutputFormat(
dispatch_output_name
), f"Invalid dispatch output name: {dispatch_output_name}"
assert MoeRunnerBackend(
runner_backend_name
), f"Invalid runner backend name: {runner_backend_name}"
cls._pre_permute_methods[key] = permute_func cls._pre_permute_methods[key] = permute_func
@classmethod @classmethod
...@@ -170,17 +165,12 @@ class PermuteMethodPool: ...@@ -170,17 +165,12 @@ class PermuteMethodPool:
:param combine_input_name: The CombineInputFormat name. :param combine_input_name: The CombineInputFormat name.
:param permute_func: The permute function to register. :param permute_func: The permute function to register.
""" """
# TODO: check if registration is valid
key = (runner_backend_name, combine_input_name) key = (runner_backend_name, combine_input_name)
if key in cls._post_permute_methods: if key in cls._post_permute_methods:
raise ValueError( raise ValueError(
f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered." f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered."
) )
assert MoeRunnerBackend(
runner_backend_name
), f"Invalid runner backend name: {runner_backend_name}"
assert CombineInputFormat(
combine_input_name
), f"Invalid combine input name: {combine_input_name}"
cls._post_permute_methods[key] = permute_func cls._post_permute_methods[key] = permute_func
@classmethod @classmethod
......
...@@ -10,15 +10,11 @@ from sglang.srt.layers.moe.moe_runner.base import ( ...@@ -10,15 +10,11 @@ from sglang.srt.layers.moe.moe_runner.base import (
PermuteMethodPool, PermuteMethodPool,
) )
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
from sglang.srt.layers.moe.token_dispatcher.base import (
CombineInput,
CombineInputFormat,
DispatchOutput,
)
from sglang.srt.layers.moe.utils import get_moe_a2a_backend from sglang.srt.layers.moe.utils import get_moe_a2a_backend
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput, DispatchOutput
from sglang.srt.layers.moe.utils import MoeRunnerBackend from sglang.srt.layers.moe.utils import MoeRunnerBackend
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -18,13 +18,16 @@ from sglang.srt.layers.moe.moe_runner.base import ( ...@@ -18,13 +18,16 @@ from sglang.srt.layers.moe.moe_runner.base import (
register_post_permute, register_post_permute,
register_pre_permute, register_pre_permute,
) )
from sglang.srt.layers.moe.token_dispatcher import (
StandardCombineInput,
StandardDispatchOutput,
)
from sglang.srt.layers.moe.utils import MoeRunnerBackend from sglang.srt.layers.moe.utils import MoeRunnerBackend
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
StandardDispatchOutput,
)
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
...@@ -325,6 +328,7 @@ def fused_experts_none_to_triton( ...@@ -325,6 +328,7 @@ def fused_experts_none_to_triton(
runner_config: MoeRunnerConfig, runner_config: MoeRunnerConfig,
) -> StandardCombineInput: ) -> StandardCombineInput:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
output = fused_experts( output = fused_experts(
hidden_states=dispatch_output.hidden_states, hidden_states=dispatch_output.hidden_states,
...@@ -437,6 +441,8 @@ def post_permute_triton_to_standard( ...@@ -437,6 +441,8 @@ def post_permute_triton_to_standard(
# NOTE: this is dead code as a fused func for standard format is registered. # NOTE: this is dead code as a fused func for standard format is registered.
# This is left here for testing and examples. # This is left here for testing and examples.
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
return StandardCombineInput( return StandardCombineInput(
hidden_states=runner_output.hidden_states, hidden_states=runner_output.hidden_states,
) )
...@@ -42,11 +42,6 @@ from enum import Enum, IntEnum, auto ...@@ -42,11 +42,6 @@ from enum import Enum, IntEnum, auto
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_permute_triton_kernel,
deepep_post_reorder_triton_kernel,
deepep_run_moe_deep_preprocess,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
...@@ -439,6 +434,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -439,6 +434,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_idx: torch.Tensor, topk_idx: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
): ):
from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_post_reorder_triton_kernel,
)
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter: if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
output = hidden_states output = hidden_states
else: else:
......
...@@ -9,7 +9,6 @@ from torch.nn.parameter import Parameter ...@@ -9,7 +9,6 @@ from torch.nn.parameter import Parameter
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
QuantizationConfig, QuantizationConfig,
...@@ -297,6 +296,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -297,6 +296,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
dispatch_output: StandardDispatchOutput, dispatch_output: StandardDispatchOutput,
) -> CombineInput: ) -> CombineInput:
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states x = dispatch_output.hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment