Commit bac269d7 authored by zhuwenwen's avatar zhuwenwen
Browse files

update fuse_fill_rms_x2_concat

parent bdae1255
...@@ -29,8 +29,7 @@ from .utils import maybe_prefix ...@@ -29,8 +29,7 @@ from .utils import maybe_prefix
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
import vllm.envs as envs import vllm.envs as envs
from lightop import fuse_fill_rms_x2_concat from vllm.utils import direct_register_custom_op
class SharedHead(nn.Module): class SharedHead(nn.Module):
...@@ -75,6 +74,24 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ...@@ -75,6 +74,24 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config, self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
cache_config, quant_config) cache_config, quant_config)
def fuse_fill_rms_x2_concat(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor,
weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
from lightop import fuse_fill_rms_x2_concat
fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, weight_inputs_embeds, weight_previous_hidden_states, epsilon)
def fuse_fill_rms_x2_concat_fake(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor,
weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
pass
direct_register_custom_op(
op_name="fuse_fill_rms_x2_concat",
op_func=fuse_fill_rms_x2_concat,
mutates_args=["hidden_states_fuse", "inputs_embeds"],
fake_impl=fuse_fill_rms_x2_concat_fake,
)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -88,8 +105,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ...@@ -88,8 +105,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
assert inputs_embeds is not None assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP # masking inputs at position 0, as not needed by MTP
if envs.VLLM_USE_FUSED_FILL_RMS_CAT: if envs.VLLM_USE_FUSED_FILL_RMS_CAT:
hidden_states_fuse = torch.empty(hidden_states.shape[0], hidden_states.shaope[1]*2, device=hidden_states.device, dtype=hidden_states.dtype) hidden_states_fuse = torch.empty(inputs_embeds.shape[0], inputs_embeds.shape[1]*2, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, self.enorm.weight, self.hnorm.weight, self.enorm.variance_epsilon) torch.ops.vllm.fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, self.enorm.weight, self.hnorm.weight, self.enorm.variance_epsilon)
hidden_states = self.eh_proj(hidden_states_fuse) hidden_states = self.eh_proj(hidden_states_fuse)
else: else:
inputs_embeds[positions == 0] = 0 inputs_embeds[positions == 0] = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment