"tests/vscode:/vscode.git/clone" did not exist on "48661d275fb44b969112a7bd8586dfd9f498e2e3"
Commit 14688ccd authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.9.2-dev' into v0.9.2-dev

parents 55310f4f fd559b9f
...@@ -50,7 +50,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -50,7 +50,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope, _yarn_find_correction_range, _yarn_linear_ramp_mask from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -608,52 +608,6 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -608,52 +608,6 @@ class DeepseekV2MLAAttention(nn.Module):
self.prefix = prefix self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2]) self.debug_layer_idx = int(self.prefix.split(".")[-2])
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
self.max_position_embeddings = rope_scaling["original_max_position_embeddings"]
self.base = rope_theta
self.rotary_dim = qk_rope_head_dim
self.scaling_factor = scaling_factor
self.mscale = mscale
self.extrapolation_factor = 1
self.beta_fast = 32
self.beta_slow = 1
cache = self._compute_cos_sin_cache()
cache = cache.to("cuda")
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(
torch.arange(0,
self.rotary_dim,
2,
dtype=torch.float,
device="cuda") /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward( def forward(
self, self,
...@@ -767,12 +721,10 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -767,12 +721,10 @@ class DeepseekV2MLAAttention(nn.Module):
q = q.view(-1, self.num_local_heads, self.qk_head_dim) q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe # Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1) k_pe = k_pe.unsqueeze(1)
weight = torch.ones(kv_c.shape[-1], dtype=q.dtype, device=kv_c.device) weight = self.kv_a_layernorm.weight
weight = nn.Parameter(weight) cos_sin_cache = self.rotary_emb.cos_sin_cache
if self.cos_sin_cache.device != positions.device: if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(positions.device) cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
if self.cos_sin_cache.device != q.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(q.dtype)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device) kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:], q[..., self.qk_nope_head_dim:],
...@@ -783,8 +735,8 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -783,8 +735,8 @@ class DeepseekV2MLAAttention(nn.Module):
q_ori=q, q_ori=q,
key_normed=kv_c_normed, key_normed=kv_c_normed,
positions=positions, positions=positions,
weight=weight.data, weight=weight,
cos_sin_cache=self.cos_sin_cache) cos_sin_cache=cos_sin_cache)
return self.o_proj(attn_out)[0] return self.o_proj(attn_out)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment