Commit 8cfec41a authored by zhuwenwen's avatar zhuwenwen
Browse files

remove USE_FUSED_RMS_QUANT

parent 60b37c6b
......@@ -1984,8 +1984,6 @@ class FusedMoE(CustomOp):
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None= None, **_
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None
......@@ -2053,28 +2051,15 @@ class FusedMoE(CustomOp):
)
# Matrix multiply.
if envs.USE_FUSED_RMS_QUANT:
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states_combined
if do_naive_dispatch_combine
else hidden_states,
router_logits=router_logits,
use_nn_moe=self.use_nn_moe,
use_fused_gate=self.use_fused_gate,
i_q=i_q,
i_s=i_s,
)
else:
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states_combined
if do_naive_dispatch_combine
else hidden_states,
router_logits=router_logits,
use_nn_moe=self.use_nn_moe,
use_fused_gate=self.use_fused_gate,
)
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states_combined
if do_naive_dispatch_combine
else hidden_states,
router_logits=router_logits,
use_nn_moe=self.use_nn_moe,
use_fused_gate=self.use_fused_gate,
)
if has_separate_shared_experts:
assert self.shared_experts is not None
......@@ -2186,16 +2171,11 @@ def moe_forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.shared_experts is None
if envs.USE_FUSED_RMS_QUANT:
return self.forward_impl(hidden_states, router_logits, i_q, i_s)
else:
return self.forward_impl(hidden_states, router_logits)
return self.forward_impl(hidden_states, router_logits)
def moe_forward_fake(
......
......@@ -41,7 +41,6 @@ import os
from vllm.model_executor.utils import gemm_bank_conf
from lmslim.quantize.quant_ops import lm_faster_rmsquant
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
logger = init_logger(__name__)
......@@ -450,52 +449,17 @@ class ReplicatedLinear(LinearBase):
def forward(
self,
input_: torch.Tensor,
rms_weight: torch.Tensor | None = None,
residual: torch.Tensor | None = None,
quant_args: list | None = None,
update_hd: bool | None= True,
x: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None):
if quant_args is not None:
input_quant_args = quant_args
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
else:
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd
)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias, input_quant_args
output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
else:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
if not self.return_bias:
return output
return output, output_bias
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
......@@ -660,48 +624,22 @@ class ColumnParallelLinear(LinearBase):
def forward(
self,
input_,
rms_weight: torch.Tensor | None = None,
residual: torch.Tensor | None = None,
update_hd: bool | None = True
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None
assert rms_weight is not None
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
if self.gather_output and self.tp_size > 1:
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
......@@ -738,54 +676,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
will be treated as a "Replicated" MergedLinear.
"""
def forward(
self, input_,
rms_weight: torch.Tensor | None = None,
residual: torch.Tensor | None = None,
update_hd: bool | None = True
) -> torch.Tensor | tuple[torch.Tensor, Parameter] | None:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None
assert residual is not None and rms_weight is not None
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias
else:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def __init__(
self,
input_size: int,
......
......@@ -532,24 +532,15 @@ def apply_int8_linear(
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input)
x_zp =None
x_q, x_scale = input_quant_args
elif envs.USE_FUSED_RMS_QUANT and silu_quant_args is not None:
assert len(silu_quant_args) == 2
x_zp =None
x_q, x_scale = silu_quant_args
else: # not USE_FUSED_RMS_QUANT
symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input)
x_zp =None
else:
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
input_zero_point,
symmetric=symmetric)
else:
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
input_zero_point,
symmetric=symmetric)
if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token
......
......@@ -1219,58 +1219,15 @@ class DeepseekV2DecoderLayer(nn.Module):
residual: torch.Tensor | None,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
if envs.USE_FUSED_RMS_QUANT:
# Fix residual FP16 overflow
residual_fix_overflow = False
assert self.input_layernorm.has_weight is True
if residual is None:
residual = hidden_states
hidden_states, _ = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = None
)
residual_fix_overflow = True
else:
hidden_states, new_residual = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = residual
)
residual = new_residual
if hidden_states.dtype == torch.float16:
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
hidden_states, new_resi = self.mlp(hidden_states, self.post_attention_layernorm.weight.data, residual)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, new_resi
# Self Attention
# Fix residual FP16 overflow
residual_fix_overflow = False
if residual is None:
residual = hidden_states.clone()
hidden_states = self.input_layernorm(hidden_states)
residual_fix_overflow = True
else:
# Self Attention
# Fix residual FP16 overflow
residual_fix_overflow = False
if residual is None:
residual = hidden_states.clone()
hidden_states = self.input_layernorm(hidden_states)
residual_fix_overflow = True
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states, residual = self.input_layernorm(hidden_states, residual)
attn_kwargs = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment