Unverified Commit 4a63c181 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix AWQ with enable MLA (#2364)

parent 2b0fc594
......@@ -21,6 +21,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm import _custom_ops as ops
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
......@@ -894,7 +895,19 @@ class DeepseekV2ForCausalLM(nn.Module):
if not global_server_args_dict["disable_mla"]:
for layer_id in range(self.config.num_hidden_layers):
self_attn = self.model.layers[layer_id].self_attn
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
if hasattr(self_attn.kv_b_proj, "qweight"):
# AWQ compatible
w = ops.awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros,
0,
0,
0,
).T
else:
w = self_attn.kv_b_proj.weight
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment