Unverified Commit 4a63c181 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix AWQ with enable MLA (#2364)

parent 2b0fc594
...@@ -21,6 +21,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple ...@@ -21,6 +21,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm import _custom_ops as ops
from vllm.distributed import ( from vllm.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -894,7 +895,19 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -894,7 +895,19 @@ class DeepseekV2ForCausalLM(nn.Module):
if not global_server_args_dict["disable_mla"]: if not global_server_args_dict["disable_mla"]:
for layer_id in range(self.config.num_hidden_layers): for layer_id in range(self.config.num_hidden_layers):
self_attn = self.model.layers[layer_id].self_attn self_attn = self.model.layers[layer_id].self_attn
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten( if hasattr(self_attn.kv_b_proj, "qweight"):
# AWQ compatible
w = ops.awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
0,
0,
0,
).T
else:
w = self_attn.kv_b_proj.weight
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment