Commit de87d606 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT when use USE_FUSED_RMS_QUANT and...

add VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT when use USE_FUSED_RMS_QUANT and USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
parent 8118d493
...@@ -633,24 +633,45 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -633,24 +633,45 @@ class DeepseekV2MLAAttention(nn.Module):
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0].split( kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if envs.VLLM_USE_LIGHTOP: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c) if envs.VLLM_USE_LIGHTOP:
else: kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) else:
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim) q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe # Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1) k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe) positions, q[..., self.qk_nope_head_dim:], k_pe)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q, q,
kv_c_normed, kv_c_normed,
k_pe, k_pe,
output_shape=(hidden_states.shape[0], output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim)) self.num_local_heads * self.v_head_dim))
else:
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
weight = self.kv_a_layernorm.weight
cos_sin_cache = self.rotary_emb.cos_sin_cache
if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:],
kv_c,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim),
q_ori=q,
key_normed=kv_c_normed,
positions=positions,
weight=weight,
cos_sin_cache=cos_sin_cache)
return self.o_proj(attn_out)[0], new_residual return self.o_proj(attn_out)[0], new_residual
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None: elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
...@@ -661,23 +682,44 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -661,23 +682,44 @@ class DeepseekV2MLAAttention(nn.Module):
q = self.q_proj(hidden_states)[0] q = self.q_proj(hidden_states)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if envs.VLLM_USE_LIGHTOP: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c) if envs.VLLM_USE_LIGHTOP:
else: kv_c_normed = self.kv_a_layernorm.forward_cuda_opt(kv_c)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) else:
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
k_pe = k_pe.unsqueeze(1) q = q.view(-1, self.num_local_heads, self.qk_head_dim)
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe) q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
attn_out = self.mla_attn(
q, attn_out = self.mla_attn(
kv_c_normed, q,
k_pe, kv_c_normed,
output_shape=(hidden_states.shape[0], k_pe,
self.num_local_heads * self.v_head_dim)) output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim))
else:
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
weight = self.kv_a_layernorm.weight
cos_sin_cache = self.rotary_emb.cos_sin_cache
if cos_sin_cache.device != positions.device or cos_sin_cache.device != q.dtype:
cos_sin_cache = cos_sin_cache.to(positions.device, dtype=q.dtype)
kv_c_normed = torch.empty(kv_c.shape, dtype=kv_c.dtype, device=kv_c.device)
attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:],
kv_c,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim),
q_ori=q,
key_normed=kv_c_normed,
positions=positions,
weight=weight,
cos_sin_cache=cos_sin_cache)
packages_ = self.o_proj(attn_out, packages_ = self.o_proj(attn_out,
pa_rms_weight=pa_rms_weight, pa_rms_weight=pa_rms_weight,
pa_residual=pa_residual, pa_residual=pa_residual,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment