Commit 80f0794e authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-fused_fill_rms_cat' into 'v0.15.1-dev'

feat:接入VLLM_USE_FUSED_FILL_RMS_CAT优化

See merge request dcutoolkit/deeplearing/vllm!512
parents 7306fe81 6395b73e
......@@ -41,7 +41,7 @@ from .interfaces import SupportsPP
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
import vllm.envs as envs
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__)
......@@ -102,6 +102,24 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
config=self.config,
topk_indices_buffer=topk_indices_buffer,
)
def fuse_fill_rms_x2_concat(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor,
weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
from lightop import fuse_fill_rms_x2_concat
fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, weight_inputs_embeds, weight_previous_hidden_states, epsilon)
def fuse_fill_rms_x2_concat_fake(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor,
weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
pass
direct_register_custom_op(
op_name="fuse_fill_rms_x2_concat",
op_func=fuse_fill_rms_x2_concat,
mutates_args=["hidden_states_fuse", "inputs_embeds"],
fake_impl=fuse_fill_rms_x2_concat_fake,
)
def forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment