Commit 577eb49f authored by wujl5's avatar wujl5 Committed by zhuwenwen
Browse files

perf: DS-量化模型融合qa和kva的gemm

parent d4e72be3
......@@ -201,6 +201,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
VLLM_ZERO_OVERHEAD_ENHANCE: bool = False
VLLM_USE_FUSED_QA_KVA_GEMM: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1299,6 +1300,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")),
# Only quantized DeepSeek models supported.
# Unquantized versions are not supported.
"VLLM_USE_FUSED_QA_KVA_GEMM":
lambda: (os.environ.get("VLLM_USE_FUSED_QA_KVA_GEMM", "False").lower() in
("true", "1")),
"VLLM_ZERO_OVERHEAD_ENHANCE":
lambda: (os.getenv('VLLM_ZERO_OVERHEAD_ENHANCE', '0').lower() in
("true", "1")),
......
......@@ -32,6 +32,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
import os
import re
from vllm.model_executor.utils import gemm_bank_conf
from lmslim.quantize.quant_ops import lm_faster_rmsquant
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
......@@ -447,6 +448,189 @@ class ReplicatedLinear(LinearBase):
return s
class FusedQuantedReplicatedLinear(LinearBase):
def __init__(
self,
input_size: int,
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "",
*,
return_bias: bool = True,
):
output_size = q_lora_rank + kv_lora_rank + qk_rope_head_dim
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias)
self.eps = eps
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.q_a_weight = None
self.kv_a_weight = None
self.q_a_wscale = None
self.kv_a_wscale = None
self.weight_loaded = False
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size,
[self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
self.layer_num = -1
if bias:
logger.warning(
"Quanted DeepSeek-specific implementation. "
"Bias is not currently supported.")
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, weight_name: str):
is_gguf_weight = getattr(param, "is_gguf_weight", False)
if is_gguf_weight:
raise ValueError(f"Unexpected is_gguf_weight")
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if not is_quantization:
raise RuntimeError(
"Quanted DeepSeek-specific implementation."
"not support UnquantizedLinearMethod")
self._record_layer_num(weight_name)
if "q_a_proj" in weight_name:
self._store_qa_weight(loaded_weight, weight_name)
elif "kv_a_proj" in weight_name:
self._store_kva_weight(loaded_weight, weight_name)
if self._received_two_weight():
self._fused_quantized_weight(weight_name, param)
def _record_layer_num(self, source: str):
pattern = r"model\.layers\.(\d+)(?:\.\w+)?\.self_attn"
numbers = re.findall(pattern, source)[0]
numbers = int(numbers)
if self.layer_num == -1:
self.layer_num = numbers
else:
assert self.layer_num == numbers, f"self.layer_num: {self.layer_num} != numbers:{numbers}\n"
def _store_qa_weight(self, loaded_weight: torch.Tensor, source: str):
if "zero" in source:
raise RuntimeError("Unsupported zero point weight now.")
if "weight_scale" in source:
self.q_a_wscale = loaded_weight
return
elif "weight" in source:
self.q_a_weight = loaded_weight
return
else:
raise ValueError(f"Unexpected weight: {source}")
def _store_kva_weight(self, loaded_weight: torch.Tensor, source: str):
if "zero" in source:
raise RuntimeError("Unsupported zero point weight now.")
if "weight_scale" in source :
self.kv_a_wscale = loaded_weight
return
elif "weight" in source:
self.kv_a_weight = loaded_weight
return
else:
raise ValueError(f"Unexpected weight: {source}")
def _received_two_weight(self):
if self.q_a_weight is not None and self.kv_a_weight is not None:
return True
if self.q_a_wscale is not None and self.kv_a_wscale is not None:
return True
return False
def _fused_quantized_weight(self, source: str, param: Parameter):
if "weight_scale" in source :
assert len(self.q_a_wscale.shape) == 2
assert len(self.kv_a_wscale.shape) == 2
fused_scale = torch.cat([self.q_a_wscale, self.kv_a_wscale], dim=0)
assert param.data.shape == fused_scale.shape, f"{param.data.shape} == {fused_scale.shape}"
param.data.copy_(fused_scale)
elif "weight" in source:
assert len(self.q_a_weight.shape) == 2
assert len(self.kv_a_weight.shape) == 2
fused_weight = torch.cat([self.q_a_weight, self.kv_a_weight], dim=0) # TN
param.data.copy_(fused_weight)
#TODO: wjl 删掉无用的显存tensor
else:
raise ValueError(f"Unexpected weight: {source}")
def forward(
self,
input_: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor,
tuple[torch.Tensor, torch.Tensor, Optional[Parameter], list[torch.Tensor]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd
)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
assert self.return_bias is True
if not self.return_bias:
raise RuntimeError("Not return bias. Unexpected Error.")
return output, new_residual, output_bias
else:
raise RuntimeError("Unexpected Error.")
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
return s
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import re
import vllm.envs as envs
from collections.abc import Iterable
from typing import Iterable, Optional
......@@ -228,6 +229,12 @@ class DeepSeekMTP(nn.Module, SupportsPP):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if envs.USE_FUSED_RMS_QUANT and envs.VLLM_USE_FUSED_QA_KVA_GEMM:
fused_params_mapping = [
("qa_kva_proj", "q_a_proj", 0),
("qa_kva_proj", "kv_a_proj_with_mqa", 1)
]
stacked_params_mapping += fused_params_mapping
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
......@@ -256,6 +263,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
old_weight_name = name
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
......@@ -264,7 +272,12 @@ class DeepSeekMTP(nn.Module, SupportsPP):
param = params_dict[name]
weight_loader = param.weight_loader
if envs.USE_FUSED_RMS_QUANT and envs.VLLM_USE_FUSED_QA_KVA_GEMM and (("q_a_proj" in old_weight_name) or ("kv_a_proj_with_mqa" in old_weight_name)):
weight_loader(param, loaded_weight, old_weight_name)
else:
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
......
......@@ -52,6 +52,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
FusedQuantedReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
......@@ -588,6 +589,15 @@ class DeepseekV2MLAAttention(nn.Module):
if self.q_lora_rank is not None:
if envs.USE_FUSED_RMS_QUANT:
if envs.VLLM_USE_FUSED_QA_KVA_GEMM:
self.qa_kva_proj = FusedQuantedReplicatedLinear(self.hidden_size,
self.q_lora_rank,
self.kv_lora_rank,
self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qa_kva_proj")
else:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
......@@ -624,7 +634,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj")
if not envs.VLLM_USE_FUSED_QA_KVA_GEMM:
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
......@@ -688,6 +698,8 @@ class DeepseekV2MLAAttention(nn.Module):
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
# TODO wjl: 这里的forward拆了
def forward(
self,
positions: torch.Tensor,
......@@ -703,13 +715,25 @@ class DeepseekV2MLAAttention(nn.Module):
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
if envs.VLLM_USE_FUSED_QA_KVA_GEMM:
if self.q_lora_rank is not None:
qc_kvc_kpe, new_residual, _bias = self.qa_kva_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False)
q_c = qc_kvc_kpe[:, :self.q_lora_rank]
kvc_kpe = qc_kvc_kpe[:, self.q_lora_rank:]
q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False)
else:
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = kvc_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
else:
if self.q_lora_rank is not None:
q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False)
q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False)
else:
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0].split(
kvc_kpe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0]
kv_c, k_pe = kvc_kpe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
......@@ -1375,6 +1399,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if envs.USE_FUSED_RMS_QUANT and envs.VLLM_USE_FUSED_QA_KVA_GEMM:
fused_params_mapping = [
("qa_kva_proj", "q_a_proj", 0),
("qa_kva_proj", "kv_a_proj_with_mqa", 1)
]
stacked_params_mapping += fused_params_mapping
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
......@@ -1407,6 +1437,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
old_weight_name = name
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
......@@ -1418,6 +1449,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
param = params_dict[name]
weight_loader = param.weight_loader
if envs.USE_FUSED_RMS_QUANT and envs.VLLM_USE_FUSED_QA_KVA_GEMM and (("q_a_proj" in old_weight_name) or ("kv_a_proj_with_mqa" in old_weight_name)):
weight_loader(param, loaded_weight, old_weight_name)
else:
weight_loader(param, loaded_weight, shard_id)
break
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment