Unverified Commit 7e6191c0 authored by Atream's avatar Atream Committed by GitHub
Browse files

init support for KTransformers Heterogeneous Computing (#11487)


Co-authored-by: default avatarJianwei Dong <1913953267@qq.com>
parent 6f9b66bd
...@@ -229,6 +229,14 @@ class Envs: ...@@ -229,6 +229,14 @@ class Envs:
SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28) SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28)
SGLANG_RESIZE_RESAMPLE = EnvStr("") SGLANG_RESIZE_RESAMPLE = EnvStr("")
# Ktransformers
SGLANG_KT_MOE_NUM_GPU_EXPERTS = EnvInt(None)
SGLANG_KT_MOE_CPUINFER = EnvInt(None)
SGLANG_KT_THREADPOOL_COUNT = EnvInt(None)
SGLANG_KT_MOE_AMX_WEIGHT_PATH = EnvStr(None)
SGLANG_KT_AMX_METHOD = EnvStr(None)
SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE = EnvInt(None)
# fmt: on # fmt: on
......
...@@ -33,6 +33,11 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -33,6 +33,11 @@ from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
QuantizationConfig, QuantizationConfig,
) )
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsWNA16AMXEPMoEMethod,
CompressedTensorsWNA16AMXMoEMethod,
CompressedTensorsWNA16MoEMethod,
)
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
...@@ -150,7 +155,6 @@ class FusedMoE(torch.nn.Module): ...@@ -150,7 +155,6 @@ class FusedMoE(torch.nn.Module):
with_bias=False, with_bias=False,
): ):
super().__init__() super().__init__()
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
...@@ -227,6 +231,8 @@ class FusedMoE(torch.nn.Module): ...@@ -227,6 +231,8 @@ class FusedMoE(torch.nn.Module):
if not use_weight_loader_fused if not use_weight_loader_fused
else self.weight_loader_fused else self.weight_loader_fused
), ),
intermediate_size_full=intermediate_size,
top_k=top_k,
with_bias=with_bias, with_bias=with_bias,
) )
...@@ -542,6 +548,18 @@ class FusedMoE(torch.nn.Module): ...@@ -542,6 +548,18 @@ class FusedMoE(torch.nn.Module):
if expert_id == -1: if expert_id == -1:
return return
if isinstance(
self.quant_method,
(
CompressedTensorsWNA16MoEMethod,
CompressedTensorsWNA16AMXMoEMethod,
CompressedTensorsWNA16AMXEPMoEMethod,
),
):
if self.quant_method.num_gpu_experts != -1:
if expert_id >= self.quant_method.num_gpu_experts:
return
self._weight_loader_impl( self._weight_loader_impl(
param=param, param=param,
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
...@@ -568,7 +586,12 @@ class FusedMoE(torch.nn.Module): ...@@ -568,7 +586,12 @@ class FusedMoE(torch.nn.Module):
loaded_weight.t().contiguous() loaded_weight.t().contiguous()
if ( if (
self.quant_method.__class__.__name__ self.quant_method.__class__.__name__
== "CompressedTensorsWNA16MoEMethod" in [
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
"CompressedTensorsWNA16AMXMoEMethod",
"CompressedTensorsWNA16AMXEPMoEMethod",
]
) )
else loaded_weight else loaded_weight
) )
...@@ -827,7 +850,6 @@ class FusedMoE(torch.nn.Module): ...@@ -827,7 +850,6 @@ class FusedMoE(torch.nn.Module):
dispatch_output=dispatch_output, dispatch_output=dispatch_output,
**kwargs, **kwargs,
) )
final_hidden_states = self.dispatcher.combine(combine_input) final_hidden_states = self.dispatcher.combine(combine_input)
# TODO: should we add some conditions here? # TODO: should we add some conditions here?
......
class scalar_types:
uint4b8 = "uint4b8"
uint8b128 = "uint8b128"
WNA16_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4b8, 8: scalar_types.uint8b128}
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
...@@ -19,11 +19,13 @@ from compressed_tensors.quantization import ( ...@@ -19,11 +19,13 @@ from compressed_tensors.quantization import (
) )
from pydantic import BaseModel from pydantic import BaseModel
from sglang.srt.environ import envs
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
LinearMethodBase, LinearMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.compressed_tensors import WNA16_SUPPORTED_BITS
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
CompressedTensorsMoEMethod, CompressedTensorsMoEMethod,
) )
...@@ -38,6 +40,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import ( ...@@ -38,6 +40,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
is_activation_quantization_format, is_activation_quantization_format,
should_ignore_layer, should_ignore_layer,
) )
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
try: try:
...@@ -76,6 +79,7 @@ class DeviceCapability(NamedTuple): ...@@ -76,6 +79,7 @@ class DeviceCapability(NamedTuple):
class CompressedTensorsConfig(QuantizationConfig): class CompressedTensorsConfig(QuantizationConfig):
DeepSeekFP8Config = None
def __init__( def __init__(
self, self,
...@@ -129,6 +133,10 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -129,6 +133,10 @@ class CompressedTensorsConfig(QuantizationConfig):
): ):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if CompressedTensorsConfig.DeepSeekFP8Config is not None:
return Fp8LinearMethod(CompressedTensorsConfig.DeepSeekFP8Config)
if envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set():
return UnquantizedLinearMethod()
scheme = self.get_scheme(layer=layer, layer_name=prefix) scheme = self.get_scheme(layer=layer, layer_name=prefix)
if scheme is None: if scheme is None:
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
...@@ -137,7 +145,8 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -137,7 +145,8 @@ class CompressedTensorsConfig(QuantizationConfig):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod.get_moe_method(self) # Ktransformers use CompressedTensorsWNA16AMXMOEMethod if AMX weights are provided
return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix)
return None return None
@classmethod @classmethod
......
...@@ -4,16 +4,34 @@ from __future__ import annotations ...@@ -4,16 +4,34 @@ from __future__ import annotations
import enum import enum
import logging import logging
import re
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
try:
from sgl_kernel import fused_marlin_moe
FUSED_MARLIN_MOE_AVAILABLE = True
except ImportError:
FUSED_MARLIN_MOE_AVAILABLE = False
try:
from kt_kernel import AMXMoEWrapper
KTRANSFORMERS_AVAILABLE = True
except ImportError:
KTRANSFORMERS_AVAILABLE = False
import torch import torch
from compressed_tensors import CompressionFormat from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.environ import envs
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from sglang.srt.layers.quantization.compressed_tensors import WNA16_SUPPORTED_BITS
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.utils import ( from sglang.srt.layers.quantization.utils import (
...@@ -21,7 +39,12 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -21,7 +39,12 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize, per_tensor_dequantize,
replace_parameter, replace_parameter,
) )
from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs from sglang.srt.utils import (
get_bool_env_var,
get_compiler_backend,
is_hip,
set_weight_attrs,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
...@@ -51,6 +74,18 @@ except ImportError: ...@@ -51,6 +74,18 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _mask_topk_ids_cpu_experts(topk_ids: torch.Tensor, num_gpu_experts: int):
"""Mask topk_ids >= num_gpu_experts by setting them to -1."""
topk_ids[topk_ids >= num_gpu_experts] = -1
@torch.compile(dynamic=True, backend=get_compiler_backend())
def mask_cpu_expert_ids(topk_ids: torch.Tensor, num_gpu_experts: int):
"""mask CPU expert IDs."""
_mask_topk_ids_cpu_experts(topk_ids, num_gpu_experts)
return topk_ids
class GPTQMarlinState(Enum): class GPTQMarlinState(Enum):
REPACK = enum.auto() REPACK = enum.auto()
READY = enum.auto() READY = enum.auto()
...@@ -60,6 +95,7 @@ __all__ = [ ...@@ -60,6 +95,7 @@ __all__ = [
"CompressedTensorsMoEMethod", "CompressedTensorsMoEMethod",
"CompressedTensorsW8A8Fp8MoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsWNA16MoEMethod", "CompressedTensorsWNA16MoEMethod",
"CompressedTensorsWNA16AMXEPMoEMethod", # for Ktransformers
] ]
...@@ -72,12 +108,24 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): ...@@ -72,12 +108,24 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
@staticmethod @staticmethod
def get_moe_method( def get_moe_method(
quant_config: CompressedTensorsConfig, quant_config: CompressedTensorsConfig,
layer: torch.nn.Module,
prefix: str,
) -> "CompressedTensorsMoEMethod": ) -> "CompressedTensorsMoEMethod":
# TODO: @dsikka: refactor this to use schemes as other kernels # TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored. # are supported + check if the layer is being ignored.
if envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set():
match = re.search(r"(\d+)\.mlp", prefix)
if not match:
raise ValueError(
f"Unable to extract layer number from prefix '{prefix}'. "
f"Expected format: '<layer_number>.mlp'"
)
layer_number = int(match.group(1))
return CompressedTensorsWNA16AMXEPMoEMethod(quant_config, layer_number)
weight_quant = quant_config.target_scheme_map["Linear"].get("weights") weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
input_quant = quant_config.target_scheme_map["Linear"].get("input_activations") input_quant = quant_config.target_scheme_map["Linear"].get("input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant): if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
if not VLLM_AVAILABLE: if not VLLM_AVAILABLE:
raise ImportError( raise ImportError(
...@@ -201,7 +249,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -201,7 +249,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: FusedMoE) -> None: def process_weights_after_loading(self, layer: torch.nn.Module | FusedMoE) -> None:
# Fp8 moe kernels require a single activation scale. # Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ. # We take the max of all the scales in case they differ.
if self.static_input_scales: if self.static_input_scales:
...@@ -349,7 +397,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -349,7 +397,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__(self, quant_config: CompressedTensorsConfig): def __init__(self, quant_config: CompressedTensorsConfig, num_gpu_experts=-1):
self.quant_config = quant_config self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels # TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored. # are supported + check if the layer is being ignored.
...@@ -371,6 +419,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -371,6 +419,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
"is supported for the following bits: ", "is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}", f"{WNA16_SUPPORTED_BITS}",
) )
self.num_gpu_experts = num_gpu_experts
def create_weights( def create_weights(
self, self,
...@@ -381,10 +430,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -381,10 +430,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
if self.num_gpu_experts != -1:
assert ( num_experts = self.num_gpu_experts
params_dtype == torch.float16 # assert (
), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501 # params_dtype == torch.float16
# ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
# Will transpose the loaded weight along the # Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will # intermediate and hidden dim sizes. Will
...@@ -683,3 +733,353 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -683,3 +733,353 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
is_k_full=self.is_k_full, is_k_full=self.is_k_full,
) )
return StandardCombineInput(hidden_states=output) return StandardCombineInput(hidden_states=output)
class CompressedTensorsWNA16AMXMoEMethod(CompressedTensorsMoEMethod):
"""AMX MoE method using AMXMoEWrapper for CPU inference."""
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer_idx,
num_gpu_experts,
cpuinfer,
threadpool_count,
amx_weight_path,
chunked_prefill_size,
):
if not KTRANSFORMERS_AVAILABLE:
raise ImportError(
"kt_kernel is not installed, to use CompressedTensorsWNA16AMXEPMoEMethod, please install kt_kernel."
)
if not FUSED_MARLIN_MOE_AVAILABLE:
raise ImportError("fused_marlin_moe is not available")
self.tp_rank = get_tensor_model_parallel_rank()
self.layer_idx = layer_idx
self.num_gpu_experts = num_gpu_experts
self.amx_weight_path = amx_weight_path
self.chunked_prefill_size = chunked_prefill_size
self.cpuinfer = cpuinfer
self.threadpool_count = threadpool_count
self.amx_wrapper = None
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
self.experts_num = num_experts
self.num_experts_per_tok = extra_weight_attrs.pop("top_k")
self.hidden_size = hidden_size
self.moe_intermediate_size = extra_weight_attrs.pop("intermediate_size_full")
if self.tp_rank != 0:
return
self.amx_wrapper = AMXMoEWrapper(
layer_idx=self.layer_idx,
num_experts=num_experts,
num_experts_per_tok=self.num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=self.moe_intermediate_size,
num_gpu_experts=self.num_gpu_experts,
cpuinfer_threads=self.cpuinfer,
threadpool_count=self.threadpool_count,
amx_weight_path=self.amx_weight_path,
chunked_prefill_size=self.chunked_prefill_size,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.tp_rank != 0:
return
if self.amx_wrapper is None:
raise RuntimeError(
"AMXMoEWrapper not initialized. Call create_weights first."
)
torch.cuda.synchronize()
# Load weights using wrapper
from sglang.srt.eplb.expert_location_dispatch import (
get_global_expert_location_metadata,
)
physical_to_logical_map_cpu = (
get_global_expert_location_metadata()
.physical_to_logical_map_cpu[self.layer_idx]
.contiguous()
)
self.amx_wrapper.load_weights(physical_to_logical_map_cpu)
def submit(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> None:
"""Submit AMX inference task asynchronously."""
assert (
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output
if self.tp_rank != 0 or self.amx_wrapper is None:
return None
# Submit forward task using wrapper
self.amx_wrapper.submit_forward(
x, topk_ids, topk_weights, torch.cuda.current_stream(x.device).cuda_stream
)
return None
def sync(self, x):
"""Synchronize and retrieve AMX inference results."""
if self.tp_rank != 0 or self.amx_wrapper is None:
return torch.zeros_like(x)
# Sync forward task using wrapper
return self.amx_wrapper.sync_forward(
x, torch.cuda.current_stream(x.device).cuda_stream
)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
"""Execute AMX MoE forward pass synchronously."""
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
assert (
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output
if self.tp_rank != 0 or self.amx_wrapper is None:
return StandardCombineInput(hidden_states=torch.zeros_like(x))
# Execute forward using wrapper (submit + sync)
output = self.amx_wrapper.forward(
x, topk_ids, topk_weights, torch.cuda.current_stream(x.device).cuda_stream
)
return StandardCombineInput(hidden_states=output)
def override_config(
cls,
num_gpu_experts,
cpuinfer,
threadpool_count,
amx_weight_path,
amx_method,
chunked_prefill_size,
):
"""Override MOE configuration via environment variables."""
# Set environment variables using envs utility class
if num_gpu_experts is not None:
envs.SGLANG_KT_MOE_NUM_GPU_EXPERTS.set(num_gpu_experts)
if cpuinfer is not None:
envs.SGLANG_KT_MOE_CPUINFER.set(cpuinfer)
if threadpool_count is not None:
envs.SGLANG_KT_THREADPOOL_COUNT.set(threadpool_count)
if amx_weight_path is not None:
envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.set(amx_weight_path)
if amx_method is not None:
envs.SGLANG_KT_AMX_METHOD.set(amx_method)
if chunked_prefill_size is not None:
envs.SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE.set(chunked_prefill_size)
class CompressedTensorsWNA16AMXEPMoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
layer_idx,
):
self.tp_rank = get_tensor_model_parallel_rank()
if (
not envs.SGLANG_KT_MOE_NUM_GPU_EXPERTS.is_set()
or not envs.SGLANG_KT_MOE_CPUINFER.is_set()
or not envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set()
):
raise RuntimeError(
"the following arguments are required: --kt-amx-weight-path, --kt-cpuinfer, --kt-num-gpu-experts"
)
self.num_gpu_experts = envs.SGLANG_KT_MOE_NUM_GPU_EXPERTS.value
cpuinfer = envs.SGLANG_KT_MOE_CPUINFER.value
threadpool_count = envs.SGLANG_KT_THREADPOOL_COUNT.value
amx_weight_path = envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.value
chunked_prefill_size = envs.SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE.value
self.AMX_method = CompressedTensorsWNA16AMXMoEMethod(
quant_config,
layer_idx,
self.num_gpu_experts,
cpuinfer,
threadpool_count,
amx_weight_path,
chunked_prefill_size,
)
self.marlin_method = CompressedTensorsWNA16MoEMethod(
quant_config, self.num_gpu_experts
)
self.layer_id = layer_idx
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
self.global_num_experts = num_experts
self.AMX_method.create_weights(
layer,
num_experts,
hidden_size,
intermediate_size_per_partition,
params_dtype,
**extra_weight_attrs,
)
self.marlin_method.create_weights(
layer,
num_experts,
hidden_size,
intermediate_size_per_partition,
params_dtype,
**extra_weight_attrs,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.AMX_method.process_weights_after_loading(layer)
self.marlin_method.process_weights_after_loading(layer)
def submit(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
"""Submit hybrid GPU+CPU MoE task (AMX submission + GPU execution)."""
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
assert (
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, router_logits = topk_output
# Submit AMX task if on rank 0
if self.tp_rank == 0:
self.AMX_method.submit(layer, dispatch_output)
# Mask CPU expert IDs (>= num_gpu_experts) as -1 so they won't be computed on GPU
topk_ids = mask_cpu_expert_ids(topk_ids, self.num_gpu_experts)
# Execute GPU (Marlin) experts
output = fused_marlin_moe(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
g_idx1=layer.w13_weight_g_idx,
g_idx2=layer.w2_weight_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.marlin_method.num_bits,
is_k_full=self.marlin_method.is_k_full,
global_num_experts=self.global_num_experts,
expert_map=torch.empty(1, device=x.device),
)
return StandardCombineInput(hidden_states=output)
def sync(self, x):
"""Synchronize and retrieve AMX results."""
if self.tp_rank != 0:
return torch.zeros_like(x)
return self.AMX_method.sync(x)
def apply(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
"""Execute hybrid GPU+CPU MoE forward pass with parallelism."""
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
assert (
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, router_logits = topk_output
# Step 1: Submit AMX task (non-blocking) if on rank 0
# This starts CPU computation in parallel
if self.tp_rank == 0:
self.AMX_method.submit(layer, dispatch_output)
# Step 2: Execute GPU (Marlin) experts in parallel with CPU
# Mask CPU expert IDs (>= num_gpu_experts) as -1 so they won't be computed on GPU
topk_ids = mask_cpu_expert_ids(topk_ids, self.num_gpu_experts)
# While GPU computes, CPU is also computing
output = fused_marlin_moe(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
g_idx1=layer.w13_weight_g_idx,
g_idx2=layer.w2_weight_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.marlin_method.num_bits,
is_k_full=self.marlin_method.is_k_full,
global_num_experts=self.global_num_experts,
expert_map=torch.empty(1, device=x.device),
)
# Step 3: Sync AMX results and combine with GPU results
if self.tp_rank == 0:
amx_output = self.AMX_method.sync(x)
output += amx_output
return StandardCombineInput(hidden_states=output)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
self.AMX_method.create_moe_runner(layer, moe_runner_config)
...@@ -65,6 +65,13 @@ from sglang.srt.utils import ( ...@@ -65,6 +65,13 @@ from sglang.srt.utils import (
) )
from sglang.srt.utils.patch_torch import monkey_patch_torch_compile from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
try:
from kt_kernel import AMXMoEWrapper
KTRANSFORMERS_AVAILABLE = True
except ImportError:
KTRANSFORMERS_AVAILABLE = False
_is_hip = is_hip() _is_hip = is_hip()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -248,6 +255,8 @@ class CudaGraphRunner: ...@@ -248,6 +255,8 @@ class CudaGraphRunner:
# Batch sizes to capture # Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}") log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}")
if KTRANSFORMERS_AVAILABLE:
AMXMoEWrapper.set_capture_batch_sizes(self.capture_bs)
self.capture_forward_mode = ForwardMode.DECODE self.capture_forward_mode = ForwardMode.DECODE
self.capture_hidden_mode = CaptureHiddenMode.NULL self.capture_hidden_mode = CaptureHiddenMode.NULL
self.num_tokens_per_bs = 1 self.num_tokens_per_bs = 1
......
...@@ -44,6 +44,7 @@ from sglang.srt.distributed import ( ...@@ -44,6 +44,7 @@ from sglang.srt.distributed import (
from sglang.srt.distributed.device_communicators.pynccl_allocator import ( from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory, use_symmetric_memory,
) )
from sglang.srt.environ import envs
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
...@@ -81,7 +82,12 @@ from sglang.srt.layers.moe import ( ...@@ -81,7 +82,12 @@ from sglang.srt.layers.moe import (
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
from sglang.srt.layers.quantization import CompressedTensorsConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsWNA16AMXEPMoEMethod,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz, is_fp8_fnuz,
per_tensor_quant_mla_fp8, per_tensor_quant_mla_fp8,
...@@ -707,6 +713,10 @@ class DeepseekV2MoE(nn.Module): ...@@ -707,6 +713,10 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states, gemm_output_zero_allocator) router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
topk_output = self.topk(hidden_states, router_logits) topk_output = self.topk(hidden_states, router_logits)
if isinstance(
self.experts.quant_method, CompressedTensorsWNA16AMXEPMoEMethod
):
topk_output.topk_weights.mul_(self.routed_scaling_factor)
final_hidden_states = self.experts(hidden_states, topk_output) final_hidden_states = self.experts(hidden_states, topk_output)
if not _is_cuda: if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
...@@ -2837,6 +2847,10 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2837,6 +2847,10 @@ class DeepseekV2ForCausalLM(nn.Module):
self.config = config self.config = config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config self.quant_config = quant_config
if envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set():
CompressedTensorsConfig.DeepSeekFP8Config = Fp8Config(
True, "dynamic", None, [128, 128]
)
self.determine_num_fused_shared_experts() self.determine_num_fused_shared_experts()
self.model = DeepseekV2Model( self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix) config, quant_config, prefix=add_prefix("model", prefix)
...@@ -2976,11 +2990,13 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2976,11 +2990,13 @@ class DeepseekV2ForCausalLM(nn.Module):
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.float8_e4m3fnuz, torch.float8_e4m3fnuz,
): ):
if ( selected_quant_config = getattr(
hasattr(self.quant_config, "weight_block_size") self.quant_config, "DeepSeekFP8Config", self.quant_config
and self.quant_config.weight_block_size is not None )
): weight_block_size = getattr(
weight_block_size = self.quant_config.weight_block_size selected_quant_config, "weight_block_size", None
)
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
if _is_fp8_fnuz: if _is_fp8_fnuz:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
......
...@@ -520,6 +520,7 @@ class Qwen3HybridLinearDecoderLayer(nn.Module): ...@@ -520,6 +520,7 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
alt_stream=alt_stream, alt_stream=alt_stream,
prefix=add_prefix("mlp", prefix),
) )
else: else:
self.mlp = Qwen2MoeMLP( self.mlp = Qwen2MoeMLP(
...@@ -673,6 +674,7 @@ class Qwen3HybridAttentionDecoderLayer(nn.Module): ...@@ -673,6 +674,7 @@ class Qwen3HybridAttentionDecoderLayer(nn.Module):
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
alt_stream=alt_stream, alt_stream=alt_stream,
prefix=add_prefix("mlp", prefix),
) )
else: else:
self.mlp = Qwen2MoeMLP( self.mlp = Qwen2MoeMLP(
......
...@@ -91,6 +91,7 @@ QUANTIZATION_CHOICES = [ ...@@ -91,6 +91,7 @@ QUANTIZATION_CHOICES = [
"qoq", "qoq",
"w4afp8", "w4afp8",
"mxfp4", "mxfp4",
"compressed-tensors", # for Ktransformers
] ]
ATTENTION_BACKEND_CHOICES = [ ATTENTION_BACKEND_CHOICES = [
...@@ -389,6 +390,13 @@ class ServerArgs: ...@@ -389,6 +390,13 @@ class ServerArgs:
# LMCache # LMCache
enable_lmcache: bool = False enable_lmcache: bool = False
# Ktransformers
kt_amx_weight_path: Optional[str] = None
kt_amx_method: Optional[str] = None
kt_cpuinfer: Optional[int] = None
kt_threadpool_count: Optional[int] = None
kt_num_gpu_experts: Optional[int] = None
# Double Sparsity # Double Sparsity
enable_double_sparsity: bool = False enable_double_sparsity: bool = False
ds_channel_config_path: Optional[str] = None ds_channel_config_path: Optional[str] = None
...@@ -544,6 +552,9 @@ class ServerArgs: ...@@ -544,6 +552,9 @@ class ServerArgs:
self._handle_amd_specifics() self._handle_amd_specifics()
self._handle_grammar_backend() self._handle_grammar_backend()
# Handle Ktransformers specific configs
self._handle_ktransformers_configs()
# Handle data parallelism. # Handle data parallelism.
self._handle_data_parallelism() self._handle_data_parallelism()
...@@ -595,6 +606,22 @@ class ServerArgs: ...@@ -595,6 +606,22 @@ class ServerArgs:
) )
self.tool_call_parser = deprecated_tool_call_parsers[self.tool_call_parser] self.tool_call_parser = deprecated_tool_call_parsers[self.tool_call_parser]
def _handle_ktransformers_configs(self):
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsWNA16AMXEPMoEMethod,
override_config,
)
override_config(
CompressedTensorsWNA16AMXEPMoEMethod,
self.kt_num_gpu_experts,
self.kt_cpuinfer,
self.kt_threadpool_count,
self.kt_amx_weight_path,
self.kt_amx_method,
self.chunked_prefill_size,
)
def _handle_missing_default_values(self): def _handle_missing_default_values(self):
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
...@@ -1518,6 +1545,7 @@ class ServerArgs: ...@@ -1518,6 +1545,7 @@ class ServerArgs:
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
# Model and tokenizer # Model and tokenizer
parser.add_argument( parser.add_argument(
"--model-path", "--model-path",
...@@ -2675,6 +2703,35 @@ class ServerArgs: ...@@ -2675,6 +2703,35 @@ class ServerArgs:
help="Using LMCache as an alternative hierarchical cache solution", help="Using LMCache as an alternative hierarchical cache solution",
) )
# Ktransformer server args
parser.add_argument(
"--kt-amx-weight-path",
type=str,
help="[ktransformers parameter] The path of the quantized expert weights for amx kernel. A local folder.",
)
parser.add_argument(
"--kt-amx-method",
type=str,
default="AMXINT4",
help="[ktransformers parameter] Quantization formats for CPU execution.",
)
parser.add_argument(
"--kt-cpuinfer",
type=int,
help="[ktransformers parameter] The number of CPUInfer threads.",
)
parser.add_argument(
"--kt-threadpool-count",
type=int,
default=2,
help="[ktransformers parameter] One-to-one with the number of NUMA nodes (one thread pool per NUMA).",
)
parser.add_argument(
"--kt-num-gpu-experts",
type=int,
help="[ktransformers parameter] The number of GPU experts.",
)
# Double Sparsity # Double Sparsity
parser.add_argument( parser.add_argument(
"--enable-double-sparsity", "--enable-double-sparsity",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment