Commit 8d2cac26 authored by zhuwenwen's avatar zhuwenwen
Browse files

[kernel] add lightop's moe_sum(mul+add) fusion operator for deepseek

[FIX] 修复mtp和VLLM_USE_TRITON_CAT不能一起开的bug
parent 5086453d
......@@ -164,10 +164,10 @@ if TYPE_CHECKING:
VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False
VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHT_OP: bool = False
VLLM_USE_TRITON_CAT: bool = False
USE_FUSED_RMS_QUANT: bool = False
VLLM_USE_LIGHTOP: bool = False
VLLM_USE_OPT_CAT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
def get_default_cache_root():
......@@ -1095,12 +1095,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")),
# vLLM will use global cache for moe
"VLLM_USE_LIGHT_OP":
lambda: (os.environ.get("VLLM_USE_LIGHT_OP", "True").lower() in
"VLLM_USE_LIGHTOP":
lambda: (os.environ.get("VLLM_USE_LIGHTOP", "False").lower() in
("true", "1")),
# vLLM will use global cache for moe
"VLLM_USE_TRITON_CAT":
lambda: (os.environ.get("VLLM_USE_TRITON_CAT", "True").lower() in
"VLLM_USE_OPT_CAT":
lambda: (os.environ.get("VLLM_USE_OPT_CAT", "True").lower() in
("true", "1")),
# vLLM will use opt merge_aatn_states,not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT":
......
......@@ -43,9 +43,17 @@ from vllm.utils import direct_register_custom_op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
if envs.VLLM_USE_LIGHTOP:
from lightop import op
os.environ['DPSK_FP16_QUICK'] = os.environ.get('DPSK_FP16_QUICK', '0')
dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None
def get_moe_cache(top_k_num,N,K,device,dtype):
global moe_cache_singleton
if moe_cache_singleton is None:
......@@ -1257,13 +1265,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None:
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8,
per_channel_quant, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe)
block_shape, use_nn_moe, shared_output, routed_scaling_factor)
def inplace_fused_experts_fake(
......@@ -1289,7 +1299,9 @@ def inplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None:
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> None:
pass
......@@ -1325,14 +1337,16 @@ def outplace_fused_experts(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16,use_int4_w4a8, per_channel_quant,
global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe)
block_shape, use_nn_moe, shared_output, routed_scaling_factor)
def outplace_fused_experts_fake(
......@@ -1357,7 +1371,9 @@ def outplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -1414,7 +1430,9 @@ def fused_experts(
block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N = w1.size(1)
......@@ -1472,7 +1490,9 @@ def fused_experts(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe=use_nn_moe)
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
def fused_experts_impl(
......@@ -1500,6 +1520,8 @@ def fused_experts_impl(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
num_tokens = hidden_states.size(0)
if use_nn_moe:
......@@ -1544,7 +1566,9 @@ def fused_experts_impl(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe=False
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor
)
elif use_int4_w4a8 is True:
return fused_experts_impl_w4a8(hidden_states=hidden_states,
......@@ -1571,7 +1595,9 @@ def fused_experts_impl(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe= False
use_nn_moe= False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor
)
#
......@@ -1744,8 +1770,28 @@ def fused_experts_impl(
block_shape=block_shape,
use_nn_moe=use_nn_moe)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
if envs.VLLM_USE_LIGHTOP and not dpsk_fp16_quick:
if shared_output is not None:
op.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx], shared_output[begin_chunk_idx:end_chunk_idx], routed_scaling_factor)
# else:
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx])
# if shared_output is not None:
# if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
# out_hidden_states[begin_chunk_idx:end_chunk_idx] = out_hidden_states[begin_chunk_idx:end_chunk_idx] * routed_scaling_factor + shared_output[begin_chunk_idx:end_chunk_idx]
# else:
# # Fix FP16 overflow
# # See DeepseekV2DecoderLayer for more details.
# out_hidden_states[begin_chunk_idx:end_chunk_idx] + shared_output[begin_chunk_idx:end_chunk_idx] * (1. / routed_scaling_factor)
# else:
# if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor
else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states
......@@ -1779,6 +1825,8 @@ def fused_moe(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
......@@ -1864,7 +1912,9 @@ def fused_moe(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
use_nn_moe=use_nn_moe)
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
......
......@@ -42,7 +42,9 @@ from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
from vllm import _custom_ops as ops
from lightop import op
if envs.VLLM_USE_LIGHTOP:
from lightop import op as op
if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts
......@@ -222,6 +224,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
......@@ -373,6 +376,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
......@@ -397,6 +401,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
shared_output=shared_output,
use_nn_moe=use_nn_moe,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate)
......@@ -418,6 +423,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
......@@ -460,7 +466,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=use_nn_moe
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor
)
def forward_cpu(
......@@ -1278,7 +1286,7 @@ class FusedMoE(torch.nn.Module):
assert topk_group is not None
assert num_expert_group is not None
if use_fused_gate:
if envs.VLLM_USE_LIGHT_OP:
if envs.VLLM_USE_LIGHTOP:
topk_weights, topk_ids = op.moe_fused_gate(
router_logits,
e_score_correction_bias,
......@@ -1427,13 +1435,14 @@ class FusedMoE(torch.nn.Module):
return tensor_model_parallel_all_reduce(final_hidden_states)
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
router_logits: torch.Tensor,
shared_output: Optional[torch.Tensor] = None):
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits)
else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
return torch.ops.vllm.moe_forward(hidden_states, router_logits, shared_output,
self.layer_name)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
......@@ -1513,7 +1522,8 @@ class FusedMoE(torch.nn.Module):
return full_final_hidden_states
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
router_logits: torch.Tensor,
shared_output: Optional[torch.Tensor] = None):
assert self.quant_method is not None
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
......@@ -1547,6 +1557,7 @@ class FusedMoE(torch.nn.Module):
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
shared_output=shared_output,
use_nn_moe=self.use_nn_moe,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate
......@@ -1619,17 +1630,17 @@ class FusedMoE(torch.nn.Module):
return s
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, shared_output: Optional[torch.Tensor] = None) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits)
return self.forward_impl(hidden_states, router_logits, shared_output)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
layer_name: str, shared_output: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(hidden_states)
......
......@@ -9,7 +9,8 @@ from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, round_up
import vllm.envs as envs
from lightop import op
if envs.VLLM_USE_LIGHTOP:
from lightop import op as op
@triton.jit
......@@ -232,7 +233,7 @@ def moe_align_block_size(
dtype=torch.int32,
device=topk_ids.device)
if envs.VLLM_USE_LIGHT_OP:
if envs.VLLM_USE_LIGHTOP:
op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad, None)
else:
......
......@@ -230,6 +230,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
**_
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -272,4 +273,6 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
......@@ -238,10 +238,16 @@ def get_model_architecture(
os.environ['LLAMA_NN'] = '0'
else:
os.environ['LLAMA_NN'] = '1'
if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0':
os.environ['LM_NN'] = '0'
else:
os.environ['LM_NN'] = '1'
if (architectures == ['DeepseekV3ForCausalLM'] or architectures == ['DeepSeekMTPModel']):
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1':
......
......@@ -213,24 +213,30 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
if envs.VLLM_USE_LIGHTOP and not self.dpsk_fp16_quick:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
router_logits=router_logits,
shared_output=shared_output)
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO:
......
......@@ -216,7 +216,9 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
if envs.VLLM_USE_OPT_CAT:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
......@@ -928,7 +930,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT:
if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024:
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2)
......@@ -993,7 +995,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT:
if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024:
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2)
......
......@@ -20,7 +20,9 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm import envs
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
if envs.VLLM_USE_OPT_CAT:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
logger = init_logger(__name__)
......@@ -166,8 +168,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if envs.VLLM_USE_TRITON_CAT:
if q_nope.shape[0] <= 1024:
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1)
else:
......
......@@ -5,7 +5,10 @@ from functools import reduce
import pytest
import torch
import math
from lightop import ds_cat
import vllm.envs as envs
if envs.VLLM_USE_LIGHTOP:
from lightop import ds_cat
def test_concat_Acc_prefill(shape_pair, dim):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment