Commit 0975d9e8 authored by zhuwenwen's avatar zhuwenwen
Browse files

[kernel] add VLLM_USE_DEEPSEEK_MOE_SUM_MUL_AND to use lightop's moe_sum fusion...

[kernel] add VLLM_USE_DEEPSEEK_MOE_SUM_MUL_AND to use lightop's moe_sum fusion operator for deepseek
parent c0707728
...@@ -166,8 +166,9 @@ if TYPE_CHECKING: ...@@ -166,8 +166,9 @@ if TYPE_CHECKING:
VLLM_USE_GLOBAL_CACHE13: bool = False VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHT_OP: bool = False VLLM_USE_LIGHT_OP: bool = False
VLLM_USE_TRITON_CAT: bool = False VLLM_USE_TRITON_CAT: bool = False
USE_FUSED_RMS_QUANT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
VLLM_USE_DEEPSEEK_MOE_SUM_MUL_AND: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1109,6 +1110,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1109,6 +1110,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"USE_FUSED_RMS_QUANT": "USE_FUSED_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in
("true", "1")), ("true", "1")),
# vllm will use lightop's moe_sum fusion operator for deepseek
"VLLM_USE_DEEPSEEK_MOE_SUM_MUL_AND":
lambda: (os.getenv('VLLM_USE_DEEPSEEK_MOE_SUM_MUL_AND', 'True').lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -40,9 +40,13 @@ from vllm.model_executor.layers.fused_moe.utils import ( ...@@ -40,9 +40,13 @@ from vllm.model_executor.layers.fused_moe.utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from lightop import op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled # from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
os.environ['DPSK_FP16_QUICK'] = os.environ.get('DPSK_FP16_QUICK', '0')
dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
logger = init_logger(__name__) logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None moe_cache_singleton = None
...@@ -1257,13 +1261,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1257,13 +1261,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None: use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8, activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8,
per_channel_quant, global_num_experts, expert_map, per_channel_quant, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe) block_shape, use_nn_moe, shared_output, routed_scaling_factor)
def inplace_fused_experts_fake( def inplace_fused_experts_fake(
...@@ -1325,14 +1331,16 @@ def outplace_fused_experts( ...@@ -1325,14 +1331,16 @@ def outplace_fused_experts(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input, False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16,use_int4_w4a8, per_channel_quant, use_int4_w4a16,use_int4_w4a8, per_channel_quant,
global_num_experts, expert_map, w1_scale, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe) block_shape, use_nn_moe, shared_output, routed_scaling_factor)
def outplace_fused_experts_fake( def outplace_fused_experts_fake(
...@@ -1414,7 +1422,9 @@ def fused_experts( ...@@ -1414,7 +1422,9 @@ def fused_experts(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False, allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False, allow_cutlass_block_scaled_grouped_gemm: bool = False,
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better # For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available. # permute/unpermute ops are available.
N = w1.size(1) N = w1.size(1)
...@@ -1472,7 +1482,9 @@ def fused_experts( ...@@ -1472,7 +1482,9 @@ def fused_experts(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
def fused_experts_impl( def fused_experts_impl(
...@@ -1500,6 +1512,8 @@ def fused_experts_impl( ...@@ -1500,6 +1512,8 @@ def fused_experts_impl(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
if use_nn_moe: if use_nn_moe:
...@@ -1744,9 +1758,30 @@ def fused_experts_impl( ...@@ -1744,9 +1758,30 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), if envs.VLLM_USE_DEEPSEEK_MOE_SUM_MUL_AND:
out_hidden_states[begin_chunk_idx:end_chunk_idx]) if envs.VLLM_USE_LIGHT_OP and not dpsk_fp16_quick:
if shared_output is not None:
op.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx], shared_output[begin_chunk_idx:end_chunk_idx], routed_scaling_factor)
else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
if shared_output is not None:
if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
out_hidden_states[begin_chunk_idx:end_chunk_idx] = out_hidden_states[begin_chunk_idx:end_chunk_idx] * routed_scaling_factor + shared_output[begin_chunk_idx:end_chunk_idx]
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
out_hidden_states[begin_chunk_idx:end_chunk_idx] + shared_output[begin_chunk_idx:end_chunk_idx] * (1. / routed_scaling_factor)
# Deepseek theoretically wouldn't happen
else:
if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor
else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states return out_hidden_states
...@@ -1779,6 +1814,8 @@ def fused_moe( ...@@ -1779,6 +1814,8 @@ def fused_moe(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -1864,7 +1901,9 @@ def fused_moe( ...@@ -1864,7 +1901,9 @@ def fused_moe(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
...@@ -2081,4 +2120,4 @@ def modular_triton_fused_moe( ...@@ -2081,4 +2120,4 @@ def modular_triton_fused_moe(
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
block_shape=block_shape, block_shape=block_shape,
), ),
) )
\ No newline at end of file
...@@ -206,6 +206,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -206,6 +206,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_output: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
...@@ -357,6 +358,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -357,6 +358,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_output: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
...@@ -385,6 +387,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -385,6 +387,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x=x, x=x,
layer=layer, layer=layer,
router_logits=router_logits, router_logits=router_logits,
shared_output=shared_output,
top_k=top_k, top_k=top_k,
renormalize=renormalize, renormalize=renormalize,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
...@@ -408,6 +411,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -408,6 +411,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_grouped_topk: bool, use_grouped_topk: bool,
top_k: int, top_k: int,
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_output: torch.Tensor,
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
...@@ -460,7 +464,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -460,7 +464,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
use_nn_moe=use_nn_moe use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
def forward_cpu( def forward_cpu(
...@@ -1427,13 +1433,14 @@ class FusedMoE(torch.nn.Module): ...@@ -1427,13 +1433,14 @@ class FusedMoE(torch.nn.Module):
return tensor_model_parallel_all_reduce(final_hidden_states) return tensor_model_parallel_all_reduce(final_hidden_states)
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor,
shared_output: torch.Tensor):
# TODO: Once the OOM issue for the TPU backend is resolved, we will # TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op. # switch to using the moe_forward custom op.
if current_platform.is_tpu(): if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
else: else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits, return torch.ops.vllm.moe_forward(hidden_states, router_logits, shared_output,
self.layer_name) self.layer_name)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor, def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
...@@ -1513,7 +1520,8 @@ class FusedMoE(torch.nn.Module): ...@@ -1513,7 +1520,8 @@ class FusedMoE(torch.nn.Module):
return full_final_hidden_states return full_final_hidden_states
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor,
shared_output: torch.Tensor):
assert self.quant_method is not None assert self.quant_method is not None
if (self.moe_parallel_config.use_pplx_kernels if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels): or self.moe_parallel_config.use_deepep_ll_kernels):
...@@ -1531,6 +1539,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1531,6 +1539,7 @@ class FusedMoE(torch.nn.Module):
layer=self, layer=self,
x=hidden_states, x=hidden_states,
router_logits=router_logits, router_logits=router_logits,
shared_output=shared_output,
top_k=self.top_k, top_k=self.top_k,
renormalize=self.renormalize, renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
...@@ -1619,16 +1628,16 @@ class FusedMoE(torch.nn.Module): ...@@ -1619,16 +1628,16 @@ class FusedMoE(torch.nn.Module):
return s return s
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_output: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits, shared_output)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_output: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1640,4 +1649,4 @@ direct_register_custom_op( ...@@ -1640,4 +1649,4 @@ direct_register_custom_op(
fake_impl=moe_forward_fake, fake_impl=moe_forward_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order, ),
) )
\ No newline at end of file
...@@ -209,24 +209,10 @@ class DeepseekV2MoE(nn.Module): ...@@ -209,24 +209,10 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick: final_hidden_states = self.experts(
final_hidden_states = self.experts( hidden_states=hidden_states,
hidden_states=hidden_states, router_logits=router_logits,
router_logits=router_logits) * self.routed_scaling_factor shared_output=shared_output)
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1: if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO:
...@@ -1125,4 +1111,4 @@ def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, ...@@ -1125,4 +1111,4 @@ def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
for i in range(config.num_nextn_predict_layers): for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx+i}."): if weight_name.startswith(f"model.layers.{layer_idx+i}."):
return layer_idx + i return layer_idx + i
return None return None
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment