Unverified Commit a53046b1 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Model] Support quantization of PixtralHFTransformer for PixtralHF (#9921)


Signed-off-by: default avatarmgoin <michael@neuralmagic.com>
parent 731aec5b
...@@ -299,3 +299,33 @@ def get_act_fn( ...@@ -299,3 +299,33 @@ def get_act_fn(
return ScaledActivation(act_fn, intermediate_size, input_is_parallel, return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype) params_dtype)
return act_fn return act_fn
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
"gelu": lambda: GeluAndMul(),
"silu": lambda: SiluAndMul(),
})
def get_act_and_mul_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")
act_fn = _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
if (quant_config is not None
and act_fn_name in quant_config.get_scaled_act_names()):
if intermediate_size is None:
raise ValueError("intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn
...@@ -19,8 +19,11 @@ from vllm.attention import AttentionMetadata ...@@ -19,8 +19,11 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -798,20 +801,24 @@ class PixtralHFMLP(nn.Module): ...@@ -798,20 +801,24 @@ class PixtralHFMLP(nn.Module):
super().__init__() super().__init__()
assert config.intermediate_size is not None assert config.intermediate_size is not None
# TODO: Use quant_config and prefix after optimizing this self.gate_up_proj = MergedColumnParallelLinear(
self.gate_proj = nn.Linear(config.hidden_size, input_size=config.hidden_size,
config.intermediate_size, output_sizes=[config.intermediate_size] * 2,
bias=False) bias=False,
self.up_proj = nn.Linear(config.hidden_size, quant_config=quant_config,
config.intermediate_size, prefix=f"{prefix}.gate_up_proj")
bias=False) self.down_proj = RowParallelLinear(input_size=config.intermediate_size,
self.down_proj = nn.Linear(config.intermediate_size, output_size=config.hidden_size,
config.hidden_size, bias=False,
bias=False) quant_config=quant_config,
self.act = get_act_fn(config.hidden_act) prefix=f"{prefix}.down_proj")
self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) gate_up, _ = self.gate_up_proj(x)
x = self.act_and_mul(gate_up)
x, _ = self.down_proj(x)
return x
class PixtralHFAttention(nn.Module): class PixtralHFAttention(nn.Module):
...@@ -830,21 +837,21 @@ class PixtralHFAttention(nn.Module): ...@@ -830,21 +837,21 @@ class PixtralHFAttention(nn.Module):
self.n_heads = config.num_attention_heads self.n_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads
self.scale = self.head_dim**-0.5 self.qkv_proj = QKVParallelLinear(
hidden_size=config.hidden_size,
# TODO: Use quant_config and prefix after optimizing this head_size=self.head_dim,
self.q_proj = nn.Linear(config.hidden_size, total_num_heads=self.n_heads,
config.hidden_size, bias=False,
bias=False) quant_config=quant_config,
self.k_proj = nn.Linear(config.hidden_size, prefix=f"{prefix}.qkv_proj",
config.hidden_size, )
bias=False) self.o_proj = RowParallelLinear(
self.v_proj = nn.Linear(config.hidden_size, input_size=config.hidden_size,
config.hidden_size, output_size=config.hidden_size,
bias=False) bias=False,
self.o_proj = nn.Linear(config.hidden_size, quant_config=quant_config,
config.hidden_size, prefix=f"{prefix}.o_proj",
bias=False) )
def forward( def forward(
self, self,
...@@ -854,13 +861,13 @@ class PixtralHFAttention(nn.Module): ...@@ -854,13 +861,13 @@ class PixtralHFAttention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, patches, _ = hidden_states.size() batch, patches, _ = hidden_states.size()
q = self.q_proj(hidden_states) qkv_states, _ = self.qkv_proj(hidden_states)
k = self.k_proj(hidden_states) q, k, v = qkv_states.chunk(3, dim=-1)
v = self.v_proj(hidden_states)
# Transpose q and k to apply HF's Rotary Position Embedding # Transpose q and k to apply HF's Rotary Position Embedding
q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2) q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(batch, patches, self.n_heads, self.head_dim)
cos, sin = position_embeddings cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0) q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
...@@ -868,22 +875,21 @@ class PixtralHFAttention(nn.Module): ...@@ -868,22 +875,21 @@ class PixtralHFAttention(nn.Module):
# Transpose q and k back for attention # Transpose q and k back for attention
q = q.transpose(1, 2).contiguous() q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous() k = k.transpose(1, 2).contiguous()
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
out = xops.memory_efficient_attention(q, out = xops.memory_efficient_attention(q,
k, k,
v, v,
attn_bias=attention_mask) attn_bias=attention_mask)
else: else:
v = v.reshape(batch, patches, self.n_heads, v = v.transpose(1, 2)
self.head_dim).transpose(1, 2)
out = nn.functional.scaled_dot_product_attention( out = nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask) q, k, v, attn_mask=attention_mask)
out = out.transpose(1, 2) out = out.transpose(1, 2)
out = out.reshape(batch, patches, self.n_heads * self.head_dim) out = out.view(batch, patches, self.n_heads * self.head_dim)
attn_output, _ = self.o_proj(out)
return self.o_proj(out) return attn_output, None
class PixtralHFTransformerBlock(nn.Module): class PixtralHFTransformerBlock(nn.Module):
...@@ -912,9 +918,9 @@ class PixtralHFTransformerBlock(nn.Module): ...@@ -912,9 +918,9 @@ class PixtralHFTransformerBlock(nn.Module):
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
position_embeddings: torch.Tensor, position_embeddings: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(hidden_states), r, _ = self.attention.forward(self.attention_norm(hidden_states),
attention_mask=attention_mask, attention_mask=attention_mask,
position_embeddings=position_embeddings) position_embeddings=position_embeddings)
h = hidden_states + r h = hidden_states + r
r = self.feed_forward.forward(self.ffn_norm(h)) r = self.feed_forward.forward(self.ffn_norm(h))
out = h + r out = h + r
...@@ -1053,10 +1059,24 @@ class PixtralHFVisionModel(nn.Module): ...@@ -1053,10 +1059,24 @@ class PixtralHFVisionModel(nn.Module):
# (TODO) Add prefix argument for filtering out weights to be loaded # (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [] stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
layer_count = len(self.transformer.layers)
for name, loaded_weight in weights: for name, loaded_weight in weights:
# omit layers when num_hidden_layers_override is set
if name.startswith("transformer.layers"):
layer_idx = int(name.split(".")[2])
if layer_idx >= layer_count:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment