Unverified Commit 924ca7c9 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Add DeepSeek V3/R1 shared experts fusion (#4918)

parent 6ff9c6a5
......@@ -399,7 +399,12 @@ def main(args: argparse.Namespace):
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts
n_share_fusion_experts = args.n_share_experts_fusion
E = (
config.n_routed_experts + n_share_fusion_experts
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
......@@ -559,6 +564,12 @@ if __name__ == "__main__":
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")
parser.add_argument(
"--n-share-experts-fusion",
type=int,
default=0,
help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1",
)
args = parser.parse_args()
main(args)
......@@ -993,13 +993,16 @@ async def benchmark(
return await request_func(request_func_input=request_func_input, pbar=pbar)
# Warmup
print("Starting initial single prompt test run...")
print(f"Starting warmup with {args.warmup_requests} sequences...")
# Use the first request for all warmup iterations
test_prompt, test_prompt_len, test_output_len = input_requests[0]
if lora_names != None and len(lora_names) != 0:
lora_name = lora_names[0]
else:
lora_name = None
# Create the test input once
test_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
......@@ -1009,14 +1012,26 @@ async def benchmark(
lora_name=lora_name,
extra_request_body=extra_request_body,
)
test_output = await request_func(request_func_input=test_input)
if not test_output.success:
# Run warmup requests
warmup_tasks = []
for _ in range(args.warmup_requests):
warmup_tasks.append(
asyncio.create_task(request_func(request_func_input=test_input))
)
warmup_outputs = await asyncio.gather(*warmup_tasks)
# Check if at least one warmup request succeeded
if not any(output.success for output in warmup_outputs):
raise ValueError(
"Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}"
"Warmup failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {warmup_outputs[0].error}"
)
else:
print("Initial test run completed. Starting main benchmark run...")
print(
f"Warmup completed with {args.warmup_requests} sequences. Starting main benchmark run..."
)
# Flush cache
if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache:
......@@ -1253,6 +1268,10 @@ def run_benchmark(args_: argparse.Namespace):
if not hasattr(args, "max_concurrency"):
args.max_concurrency = None
# Set default value for warmup_requests if not present
if not hasattr(args, "warmup_requests"):
args.warmup_requests = 1
print(f"benchmark_args={args}")
# Set global environments
......@@ -1560,6 +1579,12 @@ if __name__ == "__main__":
action="store_true",
help="Flush the cache before running the benchmark",
)
parser.add_argument(
"--warmup-requests",
type=int,
default=1,
help="Number of warmup requests to run before the benchmark",
)
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
group.add_argument(
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
}
}
......@@ -13,11 +13,6 @@ import triton
import triton.language as tl
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.int8_kernel import (
per_token_group_quant_int8,
per_token_quant_int8,
)
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
......@@ -42,9 +37,6 @@ if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
else:
from vllm import _custom_ops as vllm_ops
......@@ -764,6 +756,16 @@ def invoke_fused_moe_kernel(
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> None:
from sglang.srt.layers.quantization.int8_kernel import (
per_token_group_quant_int8,
per_token_quant_int8,
)
if _is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
......
......@@ -12,12 +12,14 @@
# limitations under the License.
# ==============================================================================
import os
from typing import Callable, Optional
import torch
import torch.nn.functional as F
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
_is_cuda = is_cuda()
......@@ -102,11 +104,13 @@ def grouped_topk(
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
n_share_experts_fusion: int = 0,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
num_experts = scores.shape[1]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
......@@ -122,9 +126,25 @@ def grouped_topk(
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if n_share_experts_fusion:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + n_share_experts_fusion,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
device=topk_ids.device,
)
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / 2.5
) # 2.5 is the routed_scaling_factor.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if n_share_experts_fusion == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
......@@ -137,11 +157,13 @@ def biased_grouped_topk_impl(
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
n_share_experts_fusion: int = 0,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = gating_output.sigmoid()
num_token = scores.shape[0]
num_experts = scores.shape[1]
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
group_scores = (
scores_for_choice.view(num_token, num_expert_group, -1)
......@@ -164,8 +186,25 @@ def biased_grouped_topk_impl(
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
topk_weights = scores.gather(1, topk_ids)
if n_share_experts_fusion:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + n_share_experts_fusion,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
device=topk_ids.device,
)
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / 2.5
) # 2.5 is the routed_scaling_factor.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if n_share_experts_fusion == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
......@@ -179,6 +218,7 @@ def biased_grouped_topk(
num_expert_group: int = 0,
topk_group: int = 0,
compiled: bool = True,
n_share_experts_fusion: int = 0,
):
biased_grouped_topk_fn = (
torch.compile(
......@@ -195,6 +235,7 @@ def biased_grouped_topk(
renormalize,
num_expert_group,
topk_group,
n_share_experts_fusion=n_share_experts_fusion,
)
......@@ -210,7 +251,10 @@ def select_experts(
correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False,
):
# DeekSeekv2 uses grouped_top_k
n_share_experts_fusion = 0
if global_server_args_dict["n_share_experts_fusion"] is not None:
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
# DeekSeek V2/V3/R1 serices models uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
......@@ -222,6 +266,7 @@ def select_experts(
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion,
)
else:
topk_weights, topk_ids = biased_grouped_topk(
......@@ -232,6 +277,7 @@ def select_experts(
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion,
)
elif torch_native and custom_routing_function is None:
topk_weights, topk_ids = fused_topk_native(
......
......@@ -51,7 +51,6 @@ except ImportError:
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.awq import AWQConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
......@@ -203,6 +202,8 @@ def get_linear_quant_method(
def gptq_get_quant_method(self, layer, prefix):
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
......
......@@ -23,7 +23,6 @@ from sglang.srt.layers.linear import (
LinearMethodBase,
UnquantizedLinearMethod,
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
......@@ -123,6 +122,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return UnquantizedLinearMethod()
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod.get_moe_method(self)
return None
......
......@@ -4,18 +4,19 @@
import enum
import logging
from enum import Enum
from typing import Callable, List, Optional
from typing import TYPE_CHECKING, Callable, List, Optional
import torch
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from sglang.srt.layers.moe.topk import select_experts
if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.utils import (
all_close_1d,
......@@ -55,7 +56,13 @@ __all__ = [
]
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
class CompressedTensorsMoEMethod:
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if cls is CompressedTensorsMoEMethod:
return super().__new__(cls)
return super().__new__(cls)
@staticmethod
def get_moe_method(
......@@ -85,6 +92,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
......@@ -112,6 +124,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
params_dtype = torch.float8_e4m3fn
......@@ -270,8 +283,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
from sglang.srt.layers.moe.topk import select_experts
topk_weights, topk_ids = select_experts(
hidden_states=x,
......@@ -291,7 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=inplace,
activation=activation,
use_fp8_w8a8=True,
w1_scale=layer.w13_weight_scale,
......@@ -306,6 +322,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
......@@ -617,6 +638,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported."
if not VLLM_AVAILABLE:
raise ImportError(
......
......@@ -81,6 +81,8 @@ global_server_args_dict = {
"disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
}
logger = logging.getLogger(__name__)
......
......@@ -157,6 +157,8 @@ class ModelRunner:
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
"n_share_experts_fusion": server_args.n_share_experts_fusion,
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
}
)
......
......@@ -16,12 +16,14 @@
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
import logging
import os
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from transformers import PretrainedConfig
from sglang.srt.distributed import (
......@@ -87,6 +89,8 @@ if _is_hip:
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__)
class DeepseekV2MLP(nn.Module):
def __init__(
......@@ -168,6 +172,12 @@ class DeepseekV2MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.n_share_experts_fusion = (
global_server_args_dict["n_share_experts_fusion"]
if global_server_args_dict["n_share_experts_fusion"] is not None
else 0
)
self.routed_scaling_factor = config.routed_scaling_factor
if self.tp_size > config.n_routed_experts:
raise ValueError(
......@@ -188,9 +198,10 @@ class DeepseekV2MoE(nn.Module):
if global_server_args_dict["enable_deepep_moe"]
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
)
self.experts = MoEImpl(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
num_experts=config.n_routed_experts + self.n_share_experts_fusion,
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
renormalize=config.norm_topk_prob,
......@@ -207,7 +218,7 @@ class DeepseekV2MoE(nn.Module):
),
)
if config.n_shared_experts is not None:
if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
# disable tp for shared experts when enable deepep moe
if not global_server_args_dict["enable_deepep_moe"]:
......@@ -267,8 +278,10 @@ class DeepseekV2MoE(nn.Module):
return self.forward_deepep(hidden_states, forward_mode)
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.n_shared_experts is not None:
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
final_hidden_states = (
......@@ -1315,7 +1328,28 @@ class DeepseekV2ForCausalLM(nn.Module):
) -> None:
super().__init__()
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
if (
global_server_args_dict.get("disable_shared_experts_fusion", False)
or self.config.architectures[0] != "DeepseekV3ForCausalLM"
or self.config.n_routed_experts != 256
or self.config.routed_scaling_factor != 2.5
):
self.n_share_experts_fusion = None
global_server_args_dict["n_share_experts_fusion"] = None
logger.info(
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
)
elif self.n_share_experts_fusion is None:
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
self.n_share_experts_fusion = self.tp_size
logger.info(
f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
)
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
......@@ -1352,6 +1386,43 @@ class DeepseekV2ForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
weights_list = list(weights)
weights_dict = dict(weights_list)
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale_inv",
"gate_proj.weight",
"gate_proj.weight_scale_inv",
"up_proj.weight",
"up_proj.weight_scale_inv",
]
names_to_remove = []
for moe_layer in tqdm(
range(
self.config.first_k_dense_replace,
self.config.num_hidden_layers,
self.config.moe_layer_freq,
),
desc=f"Cloning {self.n_share_experts_fusion} "
"replicas of the shared expert into MoE",
):
for num_repeat in range(self.n_share_experts_fusion):
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
)
weights_list.append(
(
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + num_repeat}"
f".{suffix}",
weights_dict[shared_expert_weight_name].clone(),
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
......@@ -1364,7 +1435,12 @@ class DeepseekV2ForCausalLM(nn.Module):
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_experts=self.config.n_routed_experts
+ (
self.n_share_experts_fusion
if self.n_share_experts_fusion is not None
else 0
),
)
params_dict = dict(self.named_parameters())
......
......@@ -183,6 +183,8 @@ class ServerArgs:
enable_flashmla: bool = False
flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None
n_share_experts_fusion: Optional[int] = None
disable_shared_experts_fusion: bool = False
# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
......@@ -224,6 +226,9 @@ class ServerArgs:
# GPU memory is not known yet or no GPU is available.
gpu_mem = None
if is_hip():
self.disable_shared_experts_fusion = True
# Set mem fraction static, which depends on the tensor parallelism size
if self.mem_fraction_static is None:
if self.tp_size >= 16:
......@@ -1102,6 +1107,19 @@ class ServerArgs:
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
)
parser.add_argument(
"--n-share-experts-fusion",
type=int,
default=None,
help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 "
"we use tp_size by default.",
)
parser.add_argument(
"--disable-shared-experts-fusion",
action="store_true",
help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
)
# Server warmups
parser.add_argument(
"--warmups",
......
......@@ -144,7 +144,7 @@ def moe_align_block_size_triton(
[32, 64, 128, 256], # block_size
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
[1, 2, 4, 8, 16, 32, 64], # topk
[64, 160, 256], # num_experts
[64, 160, 256, 257, 260, 264], # num_experts
)
),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment