Unverified Commit e392d858 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Core] Refactor `QKVCrossParallelLinear` implementation to support BNB 4-bit quantization (#14545)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent 77a318bd
...@@ -17,6 +17,7 @@ from vllm.sequence import SampleLogprobs ...@@ -17,6 +17,7 @@ from vllm.sequence import SampleLogprobs
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets) _ImageAssets)
from ....quantization.utils import is_quant_method_supported
from ....utils import large_gpu_test from ....utils import large_gpu_test
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
...@@ -397,6 +398,50 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, ...@@ -397,6 +398,50 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
) )
@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["float16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
def test_bnb_regression(
image_assets: _ImageAssets,
model: str,
dtype: str,
max_tokens: int,
):
stop_sign = image_assets[0].pil_image
prompts = [
{
"prompt": "<|begin_of_text|>The content of the image <|image|> is",
"multi_modal_data": {
"image": stop_sign
},
},
{
"prompt":
"The color of the sky is blue but sometimes it can also be",
},
]
# Test regression about QKVCrossParallelLinear
llm = LLM(
model=model,
dtype=dtype,
max_model_len=4096,
max_num_seqs=2,
enforce_eager=True,
quantization="bitsandbytes",
load_format="bitsandbytes",
)
sampling_params = SamplingParams(
temperature=0,
max_tokens=max_tokens,
)
outputs = llm.generate(prompts, sampling_params)
assert outputs
@large_gpu_test(min_gb=48) @large_gpu_test(min_gb=48)
@pytest.mark.core_model @pytest.mark.core_model
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
......
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
import itertools import itertools
from abc import abstractmethod from abc import abstractmethod
from typing import Optional, Union from typing import Any, Literal, Optional, Union
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter from torch.nn.parameter import Parameter, UninitializedParameter
...@@ -84,6 +85,43 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): ...@@ -84,6 +85,43 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
return param[shard_id], loaded_weight return param[shard_id], loaded_weight
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
"""
Separate the BitsAndBytes 4-bit shard.
For example, given bnb weight attributes as below:
{
'bnb_shard_offsets': array([0, 4, 8, 16]),
'bnb_quant_state': {0: ..., 1: ..., 2: ...},
}
The function will return:
{
'bnb_shard_offsets': array([0, 4]),
'bnb_quant_state': {0: ...},
}
and
{
'bnb_shard_offsets': array([0, 4, 12]),
'bnb_quant_state': {0: ..., 1: ...},
}
"""
shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
offset_l = shard_offsets[:2]
offset_r = shard_offsets[1:] - shard_offsets[1]
quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
quant_state_r = {
i - 1: bnb_weight_attrs["bnb_quant_state"][i]
for i in range(1,
len(shard_offsets) - 1)
}
left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
return left, right
class LinearMethodBase(QuantizeMethodBase): class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods.""" """Base class for different (maybe quantized) linear methods."""
...@@ -1229,7 +1267,24 @@ class RowParallelLinear(LinearBase): ...@@ -1229,7 +1267,24 @@ class RowParallelLinear(LinearBase):
return s return s
class QKVCrossParallelLinear(torch.nn.Module): class QKVCrossParallelLinear(LinearBase):
"""Linear layers for efficient cross-attention's QKV transformation.
Args:
hidden_size: input hidden state size of the transformer.
head_size: size of each attention head.
total_num_heads: total number of attention query heads.
total_num_kv_heads: total number of attention key/value heads. If
None, assume total_num_kv_heads = total_num_heads.
bias: If true, add bias.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(self, def __init__(self,
hidden_size: int, hidden_size: int,
...@@ -1241,12 +1296,28 @@ class QKVCrossParallelLinear(torch.nn.Module): ...@@ -1241,12 +1296,28 @@ class QKVCrossParallelLinear(torch.nn.Module):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = ""):
super().__init__() # input_size and output_size are not used, just for alignment
input_size = hidden_size
output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
super().__init__(input_size=input_size,
output_size=output_size,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)
self.quant_config = quant_config
# Empty placeholders for loading as a single module. # Empty placeholders for loading as a single module.
self.weight = torch.nn.Parameter() placeholder_size = 0
set_weight_attrs(self.weight, { assert self.quant_method is not None
"weight_loader": self.weight_loader_weight, self.quant_method.create_weights(self,
}) placeholder_size, [placeholder_size],
placeholder_size,
placeholder_size,
self.params_dtype,
weight_loader=self.weight_loader)
# Use a dictionary to avoid submodules parameters auto-registration: # Use a dictionary to avoid submodules parameters auto-registration:
# drop-in replacement for a `QKVParallelLinear` module. # drop-in replacement for a `QKVParallelLinear` module.
self.proj = dict() self.proj = dict()
...@@ -1276,18 +1347,94 @@ class QKVCrossParallelLinear(torch.nn.Module): ...@@ -1276,18 +1347,94 @@ class QKVCrossParallelLinear(torch.nn.Module):
if bias: if bias:
self.bias = torch.nn.Parameter() self.bias = torch.nn.Parameter()
set_weight_attrs(self.bias, { set_weight_attrs(self.bias, {
"weight_loader": self.weight_loader_bias, "output_dim": 0,
"weight_loader": self.weight_loader,
}) })
else:
self.bias = None
@property @property
def q_proj_decoder(self): def q_proj_decoder(self) -> ColumnParallelLinear:
return self.proj["q_proj_decoder"] layer = self.proj["q_proj_decoder"]
for name, param in self.named_parameters():
target_param = getattr(layer, name)
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
return layer
@property @property
def kv_proj_encoder(self): def kv_proj_encoder(self) -> QKVParallelLinear:
return self.proj["kv_proj_encoder"] layer = self.proj["kv_proj_encoder"]
for name, param in self.named_parameters():
target_param = getattr(layer, name)
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
return layer
def sync_weight_attrs(
self,
src_param: nn.Parameter,
tgt_param: nn.Parameter,
mode: Literal["q_proj_decoder", "kv_proj_encoder"],
):
missing_attrs_dict = {
k: getattr(src_param, k)
for k in (set(src_param.__dict__.keys()) -
set(tgt_param.__dict__.keys()))
}
# TODO(Isotr0py): handle bitsandbytes 8bit
use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
False)
if (missing_attrs_dict and use_bitsandbytes_4bit):
q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
missing_attrs_dict)
if mode == "q_proj_decoder":
set_weight_attrs(tgt_param, q_proj_attrs)
elif mode == "kv_proj_encoder":
set_weight_attrs(tgt_param, kv_proj_attrs)
else:
set_weight_attrs(tgt_param, missing_attrs_dict)
def _is_same_param(
self,
src_param: torch.nn.Parameter,
map_param: torch.nn.Parameter,
) -> bool:
"""Check if two parameters are exactly pointing to same things."""
# ignore weight_loader because it's always different
key_to_ignore = ["weight_loader", "_weight_loader"]
has_same_type_name = type(src_param) is type(map_param)
src_param_attrs = {
k: v
for k, v in src_param.__dict__.items() if k not in key_to_ignore
}
map_param_attrs = {
k: v
for k, v in map_param.__dict__.items() if k not in key_to_ignore
}
has_same_attrs = src_param_attrs == map_param_attrs
return has_same_type_name and has_same_attrs
def forward(self, decoder_hidden_states, encoder_hidden_states): def select_proj_params(
self,
layer: nn.Module,
param: nn.Parameter,
) -> nn.Parameter:
"""
Given the placeholder param,
return the corresponding param in the proj layers.
"""
target_param_list = [
v for _, v in layer.named_parameters()
if self._is_same_param(param, v)
]
assert len(target_param_list) == 1
target_param = target_param_list[0]
return target_param
def forward( # type: ignore[override]
self,
decoder_hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
q, _ = self.q_proj_decoder(decoder_hidden_states) q, _ = self.q_proj_decoder(decoder_hidden_states)
if encoder_hidden_states is None: if encoder_hidden_states is None:
# Encoder KV already cached. # Encoder KV already cached.
...@@ -1300,25 +1447,21 @@ class QKVCrossParallelLinear(torch.nn.Module): ...@@ -1300,25 +1447,21 @@ class QKVCrossParallelLinear(torch.nn.Module):
k, v = kv_enc.split(self.kv_size, dim=-1) k, v = kv_enc.split(self.kv_size, dim=-1)
return q, k, v return q, k, v
def weight_loader_weight(self, def weight_loader(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
else self.kv_proj_encoder.weight
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)
def weight_loader_bias(self,
param: torch.nn.Parameter, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None): loaded_shard_id: Optional[str] = None):
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \ layer = (self.q_proj_decoder
else self.kv_proj_encoder.bias if loaded_shard_id == "q" else self.kv_proj_encoder)
param.weight_loader( target_param = self.select_proj_params(layer, param)
param, shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
loaded_weight) if loaded_shard_id == "q" else param.weight_loader( layer.weight_loader(target_param, loaded_weight, *shard_id_args)
param, loaded_weight, loaded_shard_id)
\ No newline at end of file def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", q_size={self.q_proj_decoder.output_size_per_partition}"
s += f", kv_size={self.kv_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += ", gather_output=False"
return s
...@@ -43,6 +43,7 @@ from vllm.forward_context import get_forward_context ...@@ -43,6 +43,7 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -813,20 +814,11 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -813,20 +814,11 @@ class MllamaTextCrossAttention(nn.Module):
self.q_local_size = self.num_local_heads * self.head_dim self.q_local_size = self.num_local_heads * self.head_dim
self.kv_local_size = self.num_local_key_value_heads * self.head_dim self.kv_local_size = self.num_local_key_value_heads * self.head_dim
# TODO(Isotr0py): Use QKVCrossParallelLinear when it supports self.qkv_proj = QKVCrossParallelLinear(
# quantization
self.q_proj = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.num_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.kv_proj = QKVParallelLinear(
self.hidden_size, self.hidden_size,
self.head_dim, self.head_dim,
total_num_heads=0, self.num_heads,
total_num_kv_heads=self.num_key_value_heads, self.num_key_value_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
...@@ -862,15 +854,11 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -862,15 +854,11 @@ class MllamaTextCrossAttention(nn.Module):
kv_range_for_decode: Optional[List[Tuple[int, int]]], kv_range_for_decode: Optional[List[Tuple[int, int]]],
cross_attention_states: Optional[torch.Tensor], cross_attention_states: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
q, _ = self.q_proj(hidden_states) q, k, v = self.qkv_proj(hidden_states, cross_attention_states)
if cross_attention_states is not None: if cross_attention_states is not None:
kv, _ = self.kv_proj(cross_attention_states)
k, v = kv.split([self.kv_local_size, self.kv_local_size], dim=-1)
k = k.view(-1, self.num_local_key_value_heads, self.head_dim) k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
v = v.view(-1, self.num_local_key_value_heads, self.head_dim) v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
k = self.k_norm(k) k = self.k_norm(k)
else:
k = v = None
q = q.view(-1, self.num_local_heads, self.head_dim) q = q.view(-1, self.num_local_heads, self.head_dim)
q = self.q_norm(q) q = self.q_norm(q)
...@@ -1161,13 +1149,8 @@ class MllamaForCausalLM(nn.Module): ...@@ -1161,13 +1149,8 @@ class MllamaForCausalLM(nn.Module):
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsV0Only): SupportsV0Only):
packed_modules_mapping = { packed_modules_mapping = {
"self_attn.qkv_proj": [ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"self_attn.q_proj", "gate_up_proj": ["gate_proj", "up_proj"]
"self_attn.k_proj",
"self_attn.v_proj",
],
"cross_attn.kv_proj": ["cross_attn.k_proj", "cross_attn.v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
} }
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
...@@ -1437,11 +1420,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1437,11 +1420,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"), (".qkv_proj", ".k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"), (".qkv_proj", ".v_proj", "v"),
(".cross_attn.kv_proj", ".cross_attn.k_proj", "k"),
(".cross_attn.kv_proj", ".cross_attn.v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment