Unverified Commit 96354d6a authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Add base class for LoRA-supported models (#5018)

parent d12af207
...@@ -4,6 +4,9 @@ Using LoRA adapters ...@@ -4,6 +4,9 @@ Using LoRA adapters
=================== ===================
This document shows you how to use `LoRA adapters <https://arxiv.org/abs/2106.09685>`_ with vLLM on top of a base model. This document shows you how to use `LoRA adapters <https://arxiv.org/abs/2106.09685>`_ with vLLM on top of a base model.
LoRA adapters can be used with any vLLM model that implements :class:`~vllm.model_executor.models.interfaces.SupportsLoRA`.
Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save
them locally with them locally with
......
...@@ -2,6 +2,7 @@ from typing import List, Optional ...@@ -2,6 +2,7 @@ from typing import List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
import torch import torch
import torch.types
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
...@@ -64,7 +65,7 @@ class LoRALayerWeights: ...@@ -64,7 +65,7 @@ class LoRALayerWeights:
output_dim: int, output_dim: int,
rank: int, rank: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.types.Device,
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights": embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available() pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank], lora_a = torch.zeros([input_dim, rank],
......
...@@ -18,6 +18,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ...@@ -18,6 +18,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import (from_layer, from_layer_logits_processor, from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule) parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.utils import LRUCache, is_pin_memory_available from vllm.utils import LRUCache, is_pin_memory_available
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -363,7 +364,7 @@ class LoRAModelManager: ...@@ -363,7 +364,7 @@ class LoRAModelManager:
def __init__( def __init__(
self, self,
model: nn.Module, model: SupportsLoRA,
max_num_seqs: int, max_num_seqs: int,
max_num_batched_tokens: int, max_num_batched_tokens: int,
vocab_size: int, vocab_size: int,
...@@ -411,7 +412,7 @@ class LoRAModelManager: ...@@ -411,7 +412,7 @@ class LoRAModelManager:
# embeddings_indices # embeddings_indices
self.indices_len: List[Optional[int]] = [None] * 4 self.indices_len: List[Optional[int]] = [None] * 4
self.model: nn.Module = model self.model = model
if hasattr(self.model, "supported_lora_modules"): if hasattr(self.model, "supported_lora_modules"):
self.supported_lora_modules = copy.deepcopy( self.supported_lora_modules = copy.deepcopy(
self.model.supported_lora_modules) self.model.supported_lora_modules)
...@@ -428,7 +429,6 @@ class LoRAModelManager: ...@@ -428,7 +429,6 @@ class LoRAModelManager:
self._active_loras: Dict[int, None] = {} self._active_loras: Dict[int, None] = {}
self._last_mapping: Optional[LoRAMapping] = None self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules() self._create_lora_modules()
self.model.lora_manager = self
@property @property
def capacity(self) -> int: def capacity(self) -> int:
......
...@@ -32,7 +32,8 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -32,7 +32,8 @@ from vllm.model_executor.model_loader.weight_utils import (
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator) pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_tpu from vllm.utils import is_tpu
...@@ -64,12 +65,15 @@ def _get_quantization_config( ...@@ -64,12 +65,15 @@ def _get_quantization_config(
def _get_model_initialization_kwargs( def _get_model_initialization_kwargs(
model_class: Type[nn.Module], lora_config: Optional[LoRAConfig], model_class: Type[nn.Module],
vision_language_config: Optional[VisionLanguageConfig] lora_config: Optional[LoRAConfig],
vlm_config: Optional[VisionLanguageConfig],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Get extra kwargs for model initialization.""" """Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {} extra_kwargs: Dict[str, Any] = {}
if hasattr(model_class, "supported_lora_modules"):
if supports_lora(model_class):
# lora_config=None is used to disable LoRA
extra_kwargs["lora_config"] = lora_config extra_kwargs["lora_config"] = lora_config
elif lora_config: elif lora_config:
raise ValueError( raise ValueError(
...@@ -77,13 +81,15 @@ def _get_model_initialization_kwargs( ...@@ -77,13 +81,15 @@ def _get_model_initialization_kwargs(
"but LoRA is enabled. Support for this model may " "but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, " "be added in the future. If this is important to you, "
"please open an issue on github.") "please open an issue on github.")
elif issubclass(model_class, VisionLanguageModelBase):
if vision_language_config is None: if supports_vision(model_class):
if vlm_config is None:
raise ValueError("Provide `image_input_type` and other vision " raise ValueError("Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint " "related configurations through LLM entrypoint "
"or engine arguments.") "or engine arguments.")
extra_kwargs["vision_language_config"] = vision_language_config extra_kwargs["vlm_config"] = vlm_config
return extra_kwargs return extra_kwargs
......
...@@ -45,6 +45,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -45,6 +45,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from .interfaces import SupportsLoRA
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
...@@ -292,7 +294,9 @@ class BaiChuanModel(nn.Module): ...@@ -292,7 +294,9 @@ class BaiChuanModel(nn.Module):
return hidden_states return hidden_states
class BaiChuanBaseForCausalLM(nn.Module): class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = { packed_modules_mapping = {
"W_pack": ["W_pack"], "W_pack": ["W_pack"],
"gate_up_proj": [ "gate_up_proj": [
...@@ -312,14 +316,17 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -312,14 +316,17 @@ class BaiChuanBaseForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, cache_config, self.model = BaiChuanModel(config, position_embedding, cache_config,
quant_config) quant_config)
......
...@@ -28,6 +28,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -28,6 +28,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA
class GLMAttention(nn.Module): class GLMAttention(nn.Module):
...@@ -322,7 +324,9 @@ class ChatGLMModel(nn.Module): ...@@ -322,7 +324,9 @@ class ChatGLMModel(nn.Module):
return hidden_states return hidden_states
class ChatGLMForCausalLM(nn.Module): class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = { packed_modules_mapping = {
"query_key_value": ["query_key_value"], "query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"] "dense_h_to_4h": ["dense_h_to_4h"]
...@@ -345,7 +349,10 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -345,7 +349,10 @@ class ChatGLMForCausalLM(nn.Module):
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config: ChatGLMConfig = config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length", self.max_position_embeddings = getattr(config, "max_sequence_length",
8192) 8192)
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
from transformers import PretrainedConfig from transformers import LlamaConfig
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -55,7 +55,7 @@ class DeciLMForCausalLM(LlamaForCausalLM): ...@@ -55,7 +55,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
def __init__( def __init__(
self, self,
config: Optional[PretrainedConfig] = None, config: LlamaConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
......
...@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from .interfaces import SupportsLoRA
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -288,7 +290,9 @@ class GemmaModel(nn.Module): ...@@ -288,7 +290,9 @@ class GemmaModel(nn.Module):
return hidden_states return hidden_states
class GemmaForCausalLM(nn.Module): class GemmaForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -319,9 +323,11 @@ class GemmaForCausalLM(nn.Module): ...@@ -319,9 +323,11 @@ class GemmaForCausalLM(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
del lora_config # Unused.
super().__init__() super().__init__()
self.config = config self.config = config
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = GemmaModel(config, cache_config, quant_config) self.model = GemmaModel(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
......
...@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from .interfaces import SupportsLoRA
class GPTBigCodeAttention(nn.Module): class GPTBigCodeAttention(nn.Module):
...@@ -230,7 +232,9 @@ class GPTBigCodeModel(nn.Module): ...@@ -230,7 +232,9 @@ class GPTBigCodeModel(nn.Module):
return hidden_states return hidden_states
class GPTBigCodeForCausalLM(nn.Module): class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = {"c_attn": ["c_attn"]} packed_modules_mapping = {"c_attn": ["c_attn"]}
supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"] supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]
...@@ -250,7 +254,10 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -250,7 +254,10 @@ class GPTBigCodeForCausalLM(nn.Module):
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, cache_config, quant_config, self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
lora_config) lora_config)
......
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
Union, overload, runtime_checkable)
from typing_extensions import TypeGuard
from vllm.config import LoRAConfig, VisionLanguageConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
@runtime_checkable
class SupportsVision(Protocol):
"""The interface required for all vision language models (VLMs)."""
supports_vision: ClassVar[Literal[True]]
def __init__(self, *, vlm_config: VisionLanguageConfig) -> None:
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsVisionType(Protocol):
supports_vision: Literal[True]
def __call__(self, *, vlm_config: VisionLanguageConfig) -> None:
...
@overload
def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]:
...
@overload
def supports_vision(model: object) -> TypeGuard[SupportsVision]:
...
def supports_vision(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]:
if isinstance(model, type):
return isinstance(model, _SupportsVisionType)
return isinstance(model, SupportsVision)
@runtime_checkable
class SupportsLoRA(Protocol):
"""The interface required for all models that support LoRA."""
supports_lora: ClassVar[Literal[True]]
packed_modules_mapping: ClassVar[Dict[str, List[str]]]
supported_lora_modules: ClassVar[List[str]]
embedding_modules: ClassVar[Dict[str, str]]
embedding_padding_modules: ClassVar[List[str]]
# lora_config is None when LoRA is not enabled
def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
supports_lora: Literal[True]
packed_modules_mapping: Dict[str, List[str]]
supported_lora_modules: List[str]
embedding_modules: Dict[str, str]
embedding_padding_modules: List[str]
def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
...
@overload
def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]:
...
@overload
def supports_lora(model: object) -> TypeGuard[SupportsLoRA]:
...
def supports_lora(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
result = _supports_lora(model)
if not result:
lora_attrs = (
"packed_modules_mapping",
"supported_lora_modules",
"embedding_modules",
"embedding_padding_modules",
)
missing_attrs = tuple(attr for attr in lora_attrs
if not hasattr(model, attr))
if getattr(model, "supports_lora", False):
if missing_attrs:
logger.warning(
"The model (%s) sets `supports_lora=True`, "
"but is missing LoRA-specific attributes: %s",
model,
missing_attrs,
)
else:
if not missing_attrs:
logger.warning(
"The model (%s) contains all LoRA-specific attributes, "
"but does not set `supports_lora=True`.", model)
return result
def _supports_lora(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
if isinstance(model, type):
return isinstance(model, _SupportsLoRAType)
return isinstance(model, SupportsLoRA)
...@@ -49,6 +49,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -49,6 +49,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.utils import is_hip, print_warning_once from vllm.utils import is_hip, print_warning_once
from .interfaces import SupportsLoRA
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
...@@ -296,7 +298,9 @@ class LlamaModel(nn.Module): ...@@ -296,7 +298,9 @@ class LlamaModel(nn.Module):
return hidden_states return hidden_states
class LlamaForCausalLM(nn.Module): class LlamaForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -336,7 +340,10 @@ class LlamaForCausalLM(nn.Module): ...@@ -336,7 +340,10 @@ class LlamaForCausalLM(nn.Module):
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.lora_config = lora_config
self.model = LlamaModel(config, self.model = LlamaModel(config,
cache_config, cache_config,
quant_config, quant_config,
......
...@@ -20,7 +20,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -20,7 +20,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import get_dummy_image_data from vllm.multimodal.image import get_dummy_image_data
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from .vlm_base import VisionLanguageModelBase from .interfaces import SupportsVision
_KEYS_TO_MODIFY_MAPPING = { _KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head", "language_model.lm_head": "lm_head",
...@@ -86,18 +86,21 @@ LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs] ...@@ -86,18 +86,21 @@ LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
@MULTIMODAL_REGISTRY.register_image_feature_input() @MULTIMODAL_REGISTRY.register_image_feature_input()
@MULTIMODAL_REGISTRY.register_image_pixel_input() @MULTIMODAL_REGISTRY.register_image_pixel_input()
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data) @MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
class LlavaForConditionalGeneration(VisionLanguageModelBase): class LlavaForConditionalGeneration(nn.Module, SupportsVision):
supports_vision = True
def __init__(self, def __init__(self,
config: LlavaConfig, config: LlavaConfig,
vision_language_config: VisionLanguageConfig, vlm_config: VisionLanguageConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__(vision_language_config) super().__init__()
self.config = config self.config = config
self.vlm_config = vlm_config
if self.vision_language_config.image_input_type == ( if self.vlm_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES): VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config.vision_config) self.vision_tower = CLIPVisionModel(config.vision_config)
else: else:
...@@ -122,11 +125,10 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase): ...@@ -122,11 +125,10 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
self.sampler = Sampler() self.sampler = Sampler()
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor: def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != list( if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]):
self.vision_language_config.image_input_shape[1:]):
raise ValueError( raise ValueError(
f"The expected image tensor shape is batch dimension plus " f"The expected image tensor shape is batch dimension plus "
f"{self.vision_language_config.image_input_shape[1:]}. " f"{self.vlm_config.image_input_shape[1:]}. "
f"You supplied {data.shape}. " f"You supplied {data.shape}. "
f"If you are using vLLM's entrypoint, make sure your " f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with " f"supplied image input is consistent with "
...@@ -139,7 +141,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase): ...@@ -139,7 +141,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
image_features = kwargs.pop("image_features", None) image_features = kwargs.pop("image_features", None)
expected_input_type = self.vision_language_config.image_input_type expected_input_type = self.vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType ImageInputType = VisionLanguageConfig.ImageInputType
if expected_input_type == ImageInputType.PIXEL_VALUES: if expected_input_type == ImageInputType.PIXEL_VALUES:
...@@ -273,7 +275,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase): ...@@ -273,7 +275,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
inputs_embeds = merge_vision_embeddings( inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id) self.vlm_config.image_token_id)
input_ids = None input_ids = None
else: else:
......
...@@ -25,8 +25,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData ...@@ -25,8 +25,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
from vllm.sequence import SamplerOutput, SequenceData from vllm.sequence import SamplerOutput, SequenceData
from .interfaces import SupportsVision
from .llava import LlavaMultiModalProjector, merge_vision_embeddings from .llava import LlavaMultiModalProjector, merge_vision_embeddings
from .vlm_base import VisionLanguageModelBase
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -106,19 +106,21 @@ def _image_pixel_processor( ...@@ -106,19 +106,21 @@ def _image_pixel_processor(
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor) @MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor)
@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data) @MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data)
class LlavaNextForConditionalGeneration(VisionLanguageModelBase): class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
supports_vision = True
def __init__(self, def __init__(self,
config: LlavaNextConfig, config: LlavaNextConfig,
vision_language_config: VisionLanguageConfig, vlm_config: VisionLanguageConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__(vision_language_config) super().__init__()
# Update the type annotation from that of its superclass
self.config = config self.config = config
self.vlm_config = vlm_config
if self.vision_language_config.image_input_type == ( if self.vlm_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES): VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config=config.vision_config) self.vision_tower = CLIPVisionModel(config=config.vision_config)
else: else:
...@@ -146,7 +148,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase): ...@@ -146,7 +148,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
torch.empty(config.text_config.hidden_size)) torch.empty(config.text_config.hidden_size))
def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor: def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
_, num_channels, _, _ = self.vision_language_config.image_input_shape _, num_channels, _, _ = self.vlm_config.image_input_shape
# Note that this is different from that of vLLM vision_language_config # Note that this is different from that of vLLM vision_language_config
# since the image is resized by the HuggingFace preprocessor # since the image is resized by the HuggingFace preprocessor
...@@ -177,7 +179,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase): ...@@ -177,7 +179,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
image_sizes = kwargs.pop("image_sizes", None) image_sizes = kwargs.pop("image_sizes", None)
image_features = kwargs.pop("image_features", None) image_features = kwargs.pop("image_features", None)
expected_input_type = self.vision_language_config.image_input_type expected_input_type = self.vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType ImageInputType = VisionLanguageConfig.ImageInputType
if expected_input_type == ImageInputType.PIXEL_VALUES: if expected_input_type == ImageInputType.PIXEL_VALUES:
...@@ -386,7 +388,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase): ...@@ -386,7 +388,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
inputs_embeds = merge_vision_embeddings( inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id) self.vlm_config.image_token_id)
input_ids = None input_ids = None
else: else:
......
...@@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple ...@@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
...@@ -51,6 +52,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -51,6 +52,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from .interfaces import SupportsLoRA
class MiniCPMMoE(nn.Module): class MiniCPMMoE(nn.Module):
"""A tensor-parallel MoE implementation that shards each expert """A tensor-parallel MoE implementation that shards each expert
...@@ -388,7 +391,9 @@ class MiniCPMModel(nn.Module): ...@@ -388,7 +391,9 @@ class MiniCPMModel(nn.Module):
return hidden_states return hidden_states
class MiniCPMForCausalLM(nn.Module): class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -418,13 +423,16 @@ class MiniCPMForCausalLM(nn.Module): ...@@ -418,13 +423,16 @@ class MiniCPMForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.lora_config = lora_config
self.num_experts = getattr(self.config, "num_experts", 0) self.num_experts = getattr(self.config, "num_experts", 0)
self.quant_config = quant_config self.quant_config = quant_config
self.model = MiniCPMModel(config, self.model = MiniCPMModel(config,
......
...@@ -54,6 +54,8 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -54,6 +54,8 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA
class MixtralMoE(nn.Module): class MixtralMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert """A tensor-parallel MoE implementation for Mixtral that shards each expert
...@@ -472,7 +474,9 @@ class MixtralModel(nn.Module): ...@@ -472,7 +474,9 @@ class MixtralModel(nn.Module):
return hidden_states return hidden_states
class MixtralForCausalLM(nn.Module): class MixtralForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
packed_modules_mapping = { packed_modules_mapping = {
...@@ -504,7 +508,10 @@ class MixtralForCausalLM(nn.Module): ...@@ -504,7 +508,10 @@ class MixtralForCausalLM(nn.Module):
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.lora_config = lora_config
self.model = MixtralModel(config, self.model = MixtralModel(config,
cache_config, cache_config,
quant_config, quant_config,
......
...@@ -39,7 +39,7 @@ from typing import Iterable, List, Optional, Tuple ...@@ -39,7 +39,7 @@ from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PhiConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
...@@ -59,11 +59,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -59,11 +59,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from .interfaces import SupportsLoRA
class PhiAttention(nn.Module): class PhiAttention(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PhiConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
...@@ -131,7 +133,7 @@ class PhiAttention(nn.Module): ...@@ -131,7 +133,7 @@ class PhiAttention(nn.Module):
class PhiMLP(nn.Module): class PhiMLP(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PhiConfig,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
...@@ -160,7 +162,7 @@ class PhiMLP(nn.Module): ...@@ -160,7 +162,7 @@ class PhiMLP(nn.Module):
class PhiLayer(nn.Module): class PhiLayer(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PhiConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
...@@ -192,7 +194,7 @@ class PhiLayer(nn.Module): ...@@ -192,7 +194,7 @@ class PhiLayer(nn.Module):
class PhiModel(nn.Module): class PhiModel(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PhiConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
...@@ -229,7 +231,9 @@ class PhiModel(nn.Module): ...@@ -229,7 +231,9 @@ class PhiModel(nn.Module):
return hidden_states return hidden_states
class PhiForCausalLM(nn.Module): class PhiForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -250,14 +254,16 @@ class PhiForCausalLM(nn.Module): ...@@ -250,14 +254,16 @@ class PhiForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PhiConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
): ):
del lora_config # Unused.
super().__init__() super().__init__()
self.config = config self.config = config
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = PhiModel(config, cache_config, quant_config) self.model = PhiModel(config, cache_config, quant_config)
......
...@@ -48,6 +48,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -48,6 +48,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA
class Qwen2MLP(nn.Module): class Qwen2MLP(nn.Module):
...@@ -263,7 +265,9 @@ class Qwen2Model(nn.Module): ...@@ -263,7 +265,9 @@ class Qwen2Model(nn.Module):
return hidden_states return hidden_states
class Qwen2ForCausalLM(nn.Module): class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -293,7 +297,6 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -293,7 +297,6 @@ class Qwen2ForCausalLM(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
del lora_config
# TODO (@robertgshaw2): see if this can be moved out # TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")): and hasattr(config, "max_window_layers")):
...@@ -307,7 +310,10 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -307,7 +310,10 @@ class Qwen2ForCausalLM(nn.Module):
)) ))
super().__init__() super().__init__()
self.config = config self.config = config
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2Model(config, cache_config, quant_config) self.model = Qwen2Model(config, cache_config, quant_config)
......
from torch import nn
from vllm.config import VisionLanguageConfig
class VisionLanguageModelBase(nn.Module):
"""Base class for all vision language models (VLMs)."""
def __init__(self, vision_language_config: VisionLanguageConfig) -> None:
super().__init__()
self.vision_language_config = vision_language_config
...@@ -45,6 +45,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -45,6 +45,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from .interfaces import SupportsLoRA
class XverseMLP(nn.Module): class XverseMLP(nn.Module):
...@@ -266,7 +268,9 @@ class XverseModel(nn.Module): ...@@ -266,7 +268,9 @@ class XverseModel(nn.Module):
return hidden_states return hidden_states
class XverseForCausalLM(nn.Module): class XverseForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -299,10 +303,13 @@ class XverseForCausalLM(nn.Module): ...@@ -299,10 +303,13 @@ class XverseForCausalLM(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config=None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = XverseModel(config, cache_config, quant_config) self.model = XverseModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
......
...@@ -22,6 +22,7 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager ...@@ -22,6 +22,7 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import supports_lora
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
...@@ -225,14 +226,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -225,14 +226,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.model_memory_usage / float(2**30)) self.model_memory_usage / float(2**30))
if self.lora_config: if self.lora_config:
assert hasattr(self.model, "supported_lora_modules" assert supports_lora(self.model), "Model does not support LoRA"
) and self.model.supported_lora_modules, (
"Model does not support LoRA")
assert hasattr(
self.model,
"embedding_modules"), "Model does not have embedding_modules"
assert hasattr(self.model, "embedding_padding_modules"
), "Model does not have embedding_padding_modules"
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens, self.scheduler_config.max_num_batched_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment