Commit 8d2cac26 authored by zhuwenwen's avatar zhuwenwen
Browse files

[kernel] add lightop's moe_sum(mul+add) fusion operator for deepseek

[FIX] 修复mtp和VLLM_USE_TRITON_CAT不能一起开的bug
parent 5086453d
...@@ -164,10 +164,10 @@ if TYPE_CHECKING: ...@@ -164,10 +164,10 @@ if TYPE_CHECKING:
VLLM_USE_FLASH_ATTN_PA: bool = False VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False VLLM_USE_APEX_RN: bool = False
VLLM_USE_GLOBAL_CACHE13: bool = False VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHT_OP: bool = False VLLM_USE_LIGHTOP: bool = False
VLLM_USE_TRITON_CAT: bool = False VLLM_USE_OPT_CAT: bool = False
USE_FUSED_RMS_QUANT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False USE_FUSED_SILU_MUL_QUANT: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1095,12 +1095,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1095,12 +1095,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use global cache for moe # vLLM will use global cache for moe
"VLLM_USE_LIGHT_OP": "VLLM_USE_LIGHTOP":
lambda: (os.environ.get("VLLM_USE_LIGHT_OP", "True").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use global cache for moe # vLLM will use global cache for moe
"VLLM_USE_TRITON_CAT": "VLLM_USE_OPT_CAT":
lambda: (os.environ.get("VLLM_USE_TRITON_CAT", "True").lower() in lambda: (os.environ.get("VLLM_USE_OPT_CAT", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use opt merge_aatn_states,not triton # vLLM will use opt merge_aatn_states,not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
......
...@@ -43,9 +43,17 @@ from vllm.utils import direct_register_custom_op ...@@ -43,9 +43,17 @@ from vllm.utils import direct_register_custom_op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled # from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
if envs.VLLM_USE_LIGHTOP:
from lightop import op
os.environ['DPSK_FP16_QUICK'] = os.environ.get('DPSK_FP16_QUICK', '0')
dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1'
logger = init_logger(__name__) logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None moe_cache_singleton = None
def get_moe_cache(top_k_num,N,K,device,dtype): def get_moe_cache(top_k_num,N,K,device,dtype):
global moe_cache_singleton global moe_cache_singleton
if moe_cache_singleton is None: if moe_cache_singleton is None:
...@@ -1257,13 +1265,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1257,13 +1265,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None: use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8, activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8,
per_channel_quant, global_num_experts, expert_map, per_channel_quant, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe) block_shape, use_nn_moe, shared_output, routed_scaling_factor)
def inplace_fused_experts_fake( def inplace_fused_experts_fake(
...@@ -1289,7 +1299,9 @@ def inplace_fused_experts_fake( ...@@ -1289,7 +1299,9 @@ def inplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None: use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> None:
pass pass
...@@ -1325,14 +1337,16 @@ def outplace_fused_experts( ...@@ -1325,14 +1337,16 @@ def outplace_fused_experts(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input, False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16,use_int4_w4a8, per_channel_quant, use_int4_w4a16,use_int4_w4a8, per_channel_quant,
global_num_experts, expert_map, w1_scale, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe) block_shape, use_nn_moe, shared_output, routed_scaling_factor)
def outplace_fused_experts_fake( def outplace_fused_experts_fake(
...@@ -1357,7 +1371,9 @@ def outplace_fused_experts_fake( ...@@ -1357,7 +1371,9 @@ def outplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1414,7 +1430,9 @@ def fused_experts( ...@@ -1414,7 +1430,9 @@ def fused_experts(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
allow_deep_gemm: bool = False, allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False, allow_cutlass_block_scaled_grouped_gemm: bool = False,
use_nn_moe: Optional[bool] = False) -> torch.Tensor: use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better # For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available. # permute/unpermute ops are available.
N = w1.size(1) N = w1.size(1)
...@@ -1472,7 +1490,9 @@ def fused_experts( ...@@ -1472,7 +1490,9 @@ def fused_experts(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
def fused_experts_impl( def fused_experts_impl(
...@@ -1500,6 +1520,8 @@ def fused_experts_impl( ...@@ -1500,6 +1520,8 @@ def fused_experts_impl(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
if use_nn_moe: if use_nn_moe:
...@@ -1544,7 +1566,9 @@ def fused_experts_impl( ...@@ -1544,7 +1566,9 @@ def fused_experts_impl(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=False use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor
) )
elif use_int4_w4a8 is True: elif use_int4_w4a8 is True:
return fused_experts_impl_w4a8(hidden_states=hidden_states, return fused_experts_impl_w4a8(hidden_states=hidden_states,
...@@ -1571,7 +1595,9 @@ def fused_experts_impl( ...@@ -1571,7 +1595,9 @@ def fused_experts_impl(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe= False use_nn_moe= False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor
) )
# #
...@@ -1744,8 +1770,28 @@ def fused_experts_impl( ...@@ -1744,8 +1770,28 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), if envs.VLLM_USE_LIGHTOP and not dpsk_fp16_quick:
out_hidden_states[begin_chunk_idx:end_chunk_idx]) if shared_output is not None:
op.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx], shared_output[begin_chunk_idx:end_chunk_idx], routed_scaling_factor)
# else:
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx])
# if shared_output is not None:
# if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
# out_hidden_states[begin_chunk_idx:end_chunk_idx] = out_hidden_states[begin_chunk_idx:end_chunk_idx] * routed_scaling_factor + shared_output[begin_chunk_idx:end_chunk_idx]
# else:
# # Fix FP16 overflow
# # See DeepseekV2DecoderLayer for more details.
# out_hidden_states[begin_chunk_idx:end_chunk_idx] + shared_output[begin_chunk_idx:end_chunk_idx] * (1. / routed_scaling_factor)
# else:
# if hidden_states.dtype != torch.float16 or dpsk_fp16_quick:
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor
else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states return out_hidden_states
...@@ -1779,6 +1825,8 @@ def fused_moe( ...@@ -1779,6 +1825,8 @@ def fused_moe(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -1864,7 +1912,9 @@ def fused_moe( ...@@ -1864,7 +1912,9 @@ def fused_moe(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
......
...@@ -42,7 +42,9 @@ from vllm.platforms.interface import CpuArchEnum ...@@ -42,7 +42,9 @@ from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from lightop import op
if envs.VLLM_USE_LIGHTOP:
from lightop import op as op
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts from .fused_batched_moe import BatchedTritonExperts
...@@ -222,6 +224,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -222,6 +224,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -373,6 +376,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -373,6 +376,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
...@@ -397,6 +401,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -397,6 +401,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
shared_output=shared_output,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate) use_fused_gate=use_fused_gate)
...@@ -418,6 +423,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -418,6 +423,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
shared_output: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
...@@ -460,7 +466,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -460,7 +466,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
use_nn_moe=use_nn_moe use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor
) )
def forward_cpu( def forward_cpu(
...@@ -1278,7 +1286,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1278,7 +1286,7 @@ class FusedMoE(torch.nn.Module):
assert topk_group is not None assert topk_group is not None
assert num_expert_group is not None assert num_expert_group is not None
if use_fused_gate: if use_fused_gate:
if envs.VLLM_USE_LIGHT_OP: if envs.VLLM_USE_LIGHTOP:
topk_weights, topk_ids = op.moe_fused_gate( topk_weights, topk_ids = op.moe_fused_gate(
router_logits, router_logits,
e_score_correction_bias, e_score_correction_bias,
...@@ -1427,13 +1435,14 @@ class FusedMoE(torch.nn.Module): ...@@ -1427,13 +1435,14 @@ class FusedMoE(torch.nn.Module):
return tensor_model_parallel_all_reduce(final_hidden_states) return tensor_model_parallel_all_reduce(final_hidden_states)
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor,
shared_output: Optional[torch.Tensor] = None):
# TODO: Once the OOM issue for the TPU backend is resolved, we will # TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op. # switch to using the moe_forward custom op.
if current_platform.is_tpu(): if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
else: else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits, return torch.ops.vllm.moe_forward(hidden_states, router_logits, shared_output,
self.layer_name) self.layer_name)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor, def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
...@@ -1513,7 +1522,8 @@ class FusedMoE(torch.nn.Module): ...@@ -1513,7 +1522,8 @@ class FusedMoE(torch.nn.Module):
return full_final_hidden_states return full_final_hidden_states
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor): router_logits: torch.Tensor,
shared_output: Optional[torch.Tensor] = None):
assert self.quant_method is not None assert self.quant_method is not None
if (self.moe_parallel_config.use_pplx_kernels if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels): or self.moe_parallel_config.use_deepep_ll_kernels):
...@@ -1547,6 +1557,7 @@ class FusedMoE(torch.nn.Module): ...@@ -1547,6 +1557,7 @@ class FusedMoE(torch.nn.Module):
expert_load_view=self.expert_load_view, expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map, logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count, logical_replica_count=self.logical_replica_count,
shared_output=shared_output,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate use_fused_gate=self.use_fused_gate
...@@ -1619,17 +1630,17 @@ class FusedMoE(torch.nn.Module): ...@@ -1619,17 +1630,17 @@ class FusedMoE(torch.nn.Module):
return s return s
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str, shared_output: Optional[torch.Tensor] = None) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits, shared_output)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor: layer_name: str, shared_output: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
......
...@@ -9,7 +9,8 @@ from vllm.triton_utils import tl, triton ...@@ -9,7 +9,8 @@ from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, round_up from vllm.utils import cdiv, round_up
import vllm.envs as envs import vllm.envs as envs
from lightop import op if envs.VLLM_USE_LIGHTOP:
from lightop import op as op
@triton.jit @triton.jit
...@@ -232,7 +233,7 @@ def moe_align_block_size( ...@@ -232,7 +233,7 @@ def moe_align_block_size(
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device) device=topk_ids.device)
if envs.VLLM_USE_LIGHT_OP: if envs.VLLM_USE_LIGHTOP:
op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad, None) expert_ids, num_tokens_post_pad, None)
else: else:
......
...@@ -230,6 +230,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -230,6 +230,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
**_ **_
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -272,4 +273,6 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -272,4 +273,6 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
...@@ -238,10 +238,16 @@ def get_model_architecture( ...@@ -238,10 +238,16 @@ def get_model_architecture(
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
else: else:
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0': if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0':
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
else: else:
os.environ['LM_NN'] = '1' os.environ['LM_NN'] = '1'
if (architectures == ['DeepseekV3ForCausalLM'] or architectures == ['DeepSeekMTPModel']):
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
if os.getenv('GEMM_PAD') != '1': if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1': if os.getenv('FA_PAD') != '1':
......
...@@ -213,24 +213,30 @@ class DeepseekV2MoE(nn.Module): ...@@ -213,24 +213,30 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick: if envs.VLLM_USE_LIGHTOP and not self.dpsk_fp16_quick:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor router_logits=router_logits,
shared_output=shared_output)
else: else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick: if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
else: else:
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = self.experts(hidden_states=hidden_states,
* (1. / self.routed_scaling_factor) router_logits=router_logits)
if shared_output is not None:
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
if self.tp_size > 1: if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO:
......
...@@ -216,7 +216,9 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, ...@@ -216,7 +216,9 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
if envs.VLLM_USE_OPT_CAT:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
...@@ -928,7 +930,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -928,7 +930,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\ k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT: if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024: if k_nope.shape[0] > 1024:
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2) dim=2)
...@@ -993,7 +995,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -993,7 +995,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\ k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT: if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024: if k_nope.shape[0] > 1024:
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2) dim=2)
......
...@@ -20,7 +20,9 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend, ...@@ -20,7 +20,9 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm import envs from vllm import envs
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
if envs.VLLM_USE_OPT_CAT:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -166,8 +168,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -166,8 +168,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if envs.VLLM_USE_TRITON_CAT: if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] <= 1024: if q_nope.shape[0] < 1024:
q = concat_helper_decode(q_nope, q_pe, dim=2)\ q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1) .unsqueeze(1)
else: else:
......
...@@ -5,7 +5,10 @@ from functools import reduce ...@@ -5,7 +5,10 @@ from functools import reduce
import pytest import pytest
import torch import torch
import math import math
from lightop import ds_cat import vllm.envs as envs
if envs.VLLM_USE_LIGHTOP:
from lightop import ds_cat
def test_concat_Acc_prefill(shape_pair, dim): def test_concat_Acc_prefill(shape_pair, dim):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment