Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
7e6191c0
Unverified
Commit
7e6191c0
authored
Oct 21, 2025
by
Atream
Committed by
GitHub
Oct 21, 2025
Browse files
init support for KTransformers Heterogeneous Computing (#11487)
Co-authored-by:
Jianwei Dong
<
1913953267@qq.com
>
parent
6f9b66bd
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
547 additions
and
17 deletions
+547
-17
python/sglang/srt/environ.py
python/sglang/srt/environ.py
+8
-0
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+25
-3
python/sglang/srt/layers/quantization/compressed_tensors/__init__.py
...ng/srt/layers/quantization/compressed_tensors/__init__.py
+7
-0
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+10
-1
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+408
-8
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+9
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+21
-5
python/sglang/srt/models/qwen3_next.py
python/sglang/srt/models/qwen3_next.py
+2
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+57
-0
No files found.
python/sglang/srt/environ.py
View file @
7e6191c0
...
...
@@ -229,6 +229,14 @@ class Envs:
SGLANG_IMAGE_MAX_PIXELS
=
EnvInt
(
16384
*
28
*
28
)
SGLANG_RESIZE_RESAMPLE
=
EnvStr
(
""
)
# Ktransformers
SGLANG_KT_MOE_NUM_GPU_EXPERTS
=
EnvInt
(
None
)
SGLANG_KT_MOE_CPUINFER
=
EnvInt
(
None
)
SGLANG_KT_THREADPOOL_COUNT
=
EnvInt
(
None
)
SGLANG_KT_MOE_AMX_WEIGHT_PATH
=
EnvStr
(
None
)
SGLANG_KT_AMX_METHOD
=
EnvStr
(
None
)
SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE
=
EnvInt
(
None
)
# fmt: on
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
7e6191c0
...
...
@@ -33,6 +33,11 @@ from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase
,
QuantizationConfig
,
)
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
CompressedTensorsWNA16AMXEPMoEMethod
,
CompressedTensorsWNA16AMXMoEMethod
,
CompressedTensorsWNA16MoEMethod
,
)
from
sglang.srt.layers.quantization.fp8
import
Fp8MoEMethod
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptNvFp4FusedMoEMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
...
...
@@ -150,7 +155,6 @@ class FusedMoE(torch.nn.Module):
with_bias
=
False
,
):
super
().
__init__
()
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
...
...
@@ -227,6 +231,8 @@ class FusedMoE(torch.nn.Module):
if
not
use_weight_loader_fused
else
self
.
weight_loader_fused
),
intermediate_size_full
=
intermediate_size
,
top_k
=
top_k
,
with_bias
=
with_bias
,
)
...
...
@@ -542,6 +548,18 @@ class FusedMoE(torch.nn.Module):
if
expert_id
==
-
1
:
return
if
isinstance
(
self
.
quant_method
,
(
CompressedTensorsWNA16MoEMethod
,
CompressedTensorsWNA16AMXMoEMethod
,
CompressedTensorsWNA16AMXEPMoEMethod
,
),
):
if
self
.
quant_method
.
num_gpu_experts
!=
-
1
:
if
expert_id
>=
self
.
quant_method
.
num_gpu_experts
:
return
self
.
_weight_loader_impl
(
param
=
param
,
loaded_weight
=
loaded_weight
,
...
...
@@ -568,7 +586,12 @@ class FusedMoE(torch.nn.Module):
loaded_weight
.
t
().
contiguous
()
if
(
self
.
quant_method
.
__class__
.
__name__
==
"CompressedTensorsWNA16MoEMethod"
in
[
"CompressedTensorsWNA16MarlinMoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
,
"CompressedTensorsWNA16AMXMoEMethod"
,
"CompressedTensorsWNA16AMXEPMoEMethod"
,
]
)
else
loaded_weight
)
...
...
@@ -827,7 +850,6 @@ class FusedMoE(torch.nn.Module):
dispatch_output
=
dispatch_output
,
**
kwargs
,
)
final_hidden_states
=
self
.
dispatcher
.
combine
(
combine_input
)
# TODO: should we add some conditions here?
...
...
python/sglang/srt/layers/quantization/compressed_tensors/__init__.py
View file @
7e6191c0
class
scalar_types
:
uint4b8
=
"uint4b8"
uint8b128
=
"uint8b128"
WNA16_SUPPORTED_TYPES_MAP
=
{
4
:
scalar_types
.
uint4b8
,
8
:
scalar_types
.
uint8b128
}
WNA16_SUPPORTED_BITS
=
list
(
WNA16_SUPPORTED_TYPES_MAP
.
keys
())
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
7e6191c0
...
...
@@ -19,11 +19,13 @@ from compressed_tensors.quantization import (
)
from
pydantic
import
BaseModel
from
sglang.srt.environ
import
envs
from
sglang.srt.layers.quantization.base_config
import
(
LinearMethodBase
,
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.compressed_tensors
import
WNA16_SUPPORTED_BITS
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
# noqa: E501
CompressedTensorsMoEMethod
,
)
...
...
@@ -38,6 +40,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
is_activation_quantization_format
,
should_ignore_layer
,
)
from
sglang.srt.layers.quantization.fp8
import
Fp8LinearMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
try
:
...
...
@@ -76,6 +79,7 @@ class DeviceCapability(NamedTuple):
class
CompressedTensorsConfig
(
QuantizationConfig
):
DeepSeekFP8Config
=
None
def
__init__
(
self
,
...
...
@@ -129,6 +133,10 @@ class CompressedTensorsConfig(QuantizationConfig):
):
return
UnquantizedLinearMethod
()
if
isinstance
(
layer
,
LinearBase
):
if
CompressedTensorsConfig
.
DeepSeekFP8Config
is
not
None
:
return
Fp8LinearMethod
(
CompressedTensorsConfig
.
DeepSeekFP8Config
)
if
envs
.
SGLANG_KT_MOE_AMX_WEIGHT_PATH
.
is_set
():
return
UnquantizedLinearMethod
()
scheme
=
self
.
get_scheme
(
layer
=
layer
,
layer_name
=
prefix
)
if
scheme
is
None
:
return
UnquantizedLinearMethod
()
...
...
@@ -137,7 +145,8 @@ class CompressedTensorsConfig(QuantizationConfig):
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
if
isinstance
(
layer
,
FusedMoE
):
return
CompressedTensorsMoEMethod
.
get_moe_method
(
self
)
# Ktransformers use CompressedTensorsWNA16AMXMOEMethod if AMX weights are provided
return
CompressedTensorsMoEMethod
.
get_moe_method
(
self
,
layer
,
prefix
)
return
None
@
classmethod
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
7e6191c0
...
...
@@ -4,16 +4,34 @@ from __future__ import annotations
import
enum
import
logging
import
re
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
List
try
:
from
sgl_kernel
import
fused_marlin_moe
FUSED_MARLIN_MOE_AVAILABLE
=
True
except
ImportError
:
FUSED_MARLIN_MOE_AVAILABLE
=
False
try
:
from
kt_kernel
import
AMXMoEWrapper
KTRANSFORMERS_AVAILABLE
=
True
except
ImportError
:
KTRANSFORMERS_AVAILABLE
=
False
import
torch
from
compressed_tensors
import
CompressionFormat
from
compressed_tensors.quantization
import
QuantizationStrategy
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.environ
import
envs
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe.moe_runner.triton
import
TritonMoeQuantInfo
from
sglang.srt.layers.quantization.base_config
import
FusedMoEMethodBase
from
sglang.srt.layers.quantization.compressed_tensors
import
WNA16_SUPPORTED_BITS
from
sglang.srt.layers.quantization.fp8_kernel
import
is_fp8_fnuz
,
scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_utils
import
normalize_e4m3fn_to_e4m3fnuz
from
sglang.srt.layers.quantization.utils
import
(
...
...
@@ -21,7 +39,12 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize
,
replace_parameter
,
)
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
set_weight_attrs
from
sglang.srt.utils
import
(
get_bool_env_var
,
get_compiler_backend
,
is_hip
,
set_weight_attrs
,
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
...
...
@@ -51,6 +74,18 @@ except ImportError:
logger
=
logging
.
getLogger
(
__name__
)
def
_mask_topk_ids_cpu_experts
(
topk_ids
:
torch
.
Tensor
,
num_gpu_experts
:
int
):
"""Mask topk_ids >= num_gpu_experts by setting them to -1."""
topk_ids
[
topk_ids
>=
num_gpu_experts
]
=
-
1
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
())
def
mask_cpu_expert_ids
(
topk_ids
:
torch
.
Tensor
,
num_gpu_experts
:
int
):
"""mask CPU expert IDs."""
_mask_topk_ids_cpu_experts
(
topk_ids
,
num_gpu_experts
)
return
topk_ids
class
GPTQMarlinState
(
Enum
):
REPACK
=
enum
.
auto
()
READY
=
enum
.
auto
()
...
...
@@ -60,6 +95,7 @@ __all__ = [
"CompressedTensorsMoEMethod"
,
"CompressedTensorsW8A8Fp8MoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
,
"CompressedTensorsWNA16AMXEPMoEMethod"
,
# for Ktransformers
]
...
...
@@ -72,12 +108,24 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
@
staticmethod
def
get_moe_method
(
quant_config
:
CompressedTensorsConfig
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
"CompressedTensorsMoEMethod"
:
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
if
envs
.
SGLANG_KT_MOE_AMX_WEIGHT_PATH
.
is_set
():
match
=
re
.
search
(
r
"(\d+)\.mlp"
,
prefix
)
if
not
match
:
raise
ValueError
(
f
"Unable to extract layer number from prefix '
{
prefix
}
'. "
f
"Expected format: '<layer_number>.mlp'"
)
layer_number
=
int
(
match
.
group
(
1
))
return
CompressedTensorsWNA16AMXEPMoEMethod
(
quant_config
,
layer_number
)
weight_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
input_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
...
...
@@ -201,7 +249,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
FusedMoE
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
|
FusedMoE
)
->
None
:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if
self
.
static_input_scales
:
...
...
@@ -349,7 +397,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
class
CompressedTensorsWNA16MoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
self
,
quant_config
:
CompressedTensorsConfig
):
def
__init__
(
self
,
quant_config
:
CompressedTensorsConfig
,
num_gpu_experts
=-
1
):
self
.
quant_config
=
quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
...
...
@@ -371,6 +419,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
"is supported for the following bits: "
,
f
"
{
WNA16_SUPPORTED_BITS
}
"
,
)
self
.
num_gpu_experts
=
num_gpu_experts
def
create_weights
(
self
,
...
...
@@ -381,10 +430,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
assert
(
params_dtype
==
torch
.
float16
),
"float16 is required for MoE compressed models. Set dtype=torch.float16"
# noqa: E501
if
self
.
num_gpu_experts
!=
-
1
:
num_experts
=
self
.
num_gpu_experts
# assert (
# params_dtype == torch.float16
# ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
...
...
@@ -683,3 +733,353 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
is_k_full
=
self
.
is_k_full
,
)
return
StandardCombineInput
(
hidden_states
=
output
)
class
CompressedTensorsWNA16AMXMoEMethod
(
CompressedTensorsMoEMethod
):
"""AMX MoE method using AMXMoEWrapper for CPU inference."""
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
,
# type: ignore # noqa E501
layer_idx
,
num_gpu_experts
,
cpuinfer
,
threadpool_count
,
amx_weight_path
,
chunked_prefill_size
,
):
if
not
KTRANSFORMERS_AVAILABLE
:
raise
ImportError
(
"kt_kernel is not installed, to use CompressedTensorsWNA16AMXEPMoEMethod, please install kt_kernel."
)
if
not
FUSED_MARLIN_MOE_AVAILABLE
:
raise
ImportError
(
"fused_marlin_moe is not available"
)
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
layer_idx
=
layer_idx
self
.
num_gpu_experts
=
num_gpu_experts
self
.
amx_weight_path
=
amx_weight_path
self
.
chunked_prefill_size
=
chunked_prefill_size
self
.
cpuinfer
=
cpuinfer
self
.
threadpool_count
=
threadpool_count
self
.
amx_wrapper
=
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
self
.
experts_num
=
num_experts
self
.
num_experts_per_tok
=
extra_weight_attrs
.
pop
(
"top_k"
)
self
.
hidden_size
=
hidden_size
self
.
moe_intermediate_size
=
extra_weight_attrs
.
pop
(
"intermediate_size_full"
)
if
self
.
tp_rank
!=
0
:
return
self
.
amx_wrapper
=
AMXMoEWrapper
(
layer_idx
=
self
.
layer_idx
,
num_experts
=
num_experts
,
num_experts_per_tok
=
self
.
num_experts_per_tok
,
hidden_size
=
hidden_size
,
moe_intermediate_size
=
self
.
moe_intermediate_size
,
num_gpu_experts
=
self
.
num_gpu_experts
,
cpuinfer_threads
=
self
.
cpuinfer
,
threadpool_count
=
self
.
threadpool_count
,
amx_weight_path
=
self
.
amx_weight_path
,
chunked_prefill_size
=
self
.
chunked_prefill_size
,
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
self
.
tp_rank
!=
0
:
return
if
self
.
amx_wrapper
is
None
:
raise
RuntimeError
(
"AMXMoEWrapper not initialized. Call create_weights first."
)
torch
.
cuda
.
synchronize
()
# Load weights using wrapper
from
sglang.srt.eplb.expert_location_dispatch
import
(
get_global_expert_location_metadata
,
)
physical_to_logical_map_cpu
=
(
get_global_expert_location_metadata
()
.
physical_to_logical_map_cpu
[
self
.
layer_idx
]
.
contiguous
()
)
self
.
amx_wrapper
.
load_weights
(
physical_to_logical_map_cpu
)
def
submit
(
self
,
layer
:
torch
.
nn
.
Module
,
dispatch_output
:
StandardDispatchOutput
,
)
->
None
:
"""Submit AMX inference task asynchronously."""
assert
(
self
.
moe_runner_config
.
activation
==
"silu"
),
"Only SiLU activation is supported."
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
topk_weights
,
topk_ids
,
_
=
topk_output
if
self
.
tp_rank
!=
0
or
self
.
amx_wrapper
is
None
:
return
None
# Submit forward task using wrapper
self
.
amx_wrapper
.
submit_forward
(
x
,
topk_ids
,
topk_weights
,
torch
.
cuda
.
current_stream
(
x
.
device
).
cuda_stream
)
return
None
def
sync
(
self
,
x
):
"""Synchronize and retrieve AMX inference results."""
if
self
.
tp_rank
!=
0
or
self
.
amx_wrapper
is
None
:
return
torch
.
zeros_like
(
x
)
# Sync forward task using wrapper
return
self
.
amx_wrapper
.
sync_forward
(
x
,
torch
.
cuda
.
current_stream
(
x
.
device
).
cuda_stream
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
dispatch_output
:
StandardDispatchOutput
,
)
->
CombineInput
:
"""Execute AMX MoE forward pass synchronously."""
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
assert
(
self
.
moe_runner_config
.
activation
==
"silu"
),
"Only SiLU activation is supported."
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
topk_weights
,
topk_ids
,
_
=
topk_output
if
self
.
tp_rank
!=
0
or
self
.
amx_wrapper
is
None
:
return
StandardCombineInput
(
hidden_states
=
torch
.
zeros_like
(
x
))
# Execute forward using wrapper (submit + sync)
output
=
self
.
amx_wrapper
.
forward
(
x
,
topk_ids
,
topk_weights
,
torch
.
cuda
.
current_stream
(
x
.
device
).
cuda_stream
)
return
StandardCombineInput
(
hidden_states
=
output
)
def
override_config
(
cls
,
num_gpu_experts
,
cpuinfer
,
threadpool_count
,
amx_weight_path
,
amx_method
,
chunked_prefill_size
,
):
"""Override MOE configuration via environment variables."""
# Set environment variables using envs utility class
if
num_gpu_experts
is
not
None
:
envs
.
SGLANG_KT_MOE_NUM_GPU_EXPERTS
.
set
(
num_gpu_experts
)
if
cpuinfer
is
not
None
:
envs
.
SGLANG_KT_MOE_CPUINFER
.
set
(
cpuinfer
)
if
threadpool_count
is
not
None
:
envs
.
SGLANG_KT_THREADPOOL_COUNT
.
set
(
threadpool_count
)
if
amx_weight_path
is
not
None
:
envs
.
SGLANG_KT_MOE_AMX_WEIGHT_PATH
.
set
(
amx_weight_path
)
if
amx_method
is
not
None
:
envs
.
SGLANG_KT_AMX_METHOD
.
set
(
amx_method
)
if
chunked_prefill_size
is
not
None
:
envs
.
SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE
.
set
(
chunked_prefill_size
)
class
CompressedTensorsWNA16AMXEPMoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
,
# type: ignore # noqa E501
layer_idx
,
):
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
if
(
not
envs
.
SGLANG_KT_MOE_NUM_GPU_EXPERTS
.
is_set
()
or
not
envs
.
SGLANG_KT_MOE_CPUINFER
.
is_set
()
or
not
envs
.
SGLANG_KT_MOE_AMX_WEIGHT_PATH
.
is_set
()
):
raise
RuntimeError
(
"the following arguments are required: --kt-amx-weight-path, --kt-cpuinfer, --kt-num-gpu-experts"
)
self
.
num_gpu_experts
=
envs
.
SGLANG_KT_MOE_NUM_GPU_EXPERTS
.
value
cpuinfer
=
envs
.
SGLANG_KT_MOE_CPUINFER
.
value
threadpool_count
=
envs
.
SGLANG_KT_THREADPOOL_COUNT
.
value
amx_weight_path
=
envs
.
SGLANG_KT_MOE_AMX_WEIGHT_PATH
.
value
chunked_prefill_size
=
envs
.
SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE
.
value
self
.
AMX_method
=
CompressedTensorsWNA16AMXMoEMethod
(
quant_config
,
layer_idx
,
self
.
num_gpu_experts
,
cpuinfer
,
threadpool_count
,
amx_weight_path
,
chunked_prefill_size
,
)
self
.
marlin_method
=
CompressedTensorsWNA16MoEMethod
(
quant_config
,
self
.
num_gpu_experts
)
self
.
layer_id
=
layer_idx
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
self
.
global_num_experts
=
num_experts
self
.
AMX_method
.
create_weights
(
layer
,
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
params_dtype
,
**
extra_weight_attrs
,
)
self
.
marlin_method
.
create_weights
(
layer
,
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
params_dtype
,
**
extra_weight_attrs
,
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
self
.
AMX_method
.
process_weights_after_loading
(
layer
)
self
.
marlin_method
.
process_weights_after_loading
(
layer
)
def
submit
(
self
,
layer
:
torch
.
nn
.
Module
,
dispatch_output
:
StandardDispatchOutput
,
)
->
CombineInput
:
"""Submit hybrid GPU+CPU MoE task (AMX submission + GPU execution)."""
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
assert
(
self
.
moe_runner_config
.
activation
==
"silu"
),
"Only SiLU activation is supported."
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
topk_weights
,
topk_ids
,
router_logits
=
topk_output
# Submit AMX task if on rank 0
if
self
.
tp_rank
==
0
:
self
.
AMX_method
.
submit
(
layer
,
dispatch_output
)
# Mask CPU expert IDs (>= num_gpu_experts) as -1 so they won't be computed on GPU
topk_ids
=
mask_cpu_expert_ids
(
topk_ids
,
self
.
num_gpu_experts
)
# Execute GPU (Marlin) experts
output
=
fused_marlin_moe
(
x
,
layer
.
w13_weight_packed
,
layer
.
w2_weight_packed
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
router_logits
,
topk_weights
,
topk_ids
,
g_idx1
=
layer
.
w13_weight_g_idx
,
g_idx2
=
layer
.
w2_weight_g_idx
,
sort_indices1
=
layer
.
w13_g_idx_sort_indices
,
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
num_bits
=
self
.
marlin_method
.
num_bits
,
is_k_full
=
self
.
marlin_method
.
is_k_full
,
global_num_experts
=
self
.
global_num_experts
,
expert_map
=
torch
.
empty
(
1
,
device
=
x
.
device
),
)
return
StandardCombineInput
(
hidden_states
=
output
)
def
sync
(
self
,
x
):
"""Synchronize and retrieve AMX results."""
if
self
.
tp_rank
!=
0
:
return
torch
.
zeros_like
(
x
)
return
self
.
AMX_method
.
sync
(
x
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
dispatch_output
:
StandardDispatchOutput
,
)
->
CombineInput
:
"""Execute hybrid GPU+CPU MoE forward pass with parallelism."""
from
sglang.srt.layers.moe.token_dispatcher
import
StandardCombineInput
assert
(
self
.
moe_runner_config
.
activation
==
"silu"
),
"Only SiLU activation is supported."
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
topk_weights
,
topk_ids
,
router_logits
=
topk_output
# Step 1: Submit AMX task (non-blocking) if on rank 0
# This starts CPU computation in parallel
if
self
.
tp_rank
==
0
:
self
.
AMX_method
.
submit
(
layer
,
dispatch_output
)
# Step 2: Execute GPU (Marlin) experts in parallel with CPU
# Mask CPU expert IDs (>= num_gpu_experts) as -1 so they won't be computed on GPU
topk_ids
=
mask_cpu_expert_ids
(
topk_ids
,
self
.
num_gpu_experts
)
# While GPU computes, CPU is also computing
output
=
fused_marlin_moe
(
x
,
layer
.
w13_weight_packed
,
layer
.
w2_weight_packed
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
router_logits
,
topk_weights
,
topk_ids
,
g_idx1
=
layer
.
w13_weight_g_idx
,
g_idx2
=
layer
.
w2_weight_g_idx
,
sort_indices1
=
layer
.
w13_g_idx_sort_indices
,
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
num_bits
=
self
.
marlin_method
.
num_bits
,
is_k_full
=
self
.
marlin_method
.
is_k_full
,
global_num_experts
=
self
.
global_num_experts
,
expert_map
=
torch
.
empty
(
1
,
device
=
x
.
device
),
)
# Step 3: Sync AMX results and combine with GPU results
if
self
.
tp_rank
==
0
:
amx_output
=
self
.
AMX_method
.
sync
(
x
)
output
+=
amx_output
return
StandardCombineInput
(
hidden_states
=
output
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
self
.
AMX_method
.
create_moe_runner
(
layer
,
moe_runner_config
)
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
7e6191c0
...
...
@@ -65,6 +65,13 @@ from sglang.srt.utils import (
)
from
sglang.srt.utils.patch_torch
import
monkey_patch_torch_compile
try
:
from
kt_kernel
import
AMXMoEWrapper
KTRANSFORMERS_AVAILABLE
=
True
except
ImportError
:
KTRANSFORMERS_AVAILABLE
=
False
_is_hip
=
is_hip
()
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -248,6 +255,8 @@ class CudaGraphRunner:
# Batch sizes to capture
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
log_info_on_rank0
(
logger
,
f
"Capture cuda graph bs
{
self
.
capture_bs
}
"
)
if
KTRANSFORMERS_AVAILABLE
:
AMXMoEWrapper
.
set_capture_batch_sizes
(
self
.
capture_bs
)
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
self
.
num_tokens_per_bs
=
1
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
7e6191c0
...
...
@@ -44,6 +44,7 @@ from sglang.srt.distributed import (
from
sglang.srt.distributed.device_communicators.pynccl_allocator
import
(
use_symmetric_memory
,
)
from
sglang.srt.environ
import
envs
from
sglang.srt.eplb.expert_distribution
import
get_global_expert_distribution_recorder
from
sglang.srt.eplb.expert_location
import
ModelConfigForExpertLocation
from
sglang.srt.eplb.expert_location_dispatch
import
ExpertLocationDispatchInfo
...
...
@@ -81,7 +82,12 @@ from sglang.srt.layers.moe import (
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
,
TopKOutputFormat
from
sglang.srt.layers.quantization
import
CompressedTensorsConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
CompressedTensorsWNA16AMXEPMoEMethod
,
)
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
from
sglang.srt.layers.quantization.fp8_kernel
import
(
is_fp8_fnuz
,
per_tensor_quant_mla_fp8
,
...
...
@@ -707,6 +713,10 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
,
gemm_output_zero_allocator
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
if
isinstance
(
self
.
experts
.
quant_method
,
CompressedTensorsWNA16AMXEPMoEMethod
):
topk_output
.
topk_weights
.
mul_
(
self
.
routed_scaling_factor
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
not
_is_cuda
:
final_hidden_states
*=
self
.
routed_scaling_factor
...
...
@@ -2837,6 +2847,10 @@ class DeepseekV2ForCausalLM(nn.Module):
self
.
config
=
config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
quant_config
=
quant_config
if
envs
.
SGLANG_KT_MOE_AMX_WEIGHT_PATH
.
is_set
():
CompressedTensorsConfig
.
DeepSeekFP8Config
=
Fp8Config
(
True
,
"dynamic"
,
None
,
[
128
,
128
]
)
self
.
determine_num_fused_shared_experts
()
self
.
model
=
DeepseekV2Model
(
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
...
...
@@ -2976,11 +2990,13 @@ class DeepseekV2ForCausalLM(nn.Module):
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
):
if
(
hasattr
(
self
.
quant_config
,
"weight_block_size"
)
and
self
.
quant_config
.
weight_block_size
is
not
None
):
weight_block_size
=
self
.
quant_config
.
weight_block_size
selected_quant_config
=
getattr
(
self
.
quant_config
,
"DeepSeekFP8Config"
,
self
.
quant_config
)
weight_block_size
=
getattr
(
selected_quant_config
,
"weight_block_size"
,
None
)
if
weight_block_size
is
not
None
:
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
if
_is_fp8_fnuz
:
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
...
...
python/sglang/srt/models/qwen3_next.py
View file @
7e6191c0
...
...
@@ -520,6 +520,7 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
config
=
config
,
quant_config
=
quant_config
,
alt_stream
=
alt_stream
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
else
:
self
.
mlp
=
Qwen2MoeMLP
(
...
...
@@ -673,6 +674,7 @@ class Qwen3HybridAttentionDecoderLayer(nn.Module):
config
=
config
,
quant_config
=
quant_config
,
alt_stream
=
alt_stream
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
else
:
self
.
mlp
=
Qwen2MoeMLP
(
...
...
python/sglang/srt/server_args.py
View file @
7e6191c0
...
...
@@ -91,6 +91,7 @@ QUANTIZATION_CHOICES = [
"qoq"
,
"w4afp8"
,
"mxfp4"
,
"compressed-tensors"
,
# for Ktransformers
]
ATTENTION_BACKEND_CHOICES
=
[
...
...
@@ -389,6 +390,13 @@ class ServerArgs:
# LMCache
enable_lmcache
:
bool
=
False
# Ktransformers
kt_amx_weight_path
:
Optional
[
str
]
=
None
kt_amx_method
:
Optional
[
str
]
=
None
kt_cpuinfer
:
Optional
[
int
]
=
None
kt_threadpool_count
:
Optional
[
int
]
=
None
kt_num_gpu_experts
:
Optional
[
int
]
=
None
# Double Sparsity
enable_double_sparsity
:
bool
=
False
ds_channel_config_path
:
Optional
[
str
]
=
None
...
...
@@ -544,6 +552,9 @@ class ServerArgs:
self
.
_handle_amd_specifics
()
self
.
_handle_grammar_backend
()
# Handle Ktransformers specific configs
self
.
_handle_ktransformers_configs
()
# Handle data parallelism.
self
.
_handle_data_parallelism
()
...
...
@@ -595,6 +606,22 @@ class ServerArgs:
)
self
.
tool_call_parser
=
deprecated_tool_call_parsers
[
self
.
tool_call_parser
]
def
_handle_ktransformers_configs
(
self
):
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
CompressedTensorsWNA16AMXEPMoEMethod
,
override_config
,
)
override_config
(
CompressedTensorsWNA16AMXEPMoEMethod
,
self
.
kt_num_gpu_experts
,
self
.
kt_cpuinfer
,
self
.
kt_threadpool_count
,
self
.
kt_amx_weight_path
,
self
.
kt_amx_method
,
self
.
chunked_prefill_size
,
)
def
_handle_missing_default_values
(
self
):
if
self
.
tokenizer_path
is
None
:
self
.
tokenizer_path
=
self
.
model_path
...
...
@@ -1518,6 +1545,7 @@ class ServerArgs:
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
# Model and tokenizer
parser
.
add_argument
(
"--model-path"
,
...
...
@@ -2675,6 +2703,35 @@ class ServerArgs:
help
=
"Using LMCache as an alternative hierarchical cache solution"
,
)
# Ktransformer server args
parser
.
add_argument
(
"--kt-amx-weight-path"
,
type
=
str
,
help
=
"[ktransformers parameter] The path of the quantized expert weights for amx kernel. A local folder."
,
)
parser
.
add_argument
(
"--kt-amx-method"
,
type
=
str
,
default
=
"AMXINT4"
,
help
=
"[ktransformers parameter] Quantization formats for CPU execution."
,
)
parser
.
add_argument
(
"--kt-cpuinfer"
,
type
=
int
,
help
=
"[ktransformers parameter] The number of CPUInfer threads."
,
)
parser
.
add_argument
(
"--kt-threadpool-count"
,
type
=
int
,
default
=
2
,
help
=
"[ktransformers parameter] One-to-one with the number of NUMA nodes (one thread pool per NUMA)."
,
)
parser
.
add_argument
(
"--kt-num-gpu-experts"
,
type
=
int
,
help
=
"[ktransformers parameter] The number of GPU experts."
,
)
# Double Sparsity
parser
.
add_argument
(
"--enable-double-sparsity"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment