Commit 89a8f88b authored by wanglong3's avatar wanglong3 Committed by zhangzbb
Browse files

feat: Support rms+quant fusion in minimax_m2 series model.

parent ca158ae9
......@@ -39,7 +39,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm, FusedRMSNormQuant
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
......@@ -58,6 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from vllm import envs
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (
......@@ -229,8 +230,10 @@ class MiniMaxM2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
# iqis: (input_quant, input_scale)
qkv, _ = self.qkv_proj(hidden_states, iqis = iqis)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = MiniMaxText01RMSNormTP.forward_qk(
self.q_norm, self.k_norm, q.contiguous(), k.contiguous()
......@@ -282,27 +285,64 @@ class MiniMaxM2DecoderLayer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if envs.USE_FUSED_RMS_QUANT:
self.input_layernorm = FusedRMSNormQuant(config.hidden_size, eps=config.rms_norm_eps)
else:
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
def rms_quant_fusion_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> torch.Tensor:
# Self Attention
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
residual = hidden_states.clone()
input_quant, input_scale, _ = self.input_layernorm(
x = hidden_states,
residual = None,
quant_dtype = torch.int8,
update_input = False
)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
input_quant, input_scale, residual = self.input_layernorm(
x = hidden_states,
residual = residual,
quant_dtype = torch.int8,
update_input = False
)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
positions = positions,
hidden_states = hidden_states,
iqis = (input_quant, input_scale)
)
return hidden_states, residual
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> torch.Tensor:
if envs.USE_FUSED_RMS_QUANT:
hidden_states, residual = self.rms_quant_fusion_forward(
positions, hidden_states, residual
)
else:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment