"tools/vscode:/vscode.git/clone" did not exist on "5e30e9b9a9b0dc91927f2ec4ffbd92c4af3825c0"
Commit 8ff0c0d2 authored by wujl5's avatar wujl5
Browse files

rmsquant融合的add和rms功能可以使用了,quant的tensor输入到模型后输出错误, moe的gate-up层未量化

parent 2dd0894f
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from contextlib import nullcontext from contextlib import nullcontext
from enum import Enum from enum import Enum
from typing import Literal, cast, get_args, overload from typing import Literal, cast, get_args, overload, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -1669,6 +1669,8 @@ class FusedMoE(CustomOp): ...@@ -1669,6 +1669,8 @@ class FusedMoE(CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
og_hidden_states = hidden_states.shape[-1] og_hidden_states = hidden_states.shape[-1]
if self.hidden_size != og_hidden_states: if self.hidden_size != og_hidden_states:
...@@ -1720,7 +1722,9 @@ class FusedMoE(CustomOp): ...@@ -1720,7 +1722,9 @@ class FusedMoE(CustomOp):
) )
else: else:
shared_output, fused_output = torch.ops.vllm.moe_forward_shared( shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, encode_layer_name() hidden_states, router_logits, encode_layer_name(),
i_q=i_q,
i_s=i_s
) )
return ( return (
reduce_output(shared_output)[..., :og_hidden_states], reduce_output(shared_output)[..., :og_hidden_states],
...@@ -1737,8 +1741,10 @@ class FusedMoE(CustomOp): ...@@ -1737,8 +1741,10 @@ class FusedMoE(CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(hidden_states, router_logits) return self.forward_native(hidden_states, router_logits, i_q=i_q, i_s=i_s)
def forward_impl_chunked( def forward_impl_chunked(
self, self,
...@@ -1880,6 +1886,8 @@ class FusedMoE(CustomOp): ...@@ -1880,6 +1886,8 @@ class FusedMoE(CustomOp):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
use_fused_gate: bool | None = False, use_fused_gate: bool | None = False,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.quant_method is not None assert self.quant_method is not None
...@@ -1904,6 +1912,7 @@ class FusedMoE(CustomOp): ...@@ -1904,6 +1912,7 @@ class FusedMoE(CustomOp):
# parallel execution of shared experts with the FusedMoE via # parallel execution of shared experts with the FusedMoE via
# separate cuda stream) # separate cuda stream)
if self.gate is not None: if self.gate is not None:
print("YYYY: unsupported using gate in FusedMOE.forward_impl") # true
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if use_chunked_impl: if use_chunked_impl:
...@@ -2005,12 +2014,24 @@ class FusedMoE(CustomOp): ...@@ -2005,12 +2014,24 @@ class FusedMoE(CustomOp):
if self.capture is not None: if self.capture is not None:
self.capture(topk_ids) self.capture(topk_ids)
if envs.USE_FUSED_RMS_QUANT:
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
x=x, # The type signture of this is wrong due to the hack. x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe,
i_q=i_q,
i_s=i_s
)
else:
final_hidden_states = self.quant_method.apply(
layer=self,
x=x, # The type signture of this is wrong due to the hack.
topk_weights=topk_weights,
topk_ids=topk_ids,
use_nn_moe=self.use_nn_moe
) )
if has_separate_shared_experts: if has_separate_shared_experts:
...@@ -2133,16 +2154,20 @@ def moe_forward( ...@@ -2133,16 +2154,20 @@ def moe_forward(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
self = get_layer_from_name(layer_name) self = get_layer_from_name(layer_name)
assert self.shared_experts is None assert self.shared_experts is None
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits, i_q=i_q, i_s=i_s)
def moe_forward_fake( def moe_forward_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -2160,9 +2185,20 @@ def moe_forward_shared( ...@@ -2160,9 +2185,20 @@ def moe_forward_shared(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
import traceback
import torch.distributed as dist
print("=======")
# if dist.get_rank() == 0:
# traceback.print_stack()
print("=======")
self = get_layer_from_name(layer_name) self = get_layer_from_name(layer_name)
assert self.shared_experts is not None assert self.shared_experts is not None
if envs.USE_FUSED_RMS_QUANT:
return self.forward_impl(hidden_states, router_logits, i_q=i_q, i_s=i_s)
else:
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
...@@ -2170,6 +2206,8 @@ def moe_forward_shared_fake( ...@@ -2170,6 +2206,8 @@ def moe_forward_shared_fake(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, layer_name: str,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states) shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states) fused_out = torch.empty_like(hidden_states)
......
...@@ -62,9 +62,13 @@ class SharedFusedMoE(FusedMoE): ...@@ -62,9 +62,13 @@ class SharedFusedMoE(FusedMoE):
router_logits: torch.Tensor, router_logits: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None *, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if self._shared_experts is not None: print("YYYY: 处理下游专家输入, self.use_overlapped is :", self.use_overlapped)
if not self.use_overlapped: if not self.use_overlapped:
if self._shared_experts is not None: if self._shared_experts is not None:
shared_out = self._shared_experts(hidden_states) shared_out = self._shared_experts(hidden_states,
i_q=iqis[0] if iqis is not None else None,
i_s=iqis[1] if iqis is not None else None)
# Reduce shared expert outputs if necessary, since the MLP # Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False. # should have been created with reduce_results=False.
...@@ -77,14 +81,20 @@ class SharedFusedMoE(FusedMoE): ...@@ -77,14 +81,20 @@ class SharedFusedMoE(FusedMoE):
else: else:
shared_out = None shared_out = None
print("YYYY: i suppose not this branch!!!")
fused_out = super().forward( fused_out = super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
i_q=iqis[0] if iqis is not None else None,
i_s=iqis[1] if iqis is not None else None,
) )
else: else:
shared_out, fused_out = super().forward( shared_out, fused_out = super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
i_q=iqis[0] if iqis is not None else None,
i_s=iqis[1] if iqis is not None else None,
) )
# ensure early TP reduction of shared expert outputs when required # ensure early TP reduction of shared expert outputs when required
if ( if (
......
...@@ -339,11 +339,11 @@ class FusedRMSNormQuant(nn.Module): ...@@ -339,11 +339,11 @@ class FusedRMSNormQuant(nn.Module):
update_input: Optional[bool] = True update_input: Optional[bool] = True
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
x, x_scales = fused_rmsquant(x, self.weight, i_q, i_s = fused_rmsquant(x, self.weight,
self.variance_epsilon, self.variance_epsilon,
quant_dtype, residual, quant_dtype, residual,
update_input) update_input)
return x, x_scales, residual return i_q, i_s, residual
def fused_rmsquant_impl( def fused_rmsquant_impl(
......
...@@ -716,6 +716,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -716,6 +716,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Matrix multiply. # Matrix multiply.
assert self.quant_method is not None assert self.quant_method is not None
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
print("YYYYY: mlp.gate_up self.quant_method.apply: ", self.quant_method.apply)
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
else:
output_parallel = self.quant_method.apply(self, input_, bias) output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output and self.tp_size > 1: if self.gather_output and self.tp_size > 1:
......
...@@ -8,6 +8,7 @@ from vllm.attention.layer import MLAAttention ...@@ -8,6 +8,7 @@ from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm import envs
@dataclass @dataclass
...@@ -130,6 +131,10 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -130,6 +131,10 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
assert self.q_b_proj is not None, ( assert self.q_b_proj is not None, (
"q_b_proj is required when q_lora_rank is not None" "q_b_proj is required when q_lora_rank is not None"
) )
# print("YYYY: self.fused_qkv_a_proj is:", self.fused_qkv_a_proj) # MergedColumnParallelLinear(in_features=7168, output_features=2112, bias=False, tp_size=1, gather_output=False)
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
qkv_lora = self.fused_qkv_a_proj(hidden_states, iqis=iqis)[0]
else:
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split( q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
...@@ -138,6 +143,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -138,6 +143,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
q_c = self.q_a_layernorm(q_c) q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0] q = self.q_b_proj(q_c)[0]
else: else:
# print("YYYY: self.q_lora_rank is None. unsupported for now!!!")
assert self.kv_a_proj_with_mqa is not None, ( assert self.kv_a_proj_with_mqa is not None, (
"kv_a_proj_with_mqa is required when q_lora_rank is None" "kv_a_proj_with_mqa is required when q_lora_rank is None"
) )
......
...@@ -398,6 +398,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -398,6 +398,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -420,6 +422,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -420,6 +422,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=False, use_nn_moe=False,
i_q=i_q,
i_s=i_s
) )
def select_gemm_impl( def select_gemm_impl(
......
...@@ -345,7 +345,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -345,7 +345,7 @@ class DeepseekV2MoE(nn.Module):
else None, else None,
) )
def forward(self, hidden_states: torch.Tensor, def forward_default(self, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None *, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
...@@ -413,12 +413,15 @@ class DeepseekV2MoE(nn.Module): ...@@ -413,12 +413,15 @@ class DeepseekV2MoE(nn.Module):
hidden_states = sequence_parallel_chunk(hidden_states) hidden_states = sequence_parallel_chunk(hidden_states)
if self.experts.is_internal_router: if self.experts.is_internal_router:
# print("YYYY: self.experts.is_internal_router is True.") # True
# In this case, the gate/router runs inside the FusedMoE class # In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts( fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states, iqis=iqis hidden_states=hidden_states, router_logits=hidden_states, iqis=iqis
) )
else: else: # NO
print("YYYY: self.experts.is_internal_router is False.")
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
print("YYYY: RQ MOE.gate quant: self.gate.quant_method.apply: ", self.gate.quant_method.apply)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts( fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits, iqis=iqis hidden_states=hidden_states, router_logits=router_logits, iqis=iqis
...@@ -453,6 +456,18 @@ class DeepseekV2MoE(nn.Module): ...@@ -453,6 +456,18 @@ class DeepseekV2MoE(nn.Module):
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
def choose_forward(self):
if envs.USE_FUSED_RMS_QUANT:
return self.forward_RQ
else:
return self.forward_default
def forward(self, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
forward_func = self.choose_forward()
return forward_func(hidden_states, iqis=iqis)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
import math import math
...@@ -1078,14 +1093,19 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1078,14 +1093,19 @@ class DeepseekV2DecoderLayer(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
# Fix residual FP16 overflow # Fix residual FP16 overflow
# print("YYYYY: forward_RQ is called")
residual_fix_overflow = False residual_fix_overflow = False
assert self.input_layernorm.has_weight is True
if residual is None: if residual is None:
residual = hidden_states.clone() residual = hidden_states.clone()
i_q, i_s, _ = self.input_layernorm(x=hidden_states, i_q, i_s, _ = self.input_layernorm(x=hidden_states,
residual=None, residual=None,
quant_dtype=torch.int8, quant_dtype=torch.int8,
update_input=False) # update_input=False # wjl, del
update_input=True
)
# print("YYYY: i_q:", i_q.flatten()[:5], i_q.shape)
# print("YYYY: i_s:", i_s.flatten()[:5], i_s.shape)
# i_q, i_s = lm_faster_rmsquant(input=hidden_states, # i_q, i_s = lm_faster_rmsquant(input=hidden_states,
# rms_weight = self.input_layernorm.weight.data, # rms_weight = self.input_layernorm.weight.data,
# epsilon=self._eps, # epsilon=self._eps,
...@@ -1101,17 +1121,20 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1101,17 +1121,20 @@ class DeepseekV2DecoderLayer(nn.Module):
# quant_dtype=torch.int8, # quant_dtype=torch.int8,
# residual=residual, # residual=residual,
# update_input=False) # update_input=False)
# print("YYYY: input rms residual bf", residual.flatten()[:5])
i_q, i_s, residual = self.input_layernorm(x=hidden_states, i_q, i_s, residual = self.input_layernorm(x=hidden_states,
residual=residual, residual=residual,
quant_dtype=torch.int8, quant_dtype=torch.int8,
update_input=False) # update_input=False, # wjl, del
update_input=True,
)
# print("YYYY:input rms residual af", residual.flatten()[:5])
attn_kwargs = { attn_kwargs = {
"positions": positions, "positions": positions,
"hidden_states": hidden_states, "hidden_states": hidden_states,
"iqis": (i_q, i_s) # "iqis": (i_q, i_s) # wjl, del
} }
if not self.use_mha: if not self.use_mha:
attn_kwargs["llama_4_scaling"] = llama_4_scaling attn_kwargs["llama_4_scaling"] = llama_4_scaling
...@@ -1121,6 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1121,6 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module):
not isinstance(self.self_attn, DeepseekAttention) not isinstance(self.self_attn, DeepseekAttention)
and hidden_states.dtype == torch.float16 and hidden_states.dtype == torch.float16
): ):
print("YYYYY: FP16 overflow fix is applied")
# Fix FP16 overflow # Fix FP16 overflow
# We scale both hidden_states and residual before # We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale. # rmsnorm, and rmsnorm result would not affect by scale.
...@@ -1139,14 +1163,18 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1139,14 +1163,18 @@ class DeepseekV2DecoderLayer(nn.Module):
# quant_dtype=torch.int8, # quant_dtype=torch.int8,
# residual=residual, # residual=residual,
# update_input=update_hs) # update_input=update_hs)
assert self.post_attention_layernorm.has_weight is True
_i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states, _i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states,
residual=residual, residual=residual,
quant_dtype=torch.int8, quant_dtype=torch.int8,
update_input=update_hs) update_input=update_hs, # wjl, del
# update_input=True,
)
new_resi = residual new_resi = residual
hidden_states = self.mlp(hidden_states, hidden_states = self.mlp(hidden_states,
iqis=(_i_q, _i_s)) # iqis=(_i_q, _i_s) # wjl, del
)
if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow # Fix FP16 overflow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment