Commit 6395b73e authored by liuchy5's avatar liuchy5
Browse files

feat:接入VLLM_USE_FUSED_FILL_RMS_CAT优化

parent af0e6d8f
...@@ -41,7 +41,7 @@ from .interfaces import SupportsPP ...@@ -41,7 +41,7 @@ from .interfaces import SupportsPP
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
import vllm.envs as envs import vllm.envs as envs
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -102,6 +102,24 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ...@@ -102,6 +102,24 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
config=self.config, config=self.config,
topk_indices_buffer=topk_indices_buffer, topk_indices_buffer=topk_indices_buffer,
) )
def fuse_fill_rms_x2_concat(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor,
weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
from lightop import fuse_fill_rms_x2_concat
fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, weight_inputs_embeds, weight_previous_hidden_states, epsilon)
def fuse_fill_rms_x2_concat_fake(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor,
weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
pass
direct_register_custom_op(
op_name="fuse_fill_rms_x2_concat",
op_func=fuse_fill_rms_x2_concat,
mutates_args=["hidden_states_fuse", "inputs_embeds"],
fake_impl=fuse_fill_rms_x2_concat_fake,
)
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment