Commit 2dd0894f authored by wujl5's avatar wujl5
Browse files

rmsquant先上库接口代码

parent a1628458
...@@ -60,6 +60,7 @@ class SharedFusedMoE(FusedMoE): ...@@ -60,6 +60,7 @@ class SharedFusedMoE(FusedMoE):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if not self.use_overlapped: if not self.use_overlapped:
if self._shared_experts is not None: if self._shared_experts is not None:
......
...@@ -18,7 +18,8 @@ from vllm.platforms import current_platform ...@@ -18,7 +18,8 @@ from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm import envs from vllm import envs
from lightop import rms_norm_dynamic_per_token_quant from lightop.op import rms_norm_dynamic_per_token_quant as ligtop_rms_norm_dynamic_per_token_quant
def rms_norm( def rms_norm(
...@@ -334,12 +335,14 @@ class FusedRMSNormQuant(nn.Module): ...@@ -334,12 +335,14 @@ class FusedRMSNormQuant(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: torch.Tensor | None = None, residual: torch.Tensor | None = None,
quant_dtype: torch.dtype = torch.int8 quant_dtype: torch.dtype = torch.int8,
update_input: Optional[bool] = True
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
x, x_scales = fused_rmsquant(x, self.weight, x, x_scales = fused_rmsquant(x, self.weight,
self.variance_epsilon, self.variance_epsilon,
quant_dtype, residual) quant_dtype, residual,
update_input)
return x, x_scales, residual return x, x_scales, residual
...@@ -351,12 +354,13 @@ def fused_rmsquant_impl( ...@@ -351,12 +354,13 @@ def fused_rmsquant_impl(
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
update_input: Optional[bool] = True update_input: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
output, scales = rms_norm_dynamic_per_token_quant(input, output = torch.empty_like(input, device=input.device, dtype=quant_dtype)
weight, scales = torch.empty((input.numel() // input.shape[-1], 1),
epsilon, device=input.device,
quant_dtype, dtype=torch.float32)
residual, ligtop_rms_norm_dynamic_per_token_quant(output, input, weight,
update_input) scales, epsilon,
residual, update_input)
return output, scales return output, scales
def fused_rmsquant_fake( def fused_rmsquant_fake(
...@@ -374,11 +378,15 @@ def fused_rmsquant_fake( ...@@ -374,11 +378,15 @@ def fused_rmsquant_fake(
dtype=torch.float32) dtype=torch.float32)
return output, scales return output, scales
# from torch.library import Library
# customer_lib = Library("customer_", "FRAGMENT")
direct_register_custom_op( direct_register_custom_op(
op_name="rms_norm_dynamic_per_token_quant", op_name="rms_norm_dynamic_per_token_quant",
op_func=fused_rmsquant_impl, op_func=fused_rmsquant_impl,
mutates_args=[], mutates_args=[],
fake_impl=fused_rmsquant_fake, fake_impl=fused_rmsquant_fake,
# target_lib=customer_lib,
) )
def fused_rmsquant(input: torch.Tensor, def fused_rmsquant(input: torch.Tensor,
...@@ -387,12 +395,13 @@ def fused_rmsquant(input: torch.Tensor, ...@@ -387,12 +395,13 @@ def fused_rmsquant(input: torch.Tensor,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
update_input: Optional[bool] = True): update_input: Optional[bool] = True):
i_q, _scales = torch.ops.vllm.fused_rmsquant(input=input, from lmslim.quantize.quant_ops import lm_faster_rmsquant
weight=rms_weight, i_q, _scales = lm_faster_rmsquant(input=input,
epsilon=epsilon, rms_weight=rms_weight,
quant_dtype=quant_dtype, epsilon=epsilon,
residual=residual, quant_dtype=quant_dtype,
update_input=update_input) residual=residual,
update_input=update_input)
return i_q, _scales return i_q, _scales
......
...@@ -706,6 +706,28 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -706,6 +706,28 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
disable_tp: If true, all weights matrix won't be sharded, this layer disable_tp: If true, all weights matrix won't be sharded, this layer
will be treated as a "Replicated" MergedLinear. will be treated as a "Replicated" MergedLinear.
""" """
def forward(
self,
input_,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def __init__( def __init__(
self, self,
......
...@@ -115,6 +115,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -115,6 +115,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None = None, llama_4_scaling: torch.Tensor | None = None,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
q_c = None q_c = None
kv_lora = None kv_lora = None
......
...@@ -94,6 +94,8 @@ from .utils import ( ...@@ -94,6 +94,8 @@ from .utils import (
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from vllm.model_executor.layers.layernorm import FusedRMSNormQuant
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -169,6 +171,7 @@ class DeepseekAttention(nn.Module): ...@@ -169,6 +171,7 @@ class DeepseekAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
...@@ -218,10 +221,18 @@ class DeepseekV2MLP(nn.Module): ...@@ -218,10 +221,18 @@ class DeepseekV2MLP(nn.Module):
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self,
gate_up, _ = self.gate_up_proj(x) x,
x = self.act_fn(gate_up) *, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
x, _ = self.down_proj(x) ):
if envs.USE_FUSED_RMS_QUANT:
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x return x
...@@ -334,7 +345,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -334,7 +345,9 @@ class DeepseekV2MoE(nn.Module):
else None, else None,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
...@@ -385,6 +398,60 @@ class DeepseekV2MoE(nn.Module): ...@@ -385,6 +398,60 @@ class DeepseekV2MoE(nn.Module):
) )
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
def forward_RQ(self, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# Chunk the hidden states so they aren't replicated across TP ranks.
# This avoids duplicate computation in self.experts.
# TODO: We can replace the all_reduce at the end of attn with a
# reduce_scatter instead of chunking here.
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states, iqis=iqis
)
else:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits, iqis=iqis
)
shared_output, final_hidden_states = fused_moe_out
if self.shared_experts is None:
assert shared_output is None
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
if not self.is_rocm_aiter_moe_enabled:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0
)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
...@@ -528,6 +595,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -528,6 +595,7 @@ class DeepseekV2Attention(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None, llama_4_scaling: torch.Tensor | None,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0] q = self.q_a_proj(hidden_states)[0]
...@@ -904,8 +972,9 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -904,8 +972,9 @@ class DeepseekV2MLAAttention(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None, llama_4_scaling: torch.Tensor | None,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
return self.mla_attn(positions, hidden_states, llama_4_scaling) return self.mla_attn(positions, hidden_states, llama_4_scaling, iqis=iqis)
class DeepseekV2DecoderLayer(nn.Module): class DeepseekV2DecoderLayer(nn.Module):
...@@ -986,13 +1055,111 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -986,13 +1055,111 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( if not envs.USE_FUSED_RMS_QUANT:
config.hidden_size, eps=config.rms_norm_eps self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
) self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
else:
self.input_layernorm = FusedRMSNormQuant(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = FusedRMSNormQuant(
config.hidden_size, eps=config.rms_norm_eps
)
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
def forward( def forward_RQ(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
# Self Attention
# Fix residual FP16 overflow
residual_fix_overflow = False
if residual is None:
residual = hidden_states.clone()
i_q, i_s, _ = self.input_layernorm(x=hidden_states,
residual=None,
quant_dtype=torch.int8,
update_input=False)
# i_q, i_s = lm_faster_rmsquant(input=hidden_states,
# rms_weight = self.input_layernorm.weight.data,
# epsilon=self._eps,
# quant_dtype=torch.int8,
# residual=None,
# update_input=False)
residual_fix_overflow = True
else:
# hidden_states, residual = self.input_layernorm(hidden_states, residual)
# i_q, i_s = lm_faster_rmsquant(input=hidden_states,
# rms_weight = self.input_layernorm.weight.data,
# epsilon=self._eps,
# quant_dtype=torch.int8,
# residual=residual,
# update_input=False)
i_q, i_s, residual = self.input_layernorm(x=hidden_states,
residual=residual,
quant_dtype=torch.int8,
update_input=False)
attn_kwargs = {
"positions": positions,
"hidden_states": hidden_states,
"iqis": (i_q, i_s)
}
if not self.use_mha:
attn_kwargs["llama_4_scaling"] = llama_4_scaling
hidden_states = self.self_attn(**attn_kwargs)
if (
not isinstance(self.self_attn, DeepseekAttention)
and hidden_states.dtype == torch.float16
):
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1.0 / self.routed_scaling_factor
if self.layer_idx == 0:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1.0 / self.routed_scaling_factor
# Fully Connected
# hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False
# _i_q, _i_s = lm_faster_rmsquant(input=hidden_states,
# rms_weight=self.post_attention_layernorm.weight.data,
# epsilon=self._eps,
# quant_dtype=torch.int8,
# residual=residual,
# update_input=update_hs)
_i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states,
residual=residual,
quant_dtype=torch.int8,
update_input=update_hs)
new_resi = residual
hidden_states = self.mlp(hidden_states,
iqis=(_i_q, _i_s))
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1.0 / self.routed_scaling_factor
return hidden_states, new_resi
def forward_default(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -1045,6 +1212,25 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1045,6 +1212,25 @@ class DeepseekV2DecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
def choose_forward(self):
if envs.USE_FUSED_RMS_QUANT:
return self.forward_RQ
else:
return self.forward_default
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
forward_func = self.choose_forward()
return forward_func(positions=positions,
hidden_states=hidden_states,
residual=residual,
llama_4_scaling=llama_4_scaling)
@support_torch_compile @support_torch_compile
class DeepseekV2Model(nn.Module): class DeepseekV2Model(nn.Module):
......
...@@ -75,7 +75,7 @@ from .utils import ( ...@@ -75,7 +75,7 @@ from .utils import (
) )
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
if envs.VLLM_USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
from lightop import rms_norm_dynamic_per_token_quant from lightop import rms_norm_dynamic_per_token_quant
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -414,7 +414,7 @@ class Glm4MoeDecoderLayer(nn.Module): ...@@ -414,7 +414,7 @@ class Glm4MoeDecoderLayer(nn.Module):
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
) )
else: else:
if not envs.VLLM_USE_FUSED_RMS_QUANT: if not envs.USE_FUSED_RMS_QUANT:
self.mlp = Glm4MoeMLP( self.mlp = Glm4MoeMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
...@@ -431,7 +431,7 @@ class Glm4MoeDecoderLayer(nn.Module): ...@@ -431,7 +431,7 @@ class Glm4MoeDecoderLayer(nn.Module):
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
if not envs.VLLM_USE_FUSED_RMS_QUANT: if not envs.USE_FUSED_RMS_QUANT:
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
) )
...@@ -454,7 +454,7 @@ class Glm4MoeDecoderLayer(nn.Module): ...@@ -454,7 +454,7 @@ class Glm4MoeDecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
if not envs.VLLM_USE_FUSED_RMS_QUANT: if not envs.USE_FUSED_RMS_QUANT:
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment