Commit 16732666 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-refactor-code-rebase-fb39e-squash' into 'v0.9.2-dev'

DS量化模型重构atten和moe调用rmsquant融合逻辑。

See merge request dcutoolkit/deeplearing/vllm!345
parents fb39e61b 89b62a25
......@@ -34,7 +34,7 @@ class SharedFusedMoE(FusedMoE):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
hidden_states_copy: Optional[torch.Tensor] = None
hidden_states_copy: Optional[torch.Tensor] = None, **_
) -> tuple[torch.Tensor, torch.Tensor]|torch.Tensor:
if not self.use_overlapped:
shared_out = self._shared_experts(hidden_states)
......
......@@ -897,7 +897,7 @@ class EPSharedExperts(nn.Module):
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
def forward(self, x, **_):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
......
......@@ -331,7 +331,6 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "",
*,
return_bias: bool = True,
......@@ -343,7 +342,6 @@ class ReplicatedLinear(LinearBase):
quant_config,
prefix=prefix,
return_bias=return_bias)
self.eps = eps
# All the linear layer supports quant method.
assert self.quant_method is not None
......@@ -393,44 +391,18 @@ class ReplicatedLinear(LinearBase):
def forward(
self,
input_: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
quant_args: Optional[list] = None,
update_hd: Optional[bool] = True
iqis: Optional[tuple] = None, **_
) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]],
tuple[torch.Tensor, torch.Tensor, Optional[Parameter], list[torch.Tensor]]]:
if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None):
if quant_args is not None:
input_quant_args = quant_args
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
else:
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd
)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias, input_quant_args
tuple[torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
else:
bias = self.bias if not self.skip_bias_add else None
......@@ -459,7 +431,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "",
*,
return_bias: bool = True,
......@@ -473,7 +444,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
quant_config,
prefix=prefix,
return_bias=return_bias)
self.eps = eps
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
......@@ -588,7 +558,6 @@ class FusedQuantedReplicatedLinear(LinearBase):
assert len(self.kv_a_weight.shape) == 2
fused_weight = torch.cat([self.q_a_weight, self.kv_a_weight], dim=0) # TN
param.data.copy_(fused_weight)
#TODO: wjl 删掉无用的显存tensor
else:
raise ValueError(f"Unexpected weight: {source}")
......@@ -596,31 +565,17 @@ class FusedQuantedReplicatedLinear(LinearBase):
def forward(
self,
input_: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor,
tuple[torch.Tensor, torch.Tensor, Optional[Parameter], list[torch.Tensor]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd
)
new_residual = residual
input_quant_args = [i_q, _scales]
iqis: Optional[tuple] = None, **_
) -> tuple[torch.Tensor, Optional[Parameter]]:
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
output_bias = self.bias if self.skip_bias_add else None
assert self.return_bias is True
if not self.return_bias:
raise RuntimeError("Not return bias. Unexpected Error.")
return output, new_residual, output_bias
return output, output_bias
else:
raise RuntimeError("Unexpected Error.")
......@@ -858,31 +813,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
def forward(
self, input_,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True,
xqxs: Optional[tuple] = None
xqxs: Optional[tuple] = None,
iqis: Optional[tuple] = None, **_
) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]],
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[Parameter]],
]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None
assert residual is not None and rms_weight is not None
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd)
new_residual = residual
input_quant_args = [i_q, _scales]
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
if self.gather_output:
# All-gather across the partitions.
......@@ -892,7 +833,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, i_q, _scales, output_bias
return output, output_bias
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
bias = self.bias if not self.skip_bias_add else None
......@@ -933,13 +874,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "",
*,
return_bias: bool = True,
expect_tp_size: Optional[int] = None,
):
self.eps = eps
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
......
......@@ -70,7 +70,9 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
maybe_prefix)
from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON
from lmslim.quantize.quant_ops import lm_faster_rmsquant
class DeepseekV2MLP(nn.Module):
def __init__(
......@@ -100,21 +102,18 @@ class DeepseekV2MLP(nn.Module):
self.act_fn = SiluAndMul()
def forward(self, x,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = False,
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> torch.Tensor:
if envs.USE_FUSED_RMS_QUANT:
gate_up, new_resi, i_q, _scales, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd)
assert iqis is not None
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT:
x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True)
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x, new_resi, i_q, _scales
return x
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
gate_up, _ = self.gate_up_proj(x, xqxs=xqxs)
if envs.USE_FUSED_SILU_MUL_QUANT:
......@@ -279,9 +278,9 @@ class DeepseekV2MoE(nn.Module):
self.tbo_all_reduce = tbo_all_reduce
def forward(self, hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
num_tokens, hidden_dim = hidden_states.shape
......@@ -338,12 +337,12 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
else:
else: # RQ
if not self.enable_expert_parallel:
i_q, i_s = None, None
if self.run_shared_expert_singlely:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi, i_q, i_s = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
shared_output = self.shared_experts(hidden_states, iqis=iqis)
else:
shared_output = self.shared_experts(hidden_states)
......@@ -378,7 +377,8 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
router_logits=router_logits,
i_q=iqis[0], i_s=iqis[1])
if shared_output is not None:
if hidden_states.dtype != torch.float16:
......@@ -388,7 +388,7 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
else:
else: # EP
router_logits, _ = self.gate(hidden_states)
if self.use_deepep:
shared_output, final_hidden_states = self.experts(hidden_states=hidden_states,
......@@ -405,7 +405,7 @@ class DeepseekV2MoE(nn.Module):
else:
if self.run_shared_expert_singlely:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
shared_output = self.shared_experts(hidden_states, iqis=iqis)
else:
shared_output = self.shared_experts(hidden_states)
......@@ -420,9 +420,9 @@ class DeepseekV2MoE(nn.Module):
assert shared_output is not None
final_hidden_states += (shared_output * (1. / self.routed_scaling_factor))
else:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits,
i_q=iqis[0], i_s=iqis[1])
if shared_output is not None:
if hidden_states.dtype != torch.float16:
......@@ -441,9 +441,6 @@ class DeepseekV2MoE(nn.Module):
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
if envs.USE_FUSED_RMS_QUANT:
return final_hidden_states.view(num_tokens, hidden_dim), new_resi, i_q, i_s
else:
return final_hidden_states.view(num_tokens, hidden_dim)
......@@ -662,7 +659,6 @@ class DeepseekV2MLAAttention(nn.Module):
self.q_lora_rank,
bias=False,
quant_config=quant_config,
eps=config.rms_norm_eps,
prefix=f"{prefix}.q_a_proj")
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads *
......@@ -764,20 +760,22 @@ class DeepseekV2MLAAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
pa_rms_weight: Optional[torch.Tensor] = None,
pa_residual: Optional[torch.Tensor] = None,
pa_rms_eps: Optional[float] = 1e-6,
pa_quant_dtype: Optional[torch.dtype] = torch.int8,
update_input: Optional[bool] = True
update_input: Optional[bool] = True,
iqis: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
if envs.VLLM_USE_FUSED_QA_KVA_GEMM:
if self.q_lora_rank is not None:
qc_kvc_kpe, new_residual, _bias = self.qa_kva_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False)
# rms_weight=rms_weight, residual=residual, update_hd=False
qc_kvc_kpe, _bias = self.qa_kva_proj(hidden_states, iqis)
q_c = qc_kvc_kpe[:, :self.q_lora_rank]
kvc_kpe = qc_kvc_kpe[:, self.q_lora_rank:]
q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False)
......@@ -787,12 +785,12 @@ class DeepseekV2MLAAttention(nn.Module):
kv_c, k_pe = kvc_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
else:
if self.q_lora_rank is not None:
q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False)
q_c, _ = self.q_a_proj(hidden_states, iqis=iqis)
q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False)
else:
q = self.q_proj(hidden_states)[0]
kvc_kpe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0]
kvc_kpe = self.kv_a_proj_with_mqa(hidden_states, iqis=iqis)[0]
kv_c, k_pe = kvc_kpe.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
......@@ -835,7 +833,7 @@ class DeepseekV2MLAAttention(nn.Module):
positions=positions,
weight=weight,
cos_sin_cache=cos_sin_cache)
return self.o_proj(attn_out)[0], new_residual
return self.o_proj(attn_out)[0]
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None:
if self.q_lora_rank is not None:
q_c = self.q_a_proj(hidden_states)[0]
......@@ -1035,10 +1033,11 @@ class DeepseekV2DecoderLayer(nn.Module):
self.routed_scaling_factor = config.routed_scaling_factor
self.use_fused_rms_quant = envs.USE_FUSED_RMS_QUANT
self.use_fused_custom_all_reduce = envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
self._eps = config.rms_norm_eps
def forward_fused_rmsquant(
def forward_fused_RQ(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
......@@ -1050,22 +1049,25 @@ class DeepseekV2DecoderLayer(nn.Module):
assert self.input_layernorm.has_weight is True
if residual is None:
residual = hidden_states
hidden_states, _ = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = None
)
residual_fix_overflow = True
i_q, i_s = lm_faster_rmsquant(input=hidden_states,
rms_weight = self.input_layernorm.weight.data,
epsilon=self._eps,
quant_dtype=torch.int8,
residual=None,
update_input=False)
else:
hidden_states, new_residual = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = residual
)
residual = new_residual
i_q, i_s = lm_faster_rmsquant(input=hidden_states,
rms_weight = self.input_layernorm.weight.data,
epsilon=self._eps,
quant_dtype=torch.int8,
residual=residual,
update_input=False)
hidden_states = self.self_attn(positions=positions,
hidden_states = hidden_states, # get attr
iqis=(i_q, i_s))
if hidden_states.dtype == torch.float16:
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor
......@@ -1073,11 +1075,17 @@ class DeepseekV2DecoderLayer(nn.Module):
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
hidden_states, new_resi, _i_q, _scales = self.mlp(hidden_states,
rms_weight=self.post_attention_layernorm.weight.data,
residual=residual,
)
update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False
_i_q, _i_s = lm_faster_rmsquant(input=hidden_states,
rms_weight=self.post_attention_layernorm.weight.data,
epsilon=self._eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hs)
new_resi = residual
hidden_states = self.mlp(hidden_states,
iqis=(_i_q, _i_s))
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......@@ -1211,7 +1219,7 @@ class DeepseekV2DecoderLayer(nn.Module):
def choose_forward(self):
if self.use_fused_rms_quant:
return self.forward_fused_rmsquant
return self.forward_fused_RQ
elif self.use_fused_custom_all_reduce:
return self.forward_fused_CRQ
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment