Commit e80dcabe authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_FUSED_FILL_RMS_CAT for dpsk mtp fill + rms*2 + cat

update VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT impl
parent 4f9947e6
......@@ -553,18 +553,7 @@ def unified_attention_with_output(
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
return
else:
def unified_attention_with_output_fake(
def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
......
......@@ -196,6 +196,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1275,7 +1276,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MARLIN_W16A16_MOE":
lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in
("true", "1")),
# vLLM will use lightop for dpsk mtp fill + rms*2 + cat
"VLLM_USE_FUSED_FILL_RMS_CAT":
lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -253,8 +253,6 @@ def get_model_architecture(
os.environ['VLLM_USE_OPT_CAT'] = '1'
if not envs.is_set("VLLM_USE_CAT_MLA"):
os.environ['VLLM_USE_CAT_MLA'] = '1'
# if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
# os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
......@@ -298,8 +296,6 @@ def get_model_architecture(
os.environ['VLLM_USE_OPT_CAT'] = '1'
if not envs.is_set("VLLM_USE_CAT_MLA"):
os.environ['VLLM_USE_CAT_MLA'] = '1'
# if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
# os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
if not envs.is_set("VLLM_SCHED_ENABLE_MINIMAL_INJECTION"):
......
......@@ -28,6 +28,9 @@ from .interfaces import SupportsPP
from .utils import maybe_prefix
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
import vllm.envs as envs
from lightop import fuse_fill_rms_x2_concat
class SharedHead(nn.Module):
......@@ -84,10 +87,14 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
if envs.VLLM_USE_FUSED_FILL_RMS_CAT:
hidden_states_fuse = torch.empty(hidden_states.shape[0], hidden_states.shaope[1]*2, device=hidden_states.device, dtype=hidden_states.dtype)
fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, self.enorm.weight, self.hnorm.weight, self.enorm.variance_epsilon)
hidden_states = self.eh_proj(hidden_states_fuse)
else:
inputs_embeds[positions == 0] = 0
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
......
......@@ -217,6 +217,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from lightop import fused_rms_norm_rope_contiguous
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
......@@ -1163,7 +1164,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str = "bf16"
else:
kv_cache_dtype_str = self.kv_cache_dtype
from lightop import fused_rms_norm_rope_contiguous
fused_rms_norm_rope_contiguous(
positions[:num_actual_toks, ...],
q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment