Unverified Commit 627bac64 authored by Atream's avatar Atream Committed by GitHub
Browse files

Support Expert Deferral Mechanism in KTransformers (#12586)


Co-authored-by: default avatarChen Hongtao <56470055+chenht2022@users.noreply.github.com>
Co-authored-by: default avatarchenht2022 <cht22@mails.tsinghua.edu.cn>
parent 3651cfbf
......@@ -270,6 +270,8 @@ class Envs:
SGLANG_KT_MOE_AMX_WEIGHT_PATH = EnvStr(None)
SGLANG_KT_AMX_METHOD = EnvStr(None)
SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE = EnvInt(None)
SGLANG_KT_MOE_MAX_DEFERRED_EXPERTS_PER_TOKEN = EnvInt(None)
SGLANG_KT_MOE_TOTAL_LAYERS = EnvInt(None)
# Sparse Embeddings
SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None)
......
......@@ -705,8 +705,9 @@ class CompressedTensorsWNA16AMXMoEMethod(CompressedTensorsMoEMethod):
threadpool_count,
amx_weight_path,
chunked_prefill_size,
max_deferred_experts_per_token,
total_num_hidden_layers,
):
if not KTRANSFORMERS_AVAILABLE:
raise ImportError(
"kt_kernel is not installed, to use CompressedTensorsWNA16AMXEPMoEMethod, please install kt_kernel."
......@@ -723,6 +724,8 @@ class CompressedTensorsWNA16AMXMoEMethod(CompressedTensorsMoEMethod):
self.cpuinfer = cpuinfer
self.threadpool_count = threadpool_count
self.amx_wrapper = None
self.max_deferred_experts_per_token = max_deferred_experts_per_token
self.total_num_hidden_layers = total_num_hidden_layers
def create_weights(
self,
......@@ -733,6 +736,13 @@ class CompressedTensorsWNA16AMXMoEMethod(CompressedTensorsMoEMethod):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer_max_deferred = self.max_deferred_experts_per_token or 0
if (
self.max_deferred_experts_per_token is not None
and self.total_num_hidden_layers is not None
and self.layer_idx == self.total_num_hidden_layers - 1
):
layer_max_deferred = 0
self.experts_num = num_experts
self.num_experts_per_tok = extra_weight_attrs.pop("top_k")
self.hidden_size = hidden_size
......@@ -751,6 +761,7 @@ class CompressedTensorsWNA16AMXMoEMethod(CompressedTensorsMoEMethod):
threadpool_count=self.threadpool_count,
amx_weight_path=self.amx_weight_path,
chunked_prefill_size=self.chunked_prefill_size,
max_deferred_experts_per_token=layer_max_deferred,
amx_method=envs.SGLANG_KT_AMX_METHOD.value,
)
......@@ -848,6 +859,8 @@ def override_config(
amx_weight_path,
amx_method,
chunked_prefill_size,
max_deferred_experts_per_token,
num_hidden_layers,
):
"""Override MOE configuration via environment variables."""
# Set environment variables using envs utility class
......@@ -863,6 +876,12 @@ def override_config(
envs.SGLANG_KT_AMX_METHOD.set(amx_method)
if chunked_prefill_size is not None:
envs.SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE.set(chunked_prefill_size)
envs.SGLANG_KT_MOE_MAX_DEFERRED_EXPERTS_PER_TOKEN.set(
max_deferred_experts_per_token
)
envs.SGLANG_KT_MOE_TOTAL_LAYERS.set(num_hidden_layers)
cls.max_deferred_experts_per_token = max_deferred_experts_per_token
cls.total_num_hidden_layers = num_hidden_layers
class CompressedTensorsWNA16AMXEPMoEMethod(CompressedTensorsMoEMethod):
......@@ -887,6 +906,8 @@ class CompressedTensorsWNA16AMXEPMoEMethod(CompressedTensorsMoEMethod):
threadpool_count = envs.SGLANG_KT_THREADPOOL_COUNT.value
amx_weight_path = envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.value
chunked_prefill_size = envs.SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE.value
max_deferred = envs.SGLANG_KT_MOE_MAX_DEFERRED_EXPERTS_PER_TOKEN.value
total_layers = envs.SGLANG_KT_MOE_TOTAL_LAYERS.value
self.AMX_method = CompressedTensorsWNA16AMXMoEMethod(
quant_config,
......@@ -896,6 +917,8 @@ class CompressedTensorsWNA16AMXEPMoEMethod(CompressedTensorsMoEMethod):
threadpool_count,
amx_weight_path,
chunked_prefill_size,
max_deferred_experts_per_token=max_deferred,
total_num_hidden_layers=total_layers,
)
self.marlin_method = CompressedTensorsWNA16MoEMethod(
quant_config, self.num_gpu_experts
......
......@@ -437,6 +437,7 @@ class ServerArgs:
kt_cpuinfer: Optional[int] = None
kt_threadpool_count: Optional[int] = None
kt_num_gpu_experts: Optional[int] = None
kt_max_deferred_experts_per_token: Optional[int] = None
# Double Sparsity
enable_double_sparsity: bool = False
......@@ -1329,6 +1330,21 @@ class ServerArgs:
override_config,
)
num_hidden_layers = None
if self.kt_max_deferred_experts_per_token is not None:
try:
model_config = self.get_model_config()
base_config = (
getattr(model_config, "hf_text_config", None)
or model_config.hf_config
)
num_hidden_layers = getattr(base_config, "num_hidden_layers", None)
except Exception as exc: # noqa: BLE001
logger.warning(
"Failed to load model config for kt_max_deferred_experts_per_token: %s",
exc,
)
override_config(
CompressedTensorsWNA16AMXEPMoEMethod,
self.kt_num_gpu_experts,
......@@ -1337,6 +1353,8 @@ class ServerArgs:
self.kt_amx_weight_path,
self.kt_amx_method,
self.chunked_prefill_size,
self.kt_max_deferred_experts_per_token,
num_hidden_layers,
)
def _handle_data_parallelism(self):
......@@ -3038,6 +3056,12 @@ class ServerArgs:
type=int,
help="[ktransformers parameter] The number of GPU experts.",
)
parser.add_argument(
"--kt-max-deferred-experts-per-token",
type=int,
default=ServerArgs.kt_max_deferred_experts_per_token,
help="Maximum number of experts deferred to CPU per token. All MoE layers except the final one use this value; the final layer always uses 0.",
)
# Double Sparsity
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment