Unverified Commit 627bac64 authored by Atream's avatar Atream Committed by GitHub
Browse files

Support Expert Deferral Mechanism in KTransformers (#12586)


Co-authored-by: default avatarChen Hongtao <56470055+chenht2022@users.noreply.github.com>
Co-authored-by: default avatarchenht2022 <cht22@mails.tsinghua.edu.cn>
parent 3651cfbf
...@@ -270,6 +270,8 @@ class Envs: ...@@ -270,6 +270,8 @@ class Envs:
SGLANG_KT_MOE_AMX_WEIGHT_PATH = EnvStr(None) SGLANG_KT_MOE_AMX_WEIGHT_PATH = EnvStr(None)
SGLANG_KT_AMX_METHOD = EnvStr(None) SGLANG_KT_AMX_METHOD = EnvStr(None)
SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE = EnvInt(None) SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE = EnvInt(None)
SGLANG_KT_MOE_MAX_DEFERRED_EXPERTS_PER_TOKEN = EnvInt(None)
SGLANG_KT_MOE_TOTAL_LAYERS = EnvInt(None)
# Sparse Embeddings # Sparse Embeddings
SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None) SGLANG_EMBEDDINGS_SPARSE_HEAD = EnvStr(None)
......
...@@ -705,8 +705,9 @@ class CompressedTensorsWNA16AMXMoEMethod(CompressedTensorsMoEMethod): ...@@ -705,8 +705,9 @@ class CompressedTensorsWNA16AMXMoEMethod(CompressedTensorsMoEMethod):
threadpool_count, threadpool_count,
amx_weight_path, amx_weight_path,
chunked_prefill_size, chunked_prefill_size,
max_deferred_experts_per_token,
total_num_hidden_layers,
): ):
if not KTRANSFORMERS_AVAILABLE: if not KTRANSFORMERS_AVAILABLE:
raise ImportError( raise ImportError(
"kt_kernel is not installed, to use CompressedTensorsWNA16AMXEPMoEMethod, please install kt_kernel." "kt_kernel is not installed, to use CompressedTensorsWNA16AMXEPMoEMethod, please install kt_kernel."
...@@ -723,6 +724,8 @@ class CompressedTensorsWNA16AMXMoEMethod(CompressedTensorsMoEMethod): ...@@ -723,6 +724,8 @@ class CompressedTensorsWNA16AMXMoEMethod(CompressedTensorsMoEMethod):
self.cpuinfer = cpuinfer self.cpuinfer = cpuinfer
self.threadpool_count = threadpool_count self.threadpool_count = threadpool_count
self.amx_wrapper = None self.amx_wrapper = None
self.max_deferred_experts_per_token = max_deferred_experts_per_token
self.total_num_hidden_layers = total_num_hidden_layers
def create_weights( def create_weights(
self, self,
...@@ -733,6 +736,13 @@ class CompressedTensorsWNA16AMXMoEMethod(CompressedTensorsMoEMethod): ...@@ -733,6 +736,13 @@ class CompressedTensorsWNA16AMXMoEMethod(CompressedTensorsMoEMethod):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
layer_max_deferred = self.max_deferred_experts_per_token or 0
if (
self.max_deferred_experts_per_token is not None
and self.total_num_hidden_layers is not None
and self.layer_idx == self.total_num_hidden_layers - 1
):
layer_max_deferred = 0
self.experts_num = num_experts self.experts_num = num_experts
self.num_experts_per_tok = extra_weight_attrs.pop("top_k") self.num_experts_per_tok = extra_weight_attrs.pop("top_k")
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -751,6 +761,7 @@ class CompressedTensorsWNA16AMXMoEMethod(CompressedTensorsMoEMethod): ...@@ -751,6 +761,7 @@ class CompressedTensorsWNA16AMXMoEMethod(CompressedTensorsMoEMethod):
threadpool_count=self.threadpool_count, threadpool_count=self.threadpool_count,
amx_weight_path=self.amx_weight_path, amx_weight_path=self.amx_weight_path,
chunked_prefill_size=self.chunked_prefill_size, chunked_prefill_size=self.chunked_prefill_size,
max_deferred_experts_per_token=layer_max_deferred,
amx_method=envs.SGLANG_KT_AMX_METHOD.value, amx_method=envs.SGLANG_KT_AMX_METHOD.value,
) )
...@@ -848,6 +859,8 @@ def override_config( ...@@ -848,6 +859,8 @@ def override_config(
amx_weight_path, amx_weight_path,
amx_method, amx_method,
chunked_prefill_size, chunked_prefill_size,
max_deferred_experts_per_token,
num_hidden_layers,
): ):
"""Override MOE configuration via environment variables.""" """Override MOE configuration via environment variables."""
# Set environment variables using envs utility class # Set environment variables using envs utility class
...@@ -863,6 +876,12 @@ def override_config( ...@@ -863,6 +876,12 @@ def override_config(
envs.SGLANG_KT_AMX_METHOD.set(amx_method) envs.SGLANG_KT_AMX_METHOD.set(amx_method)
if chunked_prefill_size is not None: if chunked_prefill_size is not None:
envs.SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE.set(chunked_prefill_size) envs.SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE.set(chunked_prefill_size)
envs.SGLANG_KT_MOE_MAX_DEFERRED_EXPERTS_PER_TOKEN.set(
max_deferred_experts_per_token
)
envs.SGLANG_KT_MOE_TOTAL_LAYERS.set(num_hidden_layers)
cls.max_deferred_experts_per_token = max_deferred_experts_per_token
cls.total_num_hidden_layers = num_hidden_layers
class CompressedTensorsWNA16AMXEPMoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16AMXEPMoEMethod(CompressedTensorsMoEMethod):
...@@ -887,6 +906,8 @@ class CompressedTensorsWNA16AMXEPMoEMethod(CompressedTensorsMoEMethod): ...@@ -887,6 +906,8 @@ class CompressedTensorsWNA16AMXEPMoEMethod(CompressedTensorsMoEMethod):
threadpool_count = envs.SGLANG_KT_THREADPOOL_COUNT.value threadpool_count = envs.SGLANG_KT_THREADPOOL_COUNT.value
amx_weight_path = envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.value amx_weight_path = envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.value
chunked_prefill_size = envs.SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE.value chunked_prefill_size = envs.SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE.value
max_deferred = envs.SGLANG_KT_MOE_MAX_DEFERRED_EXPERTS_PER_TOKEN.value
total_layers = envs.SGLANG_KT_MOE_TOTAL_LAYERS.value
self.AMX_method = CompressedTensorsWNA16AMXMoEMethod( self.AMX_method = CompressedTensorsWNA16AMXMoEMethod(
quant_config, quant_config,
...@@ -896,6 +917,8 @@ class CompressedTensorsWNA16AMXEPMoEMethod(CompressedTensorsMoEMethod): ...@@ -896,6 +917,8 @@ class CompressedTensorsWNA16AMXEPMoEMethod(CompressedTensorsMoEMethod):
threadpool_count, threadpool_count,
amx_weight_path, amx_weight_path,
chunked_prefill_size, chunked_prefill_size,
max_deferred_experts_per_token=max_deferred,
total_num_hidden_layers=total_layers,
) )
self.marlin_method = CompressedTensorsWNA16MoEMethod( self.marlin_method = CompressedTensorsWNA16MoEMethod(
quant_config, self.num_gpu_experts quant_config, self.num_gpu_experts
......
...@@ -437,6 +437,7 @@ class ServerArgs: ...@@ -437,6 +437,7 @@ class ServerArgs:
kt_cpuinfer: Optional[int] = None kt_cpuinfer: Optional[int] = None
kt_threadpool_count: Optional[int] = None kt_threadpool_count: Optional[int] = None
kt_num_gpu_experts: Optional[int] = None kt_num_gpu_experts: Optional[int] = None
kt_max_deferred_experts_per_token: Optional[int] = None
# Double Sparsity # Double Sparsity
enable_double_sparsity: bool = False enable_double_sparsity: bool = False
...@@ -1329,6 +1330,21 @@ class ServerArgs: ...@@ -1329,6 +1330,21 @@ class ServerArgs:
override_config, override_config,
) )
num_hidden_layers = None
if self.kt_max_deferred_experts_per_token is not None:
try:
model_config = self.get_model_config()
base_config = (
getattr(model_config, "hf_text_config", None)
or model_config.hf_config
)
num_hidden_layers = getattr(base_config, "num_hidden_layers", None)
except Exception as exc: # noqa: BLE001
logger.warning(
"Failed to load model config for kt_max_deferred_experts_per_token: %s",
exc,
)
override_config( override_config(
CompressedTensorsWNA16AMXEPMoEMethod, CompressedTensorsWNA16AMXEPMoEMethod,
self.kt_num_gpu_experts, self.kt_num_gpu_experts,
...@@ -1337,6 +1353,8 @@ class ServerArgs: ...@@ -1337,6 +1353,8 @@ class ServerArgs:
self.kt_amx_weight_path, self.kt_amx_weight_path,
self.kt_amx_method, self.kt_amx_method,
self.chunked_prefill_size, self.chunked_prefill_size,
self.kt_max_deferred_experts_per_token,
num_hidden_layers,
) )
def _handle_data_parallelism(self): def _handle_data_parallelism(self):
...@@ -3038,6 +3056,12 @@ class ServerArgs: ...@@ -3038,6 +3056,12 @@ class ServerArgs:
type=int, type=int,
help="[ktransformers parameter] The number of GPU experts.", help="[ktransformers parameter] The number of GPU experts.",
) )
parser.add_argument(
"--kt-max-deferred-experts-per-token",
type=int,
default=ServerArgs.kt_max_deferred_experts_per_token,
help="Maximum number of experts deferred to CPU per token. All MoE layers except the final one use this value; the final layer always uses 0.",
)
# Double Sparsity # Double Sparsity
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment