Commit 64e307c7 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_OPT_RESHAPE_AND_CACHE (test)

parent 4c92e64a
...@@ -199,6 +199,11 @@ class Attention(nn.Module): ...@@ -199,6 +199,11 @@ class Attention(nn.Module):
# shape does not match the query shape, so we optionally let the model # shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape. # definition specify the output tensor shape.
output_shape: Optional[torch.Size] = None, output_shape: Optional[torch.Size] = None,
q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
weight: Optional[torch.Tensor] = None,
cos_sin_cache: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
The KV cache is stored inside this class and is accessed via The KV cache is stored inside this class and is accessed via
...@@ -255,8 +260,12 @@ class Attention(nn.Module): ...@@ -255,8 +260,12 @@ class Attention(nn.Module):
attn_metadata, attn_metadata,
output=output) output=output)
else: else:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
torch.ops.vllm.unified_attention_with_output( torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name) query, key, value, output, self.layer_name)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name, None, q_ori, key_normed, positions, weight, cos_sin_cache)
return output.view(-1, hidden_size) return output.view(-1, hidden_size)
else: else:
if self.use_direct_call: if self.use_direct_call:
...@@ -497,6 +506,11 @@ def unified_attention_with_output( ...@@ -497,6 +506,11 @@ def unified_attention_with_output(
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
weight: Optional[torch.Tensor] = None,
cos_sin_cache: Optional[torch.Tensor] = None,
) -> None: ) -> None:
wait_for_kv_layer_from_connector(layer_name) wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
...@@ -505,6 +519,7 @@ def unified_attention_with_output( ...@@ -505,6 +519,7 @@ def unified_attention_with_output(
attn_metadata = attn_metadata[layer_name] attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine] kv_cache = self.kv_cache[forward_context.virtual_engine]
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
self.impl.forward(self, self.impl.forward(self,
query, query,
key, key,
...@@ -513,20 +528,50 @@ def unified_attention_with_output( ...@@ -513,20 +528,50 @@ def unified_attention_with_output(
attn_metadata, attn_metadata,
output=output, output=output,
output_scale=output_scale) output_scale=output_scale)
else:
self.impl.forward(self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
q_ori=q_ori,
key_normed=key_normed,
positions=positions,
weight=weight,
cos_sin_cache=cos_sin_cache)
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO:
tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache) tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else: else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache) maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def unified_attention_with_output_fake( if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
def unified_attention_with_output_fake(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
output: torch.Tensor, output: torch.Tensor,
layer_name: str, layer_name: str,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
) -> None: ) -> None:
return
else:
def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
weight: Optional[torch.Tensor] = None,
cos_sin_cache: Optional[torch.Tensor] = None,
) -> None:
return return
......
...@@ -189,6 +189,7 @@ if TYPE_CHECKING: ...@@ -189,6 +189,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSE_SILU_AND_MUL: bool = False VLLM_USE_FUSE_SILU_AND_MUL: bool = False
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False
VLLM_USE_TOPK_RENORM: bool = False VLLM_USE_TOPK_RENORM: bool = False
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1238,6 +1239,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1238,6 +1239,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
(os.environ.get("VLLM_USE_TOPK_RENORM", "True").lower() in (os.environ.get("VLLM_USE_TOPK_RENORM", "True").lower() in
("true", "1")), ("true", "1")),
# vllm will use fused rmsnorm + contiguous + rope(for dpsk-v3) + concat_and_cache_mla
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT":
lambda: (os.getenv('VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT', 'False').lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -255,6 +255,8 @@ def get_model_architecture( ...@@ -255,6 +255,8 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1' os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_CAT_MLA"): if not envs.is_set("VLLM_USE_CAT_MLA"):
os.environ['VLLM_USE_CAT_MLA'] = '1' os.environ['VLLM_USE_CAT_MLA'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
# os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' os.environ['VLLM_USE_PD_SPLIT'] = '1'
...@@ -267,8 +269,8 @@ def get_model_architecture( ...@@ -267,8 +269,8 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"): if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1' os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"): # if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1' # os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if os.getenv('GEMM_PAD') != '1': if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
...@@ -286,6 +288,8 @@ def get_model_architecture( ...@@ -286,6 +288,8 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1' os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_CAT_MLA"): if not envs.is_set("VLLM_USE_CAT_MLA"):
os.environ['VLLM_USE_CAT_MLA'] = '1' os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' os.environ['VLLM_USE_PD_SPLIT'] = '1'
......
...@@ -50,7 +50,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -50,7 +50,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope, _yarn_find_correction_range, _yarn_linear_ramp_mask
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -65,6 +65,7 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter, ...@@ -65,6 +65,7 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
def __init__( def __init__(
...@@ -607,6 +608,52 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -607,6 +608,52 @@ class DeepseekV2MLAAttention(nn.Module):
self.prefix = prefix self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2]) self.debug_layer_idx = int(self.prefix.split(".")[-2])
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
self.max_position_embeddings = rope_scaling["original_max_position_embeddings"]
self.base = rope_theta
self.rotary_dim = qk_rope_head_dim
self.scaling_factor = scaling_factor
self.mscale = mscale
self.extrapolation_factor = 1
self.beta_fast = 32
self.beta_slow = 1
cache = self._compute_cos_sin_cache()
cache = cache.to("cuda")
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(
torch.arange(0,
self.rotary_dim,
2,
dtype=torch.float,
device="cuda") /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward( def forward(
self, self,
...@@ -697,6 +744,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -697,6 +744,7 @@ class DeepseekV2MLAAttention(nn.Module):
q = self.q_proj(hidden_states)[0] q = self.q_proj(hidden_states)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
if envs.VLLM_USE_LIGHTOP: if envs.VLLM_USE_LIGHTOP:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c) kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
else: else:
...@@ -715,6 +763,28 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -715,6 +763,28 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim)) self.num_local_heads * self.v_head_dim))
else:
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
weight = torch.ones(kv_c.shape[-1], dtype=q.dtype, device=kv_c.device)
weight = nn.Parameter(weight)
if self.cos_sin_cache.device != positions.device:
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
if self.cos_sin_cache.device != q.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(q.dtype)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:],
kv_c,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim),
q_ori=q,
key_normed=kv_c_normed,
positions=positions,
weight=weight.data,
cos_sin_cache=self.cos_sin_cache)
return self.o_proj(attn_out)[0] return self.o_proj(attn_out)[0]
......
...@@ -1095,6 +1095,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1095,6 +1095,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
attn_metadata: M, attn_metadata: M,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
q_ori: Optional[torch.Tensor] = None,
key_normed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
weight: Optional[torch.Tensor] = None,
cos_sin_cache: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
...@@ -1129,12 +1134,19 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1129,12 +1134,19 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_q = q[:num_decode_tokens] decode_q = q[:num_decode_tokens]
prefill_q = q[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:]
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
decode_q = q[:num_decode_tokens]
prefill_q = q[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:] prefill_k_c_normed = k_c_normed[num_decode_tokens:]
else:
q_ori = q_ori[:num_actual_toks, ...]
decode_q = q_ori[:num_decode_tokens]
prefill_q = q_ori[num_decode_tokens:]
# write the latent and rope to kv cache # write the latent and rope to kv cache
if kv_cache.numel() > 0: if kv_cache.numel() > 0:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
ops.concat_and_cache_mla( ops.concat_and_cache_mla(
k_c_normed, k_c_normed,
k_pe.squeeze(1), k_pe.squeeze(1),
...@@ -1143,8 +1155,35 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1143,8 +1155,35 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale, scale=layer._k_scale,
) )
else:
from lightop import fused_rms_norm_rope_contiguous
if self.kv_cache_dtype == "auto":
if q.dtype == torch.float16:
kv_cache_dtype_str = "fp16"
elif q.dtype == torch.bfloat16:
kv_cache_dtype_str = "bf16"
else:
kv_cache_dtype_str = self.kv_cache_dtype
fused_rms_norm_rope_contiguous(
positions,
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed, # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
1e-6,
)
if has_prefill: if has_prefill:
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
prefill_k_c_normed = key_normed[num_decode_tokens:]
output[num_decode_tokens:] = self._forward_prefill( output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata, kv_scale=layer._k_scale) attn_metadata, kv_scale=layer._k_scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment