Commit e55ba677 authored by zhangzbb's avatar zhangzbb
Browse files

Merge branch 'v0.15.1-fusion' into 'v0.15.1-dev'

feat: Support rms+quant fusion in minimax_m2 series model.

See merge request dcutoolkit/deeplearing/vllm!538
parents ca158ae9 89a8f88b
...@@ -39,7 +39,7 @@ from vllm.distributed import ( ...@@ -39,7 +39,7 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm, FusedRMSNormQuant
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
...@@ -58,6 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -58,6 +58,7 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm import envs
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import ( from .utils import (
...@@ -229,8 +230,10 @@ class MiniMaxM2Attention(nn.Module): ...@@ -229,8 +230,10 @@ class MiniMaxM2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) # iqis: (input_quant, input_scale)
qkv, _ = self.qkv_proj(hidden_states, iqis = iqis)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = MiniMaxText01RMSNormTP.forward_qk( q, k = MiniMaxText01RMSNormTP.forward_qk(
self.q_norm, self.k_norm, q.contiguous(), k.contiguous() self.q_norm, self.k_norm, q.contiguous(), k.contiguous()
...@@ -282,27 +285,64 @@ class MiniMaxM2DecoderLayer(nn.Module): ...@@ -282,27 +285,64 @@ class MiniMaxM2DecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if envs.USE_FUSED_RMS_QUANT:
self.input_layernorm = FusedRMSNormQuant(config.hidden_size, eps=config.rms_norm_eps)
else:
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
) )
def forward( def rms_quant_fusion_forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: torch.Tensor | None, residual: torch.Tensor | None,
) -> torch.Tensor: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states.clone()
hidden_states = self.input_layernorm(hidden_states) input_quant, input_scale, _ = self.input_layernorm(
x = hidden_states,
residual = None,
quant_dtype = torch.int8,
update_input = False
)
else: else:
hidden_states, residual = self.input_layernorm(hidden_states, residual) input_quant, input_scale, residual = self.input_layernorm(
x = hidden_states,
residual = residual,
quant_dtype = torch.int8,
update_input = False
)
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions = positions,
hidden_states=hidden_states, hidden_states = hidden_states,
iqis = (input_quant, input_scale)
) )
return hidden_states, residual
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> torch.Tensor:
if envs.USE_FUSED_RMS_QUANT:
hidden_states, residual = self.rms_quant_fusion_forward(
positions, hidden_states, residual
)
else:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment