Commit de87d606 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT when use USE_FUSED_RMS_QUANT and...

add VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT when use USE_FUSED_RMS_QUANT and USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
parent 8118d493
......@@ -633,6 +633,7 @@ class DeepseekV2MLAAttention(nn.Module):
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
if envs.VLLM_USE_LIGHTOP:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
else:
......@@ -651,6 +652,26 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim))
else:
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
weight = self.kv_a_layernorm.weight
cos_sin_cache = self.rotary_emb.cos_sin_cache
if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:],
kv_c,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim),
q_ori=q,
key_normed=kv_c_normed,
positions=positions,
weight=weight,
cos_sin_cache=cos_sin_cache)
return self.o_proj(attn_out)[0], new_residual
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None:
if self.q_lora_rank is not None:
......@@ -661,6 +682,7 @@ class DeepseekV2MLAAttention(nn.Module):
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
if envs.VLLM_USE_LIGHTOP:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
else:
......@@ -678,6 +700,26 @@ class DeepseekV2MLAAttention(nn.Module):
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim))
else:
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
weight = self.kv_a_layernorm.weight
cos_sin_cache = self.rotary_emb.cos_sin_cache
if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:],
kv_c,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim),
q_ori=q,
key_normed=kv_c_normed,
positions=positions,
weight=weight,
cos_sin_cache=cos_sin_cache)
packages_ = self.o_proj(attn_out,
pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment