"examples/vscode:/vscode.git/clone" did not exist on "6c4dbe23eb85e5d1da00ccaf4923a275d8769a7f"
Commit 0975d9e8 authored by zhuwenwen's avatar zhuwenwen
Browse files

[kernel] add VLLM_USE_DEEPSEEK_MOE_SUM_MUL_AND to use lightop's moe_sum fusion...

[kernel] add VLLM_USE_DEEPSEEK_MOE_SUM_MUL_AND to use lightop's moe_sum fusion operator for deepseek
parent c0707728
......@@ -166,8 +166,9 @@ if TYPE_CHECKING:
VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHT_OP: bool = False
VLLM_USE_TRITON_CAT: bool = False
USE_FUSED_RMS_QUANT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
VLLM_USE_DEEPSEEK_MOE_SUM_MUL_AND: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1109,6 +1110,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"USE_FUSED_RMS_QUANT":
lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in
("true", "1")),
# vllm will use lightop's moe_sum fusion operator for deepseek
"VLLM_USE_DEEPSEEK_MOE_SUM_MUL_AND":
lambda: (os.getenv('VLLM_USE_DEEPSEEK_MOE_SUM_MUL_AND', 'True').lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -40,9 +40,13 @@ from vllm.model_executor.layers.fused_moe.utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from lightop import op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
os.environ['DPSK_FP16_QUICK'] = os.environ.get('DPSK_FP16_QUICK', '0')
dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None
......@@ -1257,13 +1261,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None:
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8,
per_channel_quant, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe)
block_shape, use_nn_moe, shared_output, routed_scaling_factor)
def inplace_fused_experts_fake(
......@@ -1325,14 +1331,16 @@ def outplace_fused_experts(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16,use_int4_w4a8, per_channel_quant,
global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe)
block_shape, use_nn_moe, shared_output, routed_scaling_factor)
def outplace_fused_experts_fake(
......@@ -1414,7 +1422,9 @@ def fused_experts(
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N = w1.size(1)
......@@ -1472,7 +1482,9 @@ def fused_experts(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe=use_nn_moe)
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
def fused_experts_impl(
......@@ -1500,6 +1512,8 @@ def fused_experts_impl(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
num_tokens = hidden_states.size(0)
if use_nn_moe:
......@@ -1744,9 +1758,30 @@ def fused_experts_impl(
block_shape=block_shape,
use_nn_moe=use_nn_moe)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
if envs.VLLM_USE_DEEPSEEK_MOE_SUM_MUL_AND:
if envs.VLLM_USE_LIGHT_OP and not dpsk_fp16_quick:
if shared_output is not None:
op.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx], shared_output[begin_chunk_idx:end_chunk_idx], routed_scaling_factor)
else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
if shared_output is not None:
if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
out_hidden_states[begin_chunk_idx:end_chunk_idx] = out_hidden_states[begin_chunk_idx:end_chunk_idx] * routed_scaling_factor + shared_output[begin_chunk_idx:end_chunk_idx]
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
out_hidden_states[begin_chunk_idx:end_chunk_idx] + shared_output[begin_chunk_idx:end_chunk_idx] * (1. / routed_scaling_factor)
# Deepseek theoretically wouldn't happen
else:
if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor
else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states
......@@ -1779,6 +1814,8 @@ def fused_moe(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
......@@ -1864,7 +1901,9 @@ def fused_moe(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe=use_nn_moe)
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
......@@ -2081,4 +2120,4 @@ def modular_triton_fused_moe(
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
),
)
)
\ No newline at end of file
......@@ -206,6 +206,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
shared_output: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
......@@ -357,6 +358,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
shared_output: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
......@@ -385,6 +387,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x=x,
layer=layer,
router_logits=router_logits,
shared_output=shared_output,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
......@@ -408,6 +411,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
shared_output: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
......@@ -460,7 +464,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=use_nn_moe
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def forward_cpu(
......@@ -1427,13 +1433,14 @@ class FusedMoE(torch.nn.Module):
return tensor_model_parallel_all_reduce(final_hidden_states)
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
router_logits: torch.Tensor,
shared_output: torch.Tensor):
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits)
else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
return torch.ops.vllm.moe_forward(hidden_states, router_logits, shared_output,
self.layer_name)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
......@@ -1513,7 +1520,8 @@ class FusedMoE(torch.nn.Module):
return full_final_hidden_states
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
router_logits: torch.Tensor,
shared_output: torch.Tensor):
assert self.quant_method is not None
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
......@@ -1531,6 +1539,7 @@ class FusedMoE(torch.nn.Module):
layer=self,
x=hidden_states,
router_logits=router_logits,
shared_output=shared_output,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
......@@ -1619,16 +1628,16 @@ class FusedMoE(torch.nn.Module):
return s
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_output: torch.Tensor,
layer_name: str) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits)
return self.forward_impl(hidden_states, router_logits, shared_output)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_output: torch.Tensor,
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -1640,4 +1649,4 @@ direct_register_custom_op(
fake_impl=moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
)
\ No newline at end of file
......@@ -209,24 +209,10 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
......@@ -1125,4 +1111,4 @@ def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
return layer_idx + i
return None
return None
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment