Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a5a03209
"src/vscode:/vscode.git/clone" did not exist on "c4a3b09a36fb22b949dc7d56f447206d5fd3b0d5"
Unverified
Commit
a5a03209
authored
Sep 06, 2025
by
Cheng Wan
Committed by
GitHub
Sep 06, 2025
Browse files
Fix circular import (#10107)
parent
21af5c04
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
33 deletions
+25
-33
python/sglang/srt/layers/moe/moe_runner/base.py
python/sglang/srt/layers/moe/moe_runner/base.py
+8
-18
python/sglang/srt/layers/moe/moe_runner/runner.py
python/sglang/srt/layers/moe/moe_runner/runner.py
+1
-5
python/sglang/srt/layers/moe/moe_runner/triton.py
python/sglang/srt/layers/moe/moe_runner/triton.py
+10
-4
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+5
-5
python/sglang/srt/layers/quantization/w4afp8.py
python/sglang/srt/layers/quantization/w4afp8.py
+1
-1
No files found.
python/sglang/srt/layers/moe/moe_runner/base.py
View file @
a5a03209
...
...
@@ -6,12 +6,6 @@ from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
import
torch
from
sglang.srt.layers.moe.token_dispatcher
import
(
CombineInput
,
CombineInputFormat
,
DispatchOutput
,
DispatchOutputFormat
,
)
from
sglang.srt.layers.moe.utils
import
MoeA2ABackend
,
MoeRunnerBackend
if
TYPE_CHECKING
:
...
...
@@ -20,6 +14,12 @@ if TYPE_CHECKING:
TritonRunnerInput
,
TritonRunnerOutput
,
)
from
sglang.srt.layers.moe.token_dispatcher
import
(
CombineInput
,
CombineInputFormat
,
DispatchOutput
,
DispatchOutputFormat
,
)
@
dataclass
...
...
@@ -143,17 +143,12 @@ class PermuteMethodPool:
:param runner_backend_name: The MoeRunnerBackend name.
:param permute_func: The permute function to register.
"""
# TODO: check if registration is valid
key
=
(
dispatch_output_name
,
runner_backend_name
)
if
key
in
cls
.
_pre_permute_methods
:
raise
ValueError
(
f
"Pre-permute method for
{
dispatch_output_name
}
to
{
runner_backend_name
}
is already registered."
)
assert
DispatchOutputFormat
(
dispatch_output_name
),
f
"Invalid dispatch output name:
{
dispatch_output_name
}
"
assert
MoeRunnerBackend
(
runner_backend_name
),
f
"Invalid runner backend name:
{
runner_backend_name
}
"
cls
.
_pre_permute_methods
[
key
]
=
permute_func
@
classmethod
...
...
@@ -170,17 +165,12 @@ class PermuteMethodPool:
:param combine_input_name: The CombineInputFormat name.
:param permute_func: The permute function to register.
"""
# TODO: check if registration is valid
key
=
(
runner_backend_name
,
combine_input_name
)
if
key
in
cls
.
_post_permute_methods
:
raise
ValueError
(
f
"Post-permute method for
{
runner_backend_name
}
to
{
combine_input_name
}
is already registered."
)
assert
MoeRunnerBackend
(
runner_backend_name
),
f
"Invalid runner backend name:
{
runner_backend_name
}
"
assert
CombineInputFormat
(
combine_input_name
),
f
"Invalid combine input name:
{
combine_input_name
}
"
cls
.
_post_permute_methods
[
key
]
=
permute_func
@
classmethod
...
...
python/sglang/srt/layers/moe/moe_runner/runner.py
View file @
a5a03209
...
...
@@ -10,15 +10,11 @@ from sglang.srt.layers.moe.moe_runner.base import (
PermuteMethodPool
,
)
from
sglang.srt.layers.moe.moe_runner.triton
import
TritonRunnerCore
from
sglang.srt.layers.moe.token_dispatcher.base
import
(
CombineInput
,
CombineInputFormat
,
DispatchOutput
,
)
from
sglang.srt.layers.moe.utils
import
get_moe_a2a_backend
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.moe_runner.base
import
MoeQuantInfo
from
sglang.srt.layers.moe.token_dispatcher.base
import
CombineInput
,
DispatchOutput
from
sglang.srt.layers.moe.utils
import
MoeRunnerBackend
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/layers/moe/moe_runner/triton.py
View file @
a5a03209
...
...
@@ -18,13 +18,16 @@ from sglang.srt.layers.moe.moe_runner.base import (
register_post_permute
,
register_pre_permute
,
)
from
sglang.srt.layers.moe.token_dispatcher
import
(
StandardCombineInput
,
StandardDispatchOutput
,
)
from
sglang.srt.layers.moe.utils
import
MoeRunnerBackend
from
sglang.srt.utils
import
cpu_has_amx_support
,
is_cpu
,
is_cuda
,
is_hip
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.token_dispatcher.standard
import
(
StandardCombineInput
,
StandardDispatchOutput
,
)
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
...
...
@@ -325,6 +328,7 @@ def fused_experts_none_to_triton(
runner_config
:
MoeRunnerConfig
,
)
->
StandardCombineInput
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.token_dispatcher.standard
import
StandardCombineInput
output
=
fused_experts
(
hidden_states
=
dispatch_output
.
hidden_states
,
...
...
@@ -437,6 +441,8 @@ def post_permute_triton_to_standard(
# NOTE: this is dead code as a fused func for standard format is registered.
# This is left here for testing and examples.
from
sglang.srt.layers.moe.token_dispatcher.standard
import
StandardCombineInput
return
StandardCombineInput
(
hidden_states
=
runner_output
.
hidden_states
,
)
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
a5a03209
...
...
@@ -42,11 +42,6 @@ from enum import Enum, IntEnum, auto
import
torch
import
torch.distributed
as
dist
from
sglang.srt.layers.moe.ep_moe.kernels
import
(
deepep_permute_triton_kernel
,
deepep_post_reorder_triton_kernel
,
deepep_run_moe_deep_preprocess
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
is_hip
()
...
...
@@ -439,6 +434,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
topk_idx
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
):
from
sglang.srt.layers.moe.ep_moe.kernels
import
(
deepep_post_reorder_triton_kernel
,
)
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
or
_use_aiter
:
output
=
hidden_states
else
:
...
...
python/sglang/srt/layers/quantization/w4afp8.py
View file @
a5a03209
...
...
@@ -9,7 +9,6 @@ from torch.nn.parameter import Parameter
from
sglang.srt.distributed.parallel_state
import
get_moe_expert_parallel_world_size
from
sglang.srt.layers.linear
import
LinearBase
,
UnquantizedLinearMethod
from
sglang.srt.layers.moe.cutlass_w4a8_moe
import
cutlass_w4a8_moe
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
QuantizationConfig
,
...
...
@@ -297,6 +296,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
dispatch_output
:
StandardDispatchOutput
,
)
->
CombineInput
:
from
sglang.srt.layers.moe.cutlass_w4a8_moe
import
cutlass_w4a8_moe
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
x
=
dispatch_output
.
hidden_states
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment