Commit 46dd30e7 authored by zhuwenwen's avatar zhuwenwen
Browse files

set USE_FUSED_RMS_QUANT=0 and USE_FUSED_SILU_MUL_QUANT=0

add VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD
parent 86d92eb9
...@@ -167,6 +167,7 @@ if TYPE_CHECKING: ...@@ -167,6 +167,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP: bool = False VLLM_USE_LIGHTOP: bool = False
VLLM_USE_OPT_CAT: bool = False VLLM_USE_OPT_CAT: bool = False
VLLM_USE_OPT_MOE_SUM: bool = False VLLM_USE_OPT_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: bool = False
VLLM_USE_LIGHTOP_MOE_SUM: bool = False VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
...@@ -1111,6 +1112,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1111,6 +1112,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_OPT_MOE_SUM": "VLLM_USE_OPT_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop moe_sum_mul_add
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD", "True").lower() in
("true", "1")),
# vLLM will use lightop moe_sum # vLLM will use lightop moe_sum
"VLLM_USE_LIGHTOP_MOE_SUM": "VLLM_USE_LIGHTOP_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "True").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "True").lower() in
......
...@@ -1895,7 +1895,7 @@ def fused_experts_impl( ...@@ -1895,7 +1895,7 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
if envs.VLLM_USE_LIGHTOP: if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
from lightop import op as op from lightop import op as op
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()), op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=shared_output[begin_chunk_idx:end_chunk_idx], output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=shared_output[begin_chunk_idx:end_chunk_idx],
......
...@@ -247,6 +247,8 @@ def get_model_architecture( ...@@ -247,6 +247,8 @@ def get_model_architecture(
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]: if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"): if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1' os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
...@@ -258,12 +260,11 @@ def get_model_architecture( ...@@ -258,12 +260,11 @@ def get_model_architecture(
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]: if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"): if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1' os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
if not envs.is_set("USE_FUSED_RMS_QUANT"):
os.environ['USE_FUSED_RMS_QUANT'] = '1'
if not envs.is_set("USE_FUSED_SILU_MUL_QUANT"):
os.environ['USE_FUSED_SILU_MUL_QUANT'] = '1'
# awq相关配置 # awq相关配置
try: try:
if os.getenv('AWQ_MOE_SZ') == None: if os.getenv('AWQ_MOE_SZ') == None:
......
...@@ -214,7 +214,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -214,7 +214,7 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if envs.VLLM_USE_LIGHTOP: if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
...@@ -230,14 +230,14 @@ class DeepseekV2MoE(nn.Module): ...@@ -230,14 +230,14 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
if shared_output is not None: if shared_output is not None:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
else: else:
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
if self.tp_size > 1: if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment