Commit 577eb49f authored by wujl5's avatar wujl5 Committed by zhuwenwen
Browse files

perf: DS-量化模型融合qa和kva的gemm

parent d4e72be3
...@@ -201,6 +201,7 @@ if TYPE_CHECKING: ...@@ -201,6 +201,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
VLLM_ZERO_OVERHEAD_ENHANCE: bool = False VLLM_ZERO_OVERHEAD_ENHANCE: bool = False
VLLM_USE_FUSED_QA_KVA_GEMM: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1299,6 +1300,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1299,6 +1300,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM": "VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
("true", "1")), ("true", "1")),
# Only quantized DeepSeek models supported.
# Unquantized versions are not supported.
"VLLM_USE_FUSED_QA_KVA_GEMM":
lambda: (os.environ.get("VLLM_USE_FUSED_QA_KVA_GEMM", "False").lower() in
("true", "1")),
"VLLM_ZERO_OVERHEAD_ENHANCE": "VLLM_ZERO_OVERHEAD_ENHANCE":
lambda: (os.getenv('VLLM_ZERO_OVERHEAD_ENHANCE', '0').lower() in lambda: (os.getenv('VLLM_ZERO_OVERHEAD_ENHANCE', '0').lower() in
("true", "1")), ("true", "1")),
......
...@@ -32,10 +32,11 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -32,10 +32,11 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
import os import os
import re
from vllm.model_executor.utils import gemm_bank_conf from vllm.model_executor.utils import gemm_bank_conf
from lmslim.quantize.quant_ops import lm_faster_rmsquant from lmslim.quantize.quant_ops import lm_faster_rmsquant
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
logger = init_logger(__name__) logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [ WEIGHT_LOADER_V2_SUPPORTED = [
...@@ -447,6 +448,189 @@ class ReplicatedLinear(LinearBase): ...@@ -447,6 +448,189 @@ class ReplicatedLinear(LinearBase):
return s return s
class FusedQuantedReplicatedLinear(LinearBase):
def __init__(
self,
input_size: int,
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "",
*,
return_bias: bool = True,
):
output_size = q_lora_rank + kv_lora_rank + qk_rope_head_dim
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias)
self.eps = eps
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.q_a_weight = None
self.kv_a_weight = None
self.q_a_wscale = None
self.kv_a_wscale = None
self.weight_loaded = False
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size,
[self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
self.layer_num = -1
if bias:
logger.warning(
"Quanted DeepSeek-specific implementation. "
"Bias is not currently supported.")
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, weight_name: str):
is_gguf_weight = getattr(param, "is_gguf_weight", False)
if is_gguf_weight:
raise ValueError(f"Unexpected is_gguf_weight")
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if not is_quantization:
raise RuntimeError(
"Quanted DeepSeek-specific implementation."
"not support UnquantizedLinearMethod")
self._record_layer_num(weight_name)
if "q_a_proj" in weight_name:
self._store_qa_weight(loaded_weight, weight_name)
elif "kv_a_proj" in weight_name:
self._store_kva_weight(loaded_weight, weight_name)
if self._received_two_weight():
self._fused_quantized_weight(weight_name, param)
def _record_layer_num(self, source: str):
pattern = r"model\.layers\.(\d+)(?:\.\w+)?\.self_attn"
numbers = re.findall(pattern, source)[0]
numbers = int(numbers)
if self.layer_num == -1:
self.layer_num = numbers
else:
assert self.layer_num == numbers, f"self.layer_num: {self.layer_num} != numbers:{numbers}\n"
def _store_qa_weight(self, loaded_weight: torch.Tensor, source: str):
if "zero" in source:
raise RuntimeError("Unsupported zero point weight now.")
if "weight_scale" in source:
self.q_a_wscale = loaded_weight
return
elif "weight" in source:
self.q_a_weight = loaded_weight
return
else:
raise ValueError(f"Unexpected weight: {source}")
def _store_kva_weight(self, loaded_weight: torch.Tensor, source: str):
if "zero" in source:
raise RuntimeError("Unsupported zero point weight now.")
if "weight_scale" in source :
self.kv_a_wscale = loaded_weight
return
elif "weight" in source:
self.kv_a_weight = loaded_weight
return
else:
raise ValueError(f"Unexpected weight: {source}")
def _received_two_weight(self):
if self.q_a_weight is not None and self.kv_a_weight is not None:
return True
if self.q_a_wscale is not None and self.kv_a_wscale is not None:
return True
return False
def _fused_quantized_weight(self, source: str, param: Parameter):
if "weight_scale" in source :
assert len(self.q_a_wscale.shape) == 2
assert len(self.kv_a_wscale.shape) == 2
fused_scale = torch.cat([self.q_a_wscale, self.kv_a_wscale], dim=0)
assert param.data.shape == fused_scale.shape, f"{param.data.shape} == {fused_scale.shape}"
param.data.copy_(fused_scale)
elif "weight" in source:
assert len(self.q_a_weight.shape) == 2
assert len(self.kv_a_weight.shape) == 2
fused_weight = torch.cat([self.q_a_weight, self.kv_a_weight], dim=0) # TN
param.data.copy_(fused_weight)
#TODO: wjl 删掉无用的显存tensor
else:
raise ValueError(f"Unexpected weight: {source}")
def forward(
self,
input_: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor,
tuple[torch.Tensor, torch.Tensor, Optional[Parameter], list[torch.Tensor]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd
)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
assert self.return_bias is True
if not self.return_bias:
raise RuntimeError("Not return bias. Unexpected Error.")
return output, new_residual, output_bias
else:
raise RuntimeError("Unexpected Error.")
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
return s
class ColumnParallelLinear(LinearBase): class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
import re import re
import vllm.envs as envs
from collections.abc import Iterable from collections.abc import Iterable
from typing import Iterable, Optional from typing import Iterable, Optional
...@@ -228,6 +229,12 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -228,6 +229,12 @@ class DeepSeekMTP(nn.Module, SupportsPP):
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
if envs.USE_FUSED_RMS_QUANT and envs.VLLM_USE_FUSED_QA_KVA_GEMM:
fused_params_mapping = [
("qa_kva_proj", "q_a_proj", 0),
("qa_kva_proj", "kv_a_proj_with_mqa", 1)
]
stacked_params_mapping += fused_params_mapping
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
...@@ -256,6 +263,7 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -256,6 +263,7 @@ class DeepSeekMTP(nn.Module, SupportsPP):
# for mlp.experts[0].gate_gate_up_proj, which breaks load. # for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict): if (("mlp.experts." in name) and name not in params_dict):
continue continue
old_weight_name = name
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
...@@ -264,7 +272,12 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -264,7 +272,12 @@ class DeepSeekMTP(nn.Module, SupportsPP):
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
if envs.USE_FUSED_RMS_QUANT and envs.VLLM_USE_FUSED_QA_KVA_GEMM and (("q_a_proj" in old_weight_name) or ("kv_a_proj_with_mqa" in old_weight_name)):
weight_loader(param, loaded_weight, old_weight_name)
else:
weight_loader(param, loaded_weight, shard_id)
break break
else: else:
for mapping in expert_params_mapping: for mapping in expert_params_mapping:
......
...@@ -52,6 +52,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm ...@@ -52,6 +52,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
FusedQuantedReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -588,12 +589,21 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -588,12 +589,21 @@ class DeepseekV2MLAAttention(nn.Module):
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
self.q_a_proj = ReplicatedLinear(self.hidden_size, if envs.VLLM_USE_FUSED_QA_KVA_GEMM:
self.q_lora_rank, self.qa_kva_proj = FusedQuantedReplicatedLinear(self.hidden_size,
bias=False, self.q_lora_rank,
quant_config=quant_config, self.kv_lora_rank,
eps=config.rms_norm_eps, self.qk_rope_head_dim,
prefix=f"{prefix}.q_a_proj") bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qa_kva_proj")
else:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
eps=config.rms_norm_eps,
prefix=f"{prefix}.q_a_proj")
self.q_b_proj = ColumnParallelLinear(q_lora_rank, self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads * self.num_heads *
self.qk_head_dim, self.qk_head_dim,
...@@ -624,13 +634,13 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -624,13 +634,13 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.q_proj") prefix=f"{prefix}.q_proj")
if not envs.VLLM_USE_FUSED_QA_KVA_GEMM:
self.kv_a_proj_with_mqa = ReplicatedLinear( self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size, self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank + self.qk_rope_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa") prefix=f"{prefix}.kv_a_proj_with_mqa")
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear( self.kv_b_proj = ColumnParallelLinear(
...@@ -687,7 +697,9 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -687,7 +697,9 @@ class DeepseekV2MLAAttention(nn.Module):
self.prefix = prefix self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2]) self.debug_layer_idx = int(self.prefix.split(".")[-2])
# TODO wjl: 这里的forward拆了
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -703,13 +715,25 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -703,13 +715,25 @@ class DeepseekV2MLAAttention(nn.Module):
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None: if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
if self.q_lora_rank is not None: if envs.VLLM_USE_FUSED_QA_KVA_GEMM:
q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False) if self.q_lora_rank is not None:
q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False) qc_kvc_kpe, new_residual, _bias = self.qa_kva_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False)
q_c = qc_kvc_kpe[:, :self.q_lora_rank]
kvc_kpe = qc_kvc_kpe[:, self.q_lora_rank:]
q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False)
else:
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = kvc_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
else: else:
q = self.q_proj(hidden_states)[0] if self.q_lora_rank is not None:
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0].split( q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False)
q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False)
else:
q = self.q_proj(hidden_states)[0]
kvc_kpe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0]
kv_c, k_pe = kvc_kpe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
...@@ -1375,6 +1399,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1375,6 +1399,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
if envs.USE_FUSED_RMS_QUANT and envs.VLLM_USE_FUSED_QA_KVA_GEMM:
fused_params_mapping = [
("qa_kva_proj", "q_a_proj", 0),
("qa_kva_proj", "kv_a_proj_with_mqa", 1)
]
stacked_params_mapping += fused_params_mapping
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
...@@ -1407,6 +1437,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1407,6 +1437,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# for mlp.experts[0].gate_gate_up_proj, which breaks load. # for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict): if (("mlp.experts." in name) and name not in params_dict):
continue continue
old_weight_name = name
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
...@@ -1418,7 +1449,11 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1418,7 +1449,11 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
if envs.USE_FUSED_RMS_QUANT and envs.VLLM_USE_FUSED_QA_KVA_GEMM and (("q_a_proj" in old_weight_name) or ("kv_a_proj_with_mqa" in old_weight_name)):
weight_loader(param, loaded_weight, old_weight_name)
else:
weight_loader(param, loaded_weight, shard_id)
break break
else: else:
is_expert_weight = False is_expert_weight = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment