Unverified Commit 924ca7c9 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Add DeepSeek V3/R1 shared experts fusion (#4918)

parent 6ff9c6a5
...@@ -399,7 +399,12 @@ def main(args: argparse.Namespace): ...@@ -399,7 +399,12 @@ def main(args: argparse.Namespace):
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
E = config.n_routed_experts n_share_fusion_experts = args.n_share_experts_fusion
E = (
config.n_routed_experts + n_share_fusion_experts
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
else config.n_routed_experts
)
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
...@@ -559,6 +564,12 @@ if __name__ == "__main__": ...@@ -559,6 +564,12 @@ if __name__ == "__main__":
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true") parser.add_argument("--tune", action="store_true")
parser.add_argument(
"--n-share-experts-fusion",
type=int,
default=0,
help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1",
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -993,13 +993,16 @@ async def benchmark( ...@@ -993,13 +993,16 @@ async def benchmark(
return await request_func(request_func_input=request_func_input, pbar=pbar) return await request_func(request_func_input=request_func_input, pbar=pbar)
# Warmup # Warmup
print("Starting initial single prompt test run...") print(f"Starting warmup with {args.warmup_requests} sequences...")
# Use the first request for all warmup iterations
test_prompt, test_prompt_len, test_output_len = input_requests[0] test_prompt, test_prompt_len, test_output_len = input_requests[0]
if lora_names != None and len(lora_names) != 0: if lora_names != None and len(lora_names) != 0:
lora_name = lora_names[0] lora_name = lora_names[0]
else: else:
lora_name = None lora_name = None
# Create the test input once
test_input = RequestFuncInput( test_input = RequestFuncInput(
model=model_id, model=model_id,
prompt=test_prompt, prompt=test_prompt,
...@@ -1009,14 +1012,26 @@ async def benchmark( ...@@ -1009,14 +1012,26 @@ async def benchmark(
lora_name=lora_name, lora_name=lora_name,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
) )
test_output = await request_func(request_func_input=test_input)
if not test_output.success: # Run warmup requests
warmup_tasks = []
for _ in range(args.warmup_requests):
warmup_tasks.append(
asyncio.create_task(request_func(request_func_input=test_input))
)
warmup_outputs = await asyncio.gather(*warmup_tasks)
# Check if at least one warmup request succeeded
if not any(output.success for output in warmup_outputs):
raise ValueError( raise ValueError(
"Initial test run failed - Please make sure benchmark arguments " "Warmup failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}" f"are correctly specified. Error: {warmup_outputs[0].error}"
) )
else: else:
print("Initial test run completed. Starting main benchmark run...") print(
f"Warmup completed with {args.warmup_requests} sequences. Starting main benchmark run..."
)
# Flush cache # Flush cache
if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache: if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache:
...@@ -1253,6 +1268,10 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1253,6 +1268,10 @@ def run_benchmark(args_: argparse.Namespace):
if not hasattr(args, "max_concurrency"): if not hasattr(args, "max_concurrency"):
args.max_concurrency = None args.max_concurrency = None
# Set default value for warmup_requests if not present
if not hasattr(args, "warmup_requests"):
args.warmup_requests = 1
print(f"benchmark_args={args}") print(f"benchmark_args={args}")
# Set global environments # Set global environments
...@@ -1560,6 +1579,12 @@ if __name__ == "__main__": ...@@ -1560,6 +1579,12 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Flush the cache before running the benchmark", help="Flush the cache before running the benchmark",
) )
parser.add_argument(
"--warmup-requests",
type=int,
default=1,
help="Number of warmup requests to run before the benchmark",
)
group = parser.add_argument_group("generated-shared-prefix dataset arguments") group = parser.add_argument_group("generated-shared-prefix dataset arguments")
group.add_argument( group.add_argument(
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
}
}
...@@ -13,11 +13,6 @@ import triton ...@@ -13,11 +13,6 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.int8_kernel import (
per_token_group_quant_int8,
per_token_quant_int8,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_bool_env_var, get_bool_env_var,
...@@ -42,9 +37,6 @@ if _is_cuda: ...@@ -42,9 +37,6 @@ if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, silu_and_mul
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
else: else:
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
...@@ -764,6 +756,16 @@ def invoke_fused_moe_kernel( ...@@ -764,6 +756,16 @@ def invoke_fused_moe_kernel(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
) -> None: ) -> None:
from sglang.srt.layers.quantization.int8_kernel import (
per_token_group_quant_int8,
per_token_quant_int8,
)
if _is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import os
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -102,11 +104,13 @@ def grouped_topk( ...@@ -102,11 +104,13 @@ def grouped_topk(
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
n_share_experts_fusion: int = 0,
): ):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = torch.softmax(gating_output, dim=-1) scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0] num_token = scores.shape[0]
num_experts = scores.shape[1]
group_scores = ( group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group] ) # [n, n_group]
...@@ -122,9 +126,25 @@ def grouped_topk( ...@@ -122,9 +126,25 @@ def grouped_topk(
) # [n, e] ) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if n_share_experts_fusion:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + n_share_experts_fusion,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
device=topk_ids.device,
)
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / 2.5
) # 2.5 is the routed_scaling_factor.
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if n_share_experts_fusion == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
return topk_weights.to(torch.float32), topk_ids.to(torch.int32) return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
...@@ -137,11 +157,13 @@ def biased_grouped_topk_impl( ...@@ -137,11 +157,13 @@ def biased_grouped_topk_impl(
renormalize: bool, renormalize: bool,
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
n_share_experts_fusion: int = 0,
): ):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = gating_output.sigmoid() scores = gating_output.sigmoid()
num_token = scores.shape[0] num_token = scores.shape[0]
num_experts = scores.shape[1]
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0) scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
group_scores = ( group_scores = (
scores_for_choice.view(num_token, num_expert_group, -1) scores_for_choice.view(num_token, num_expert_group, -1)
...@@ -164,8 +186,25 @@ def biased_grouped_topk_impl( ...@@ -164,8 +186,25 @@ def biased_grouped_topk_impl(
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) _, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
topk_weights = scores.gather(1, topk_ids) topk_weights = scores.gather(1, topk_ids)
if n_share_experts_fusion:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + n_share_experts_fusion,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
device=topk_ids.device,
)
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / 2.5
) # 2.5 is the routed_scaling_factor.
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if n_share_experts_fusion == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
return topk_weights.to(torch.float32), topk_ids.to(torch.int32) return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
...@@ -179,6 +218,7 @@ def biased_grouped_topk( ...@@ -179,6 +218,7 @@ def biased_grouped_topk(
num_expert_group: int = 0, num_expert_group: int = 0,
topk_group: int = 0, topk_group: int = 0,
compiled: bool = True, compiled: bool = True,
n_share_experts_fusion: int = 0,
): ):
biased_grouped_topk_fn = ( biased_grouped_topk_fn = (
torch.compile( torch.compile(
...@@ -195,6 +235,7 @@ def biased_grouped_topk( ...@@ -195,6 +235,7 @@ def biased_grouped_topk(
renormalize, renormalize,
num_expert_group, num_expert_group,
topk_group, topk_group,
n_share_experts_fusion=n_share_experts_fusion,
) )
...@@ -210,7 +251,10 @@ def select_experts( ...@@ -210,7 +251,10 @@ def select_experts(
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False, torch_native: bool = False,
): ):
# DeekSeekv2 uses grouped_top_k n_share_experts_fusion = 0
if global_server_args_dict["n_share_experts_fusion"] is not None:
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
# DeekSeek V2/V3/R1 serices models uses grouped_top_k
if use_grouped_topk: if use_grouped_topk:
assert topk_group is not None assert topk_group is not None
assert num_expert_group is not None assert num_expert_group is not None
...@@ -222,6 +266,7 @@ def select_experts( ...@@ -222,6 +266,7 @@ def select_experts(
renormalize=renormalize, renormalize=renormalize,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion,
) )
else: else:
topk_weights, topk_ids = biased_grouped_topk( topk_weights, topk_ids = biased_grouped_topk(
...@@ -232,6 +277,7 @@ def select_experts( ...@@ -232,6 +277,7 @@ def select_experts(
renormalize=renormalize, renormalize=renormalize,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion,
) )
elif torch_native and custom_routing_function is None: elif torch_native and custom_routing_function is None:
topk_weights, topk_ids = fused_topk_native( topk_weights, topk_ids = fused_topk_native(
......
...@@ -51,7 +51,6 @@ except ImportError: ...@@ -51,7 +51,6 @@ except ImportError:
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.awq import AWQConfig from sglang.srt.layers.quantization.awq import AWQConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
...@@ -203,6 +202,8 @@ def get_linear_quant_method( ...@@ -203,6 +202,8 @@ def get_linear_quant_method(
def gptq_get_quant_method(self, layer, prefix): def gptq_get_quant_method(self, layer, prefix):
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self) return GPTQMarlinMoEMethod(self)
......
...@@ -23,7 +23,6 @@ from sglang.srt.layers.linear import ( ...@@ -23,7 +23,6 @@ from sglang.srt.layers.linear import (
LinearMethodBase, LinearMethodBase,
UnquantizedLinearMethod, UnquantizedLinearMethod,
) )
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
...@@ -123,6 +122,8 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -123,6 +122,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
layer.scheme = scheme layer.scheme = scheme
return CompressedTensorsLinearMethod(self) return CompressedTensorsLinearMethod(self)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod.get_moe_method(self) return CompressedTensorsMoEMethod.get_moe_method(self)
return None return None
......
...@@ -4,18 +4,19 @@ ...@@ -4,18 +4,19 @@
import enum import enum
import logging import logging
from enum import Enum from enum import Enum
from typing import Callable, List, Optional from typing import TYPE_CHECKING, Callable, List, Optional
import torch import torch
from compressed_tensors import CompressionFormat from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationStrategy
from sglang.srt.layers.moe.fused_moe_triton import ( if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoE, FusedMoE,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
) )
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.layers.quantization.utils import ( from sglang.srt.layers.quantization.utils import (
all_close_1d, all_close_1d,
...@@ -55,7 +56,13 @@ __all__ = [ ...@@ -55,7 +56,13 @@ __all__ = [
] ]
class CompressedTensorsMoEMethod(FusedMoEMethodBase): class CompressedTensorsMoEMethod:
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if cls is CompressedTensorsMoEMethod:
return super().__new__(cls)
return super().__new__(cls)
@staticmethod @staticmethod
def get_moe_method( def get_moe_method(
...@@ -85,6 +92,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -85,6 +92,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
): ):
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
self.quant_config = quant_config self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get( self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
...@@ -112,6 +124,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -112,6 +124,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
params_dtype = torch.float8_e4m3fn params_dtype = torch.float8_e4m3fn
...@@ -270,8 +283,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -270,8 +283,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
scoring_func: str = "softmax", scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton import fused_experts
from sglang.srt.layers.moe.topk import select_experts
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
...@@ -291,7 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -291,7 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=inplace,
activation=activation, activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
...@@ -306,6 +322,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -306,6 +322,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
): ):
from sglang.srt.layers.moe.fused_moe_triton import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
self.quant_config = quant_config self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels # TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored. # are supported + check if the layer is being ignored.
...@@ -617,6 +638,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -617,6 +638,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
if not VLLM_AVAILABLE: if not VLLM_AVAILABLE:
raise ImportError( raise ImportError(
......
...@@ -81,6 +81,8 @@ global_server_args_dict = { ...@@ -81,6 +81,8 @@ global_server_args_dict = {
"disable_radix_cache": ServerArgs.disable_radix_cache, "disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"chunked_prefill_size": ServerArgs.chunked_prefill_size, "chunked_prefill_size": ServerArgs.chunked_prefill_size,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
} }
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -157,6 +157,8 @@ class ModelRunner: ...@@ -157,6 +157,8 @@ class ModelRunner:
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder, "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject, "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
"n_share_experts_fusion": server_args.n_share_experts_fusion,
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
} }
) )
......
...@@ -16,12 +16,14 @@ ...@@ -16,12 +16,14 @@
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model.""" """Inference-only DeepseekV2 model."""
import logging
import os import os
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from tqdm import tqdm
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
...@@ -87,6 +89,8 @@ if _is_hip: ...@@ -87,6 +89,8 @@ if _is_hip:
expert_distribution_recorder = ExpertDistributionRecorder() expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__)
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
def __init__( def __init__(
...@@ -168,6 +172,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -168,6 +172,12 @@ class DeepseekV2MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts self.n_shared_experts = config.n_shared_experts
self.n_share_experts_fusion = (
global_server_args_dict["n_share_experts_fusion"]
if global_server_args_dict["n_share_experts_fusion"] is not None
else 0
)
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
if self.tp_size > config.n_routed_experts: if self.tp_size > config.n_routed_experts:
raise ValueError( raise ValueError(
...@@ -188,9 +198,10 @@ class DeepseekV2MoE(nn.Module): ...@@ -188,9 +198,10 @@ class DeepseekV2MoE(nn.Module):
if global_server_args_dict["enable_deepep_moe"] if global_server_args_dict["enable_deepep_moe"]
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
) )
self.experts = MoEImpl( self.experts = MoEImpl(
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts + self.n_share_experts_fusion,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
...@@ -207,7 +218,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -207,7 +218,7 @@ class DeepseekV2MoE(nn.Module):
), ),
) )
if config.n_shared_experts is not None: if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = config.moe_intermediate_size * config.n_shared_experts
# disable tp for shared experts when enable deepep moe # disable tp for shared experts when enable deepep moe
if not global_server_args_dict["enable_deepep_moe"]: if not global_server_args_dict["enable_deepep_moe"]:
...@@ -267,8 +278,10 @@ class DeepseekV2MoE(nn.Module): ...@@ -267,8 +278,10 @@ class DeepseekV2MoE(nn.Module):
return self.forward_deepep(hidden_states, forward_mode) return self.forward_deepep(hidden_states, forward_mode)
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.n_shared_experts is not None: if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
final_hidden_states = ( final_hidden_states = (
...@@ -1315,7 +1328,28 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1315,7 +1328,28 @@ class DeepseekV2ForCausalLM(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config self.quant_config = quant_config
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
if (
global_server_args_dict.get("disable_shared_experts_fusion", False)
or self.config.architectures[0] != "DeepseekV3ForCausalLM"
or self.config.n_routed_experts != 256
or self.config.routed_scaling_factor != 2.5
):
self.n_share_experts_fusion = None
global_server_args_dict["n_share_experts_fusion"] = None
logger.info(
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
)
elif self.n_share_experts_fusion is None:
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
self.n_share_experts_fusion = self.tp_size
logger.info(
f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
)
self.model = DeepseekV2Model( self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix) config, quant_config, prefix=add_prefix("model", prefix)
) )
...@@ -1352,6 +1386,43 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1352,6 +1386,43 @@ class DeepseekV2ForCausalLM(nn.Module):
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
weights_list = list(weights)
weights_dict = dict(weights_list)
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale_inv",
"gate_proj.weight",
"gate_proj.weight_scale_inv",
"up_proj.weight",
"up_proj.weight_scale_inv",
]
names_to_remove = []
for moe_layer in tqdm(
range(
self.config.first_k_dense_replace,
self.config.num_hidden_layers,
self.config.moe_layer_freq,
),
desc=f"Cloning {self.n_share_experts_fusion} "
"replicas of the shared expert into MoE",
):
for num_repeat in range(self.n_share_experts_fusion):
for suffix in suffix_list:
shared_expert_weight_name = (
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
)
weights_list.append(
(
f"model.layers.{moe_layer}."
f"mlp.experts."
f"{self.config.n_routed_experts + num_repeat}"
f".{suffix}",
weights_dict[shared_expert_weight_name].clone(),
)
)
names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove]
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
...@@ -1364,7 +1435,12 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1364,7 +1435,12 @@ class DeepseekV2ForCausalLM(nn.Module):
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts, num_experts=self.config.n_routed_experts
+ (
self.n_share_experts_fusion
if self.n_share_experts_fusion is not None
else 0
),
) )
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
......
...@@ -183,6 +183,8 @@ class ServerArgs: ...@@ -183,6 +183,8 @@ class ServerArgs:
enable_flashmla: bool = False enable_flashmla: bool = False
flashinfer_mla_disable_ragged: bool = False flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None warmups: Optional[str] = None
n_share_experts_fusion: Optional[int] = None
disable_shared_experts_fusion: bool = False
# Debug tensor dumps # Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None debug_tensor_dump_output_folder: Optional[str] = None
...@@ -224,6 +226,9 @@ class ServerArgs: ...@@ -224,6 +226,9 @@ class ServerArgs:
# GPU memory is not known yet or no GPU is available. # GPU memory is not known yet or no GPU is available.
gpu_mem = None gpu_mem = None
if is_hip():
self.disable_shared_experts_fusion = True
# Set mem fraction static, which depends on the tensor parallelism size # Set mem fraction static, which depends on the tensor parallelism size
if self.mem_fraction_static is None: if self.mem_fraction_static is None:
if self.tp_size >= 16: if self.tp_size >= 16:
...@@ -1102,6 +1107,19 @@ class ServerArgs: ...@@ -1102,6 +1107,19 @@ class ServerArgs:
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.", help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
) )
parser.add_argument(
"--n-share-experts-fusion",
type=int,
default=None,
help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 "
"we use tp_size by default.",
)
parser.add_argument(
"--disable-shared-experts-fusion",
action="store_true",
help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
)
# Server warmups # Server warmups
parser.add_argument( parser.add_argument(
"--warmups", "--warmups",
......
...@@ -144,7 +144,7 @@ def moe_align_block_size_triton( ...@@ -144,7 +144,7 @@ def moe_align_block_size_triton(
[32, 64, 128, 256], # block_size [32, 64, 128, 256], # block_size
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
[1, 2, 4, 8, 16, 32, 64], # topk [1, 2, 4, 8, 16, 32, 64], # topk
[64, 160, 256], # num_experts [64, 160, 256, 257, 260, 264], # num_experts
) )
), ),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment