Unverified Commit b22b7984 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] PP support for embedding models and update docs (#9090)


Co-authored-by: default avatarRoger Wang <136131678+ywang96@users.noreply.github.com>
parent f22619fe
...@@ -7,10 +7,12 @@ vLLM supports a variety of generative Transformer models in `HuggingFace Transfo ...@@ -7,10 +7,12 @@ vLLM supports a variety of generative Transformer models in `HuggingFace Transfo
The following is the list of model architectures that are currently supported by vLLM. The following is the list of model architectures that are currently supported by vLLM.
Alongside each architecture, we include some popular models that use it. Alongside each architecture, we include some popular models that use it.
---- Text-only Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^
Text Generation
---------------
Decoder-only Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. list-table:: .. list-table::
:widths: 25 25 50 5 5 :widths: 25 25 50 5 5
:header-rows: 1 :header-rows: 1
...@@ -40,6 +42,11 @@ Decoder-only Language Models ...@@ -40,6 +42,11 @@ Decoder-only Language Models
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
- -
- ✅︎ - ✅︎
* - :code:`BartForConditionalGeneration`
- BART
- :code:`facebook/bart-base`, :code:`facebook/bart-large-cnn`, etc.
-
-
* - :code:`ChatGLMModel` * - :code:`ChatGLMModel`
- ChatGLM - ChatGLM
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc. - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
...@@ -259,11 +266,55 @@ Decoder-only Language Models ...@@ -259,11 +266,55 @@ Decoder-only Language Models
.. note:: .. note::
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
.. _supported_vlms: Text Embedding
--------------
.. list-table::
:widths: 25 25 50 5 5
:header-rows: 1
* - Architecture
- Models
- Example HuggingFace Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`Gemma2Model`
- Gemma2-based
- :code:`BAAI/bge-multilingual-gemma2`, etc.
-
- ✅︎
* - :code:`MistralModel`
- Mistral-based
- :code:`intfloat/e5-mistral-7b-instruct`, etc.
-
- ✅︎
Reward Modeling
---------------
.. list-table::
:widths: 25 25 50 5 5
:header-rows: 1
* - Architecture
- Models
- Example HuggingFace Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`Qwen2ForRewardModel`
- Qwen2-based
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
-
- ✅︎
.. note::
As an interim measure, these models are supported via Embeddings API. See `this RFC <https://github.com/vllm-project/vllm/issues/8967>`_ for upcoming changes.
Multimodal Language Models Multimodal Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. _supported_vlms:
.. list-table:: .. list-table::
:widths: 25 25 25 25 5 5 :widths: 25 25 25 25 5 5
:header-rows: 1 :header-rows: 1
...@@ -378,6 +429,7 @@ Multimodal Language Models ...@@ -378,6 +429,7 @@ Multimodal Language Models
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
----
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>` Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>`
......
...@@ -6,10 +6,9 @@ Using VLMs ...@@ -6,10 +6,9 @@ Using VLMs
vLLM provides experimental support for Vision Language Models (VLMs). See the :ref:`list of supported VLMs here <supported_vlms>`. vLLM provides experimental support for Vision Language Models (VLMs). See the :ref:`list of supported VLMs here <supported_vlms>`.
This document shows you how to run and serve these models using vLLM. This document shows you how to run and serve these models using vLLM.
.. important:: .. note::
We are actively iterating on VLM support. Expect breaking changes to VLM usage and development in upcoming releases without prior deprecation. We are actively iterating on VLM support. See `this RFC <https://github.com/vllm-project/vllm/issues/4194>`_ for upcoming changes,
and `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests.
We are continuously improving user & developer experience for VLMs. Please `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests.
Offline Inference Offline Inference
----------------- -----------------
......
...@@ -7,7 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node ...@@ -7,7 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
""" """
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, NamedTuple, Optional from typing import List, Literal, NamedTuple, Optional
import pytest import pytest
...@@ -97,6 +97,9 @@ class PPTestSettings: ...@@ -97,6 +97,9 @@ class PPTestSettings:
self.trust_remote_code, self.tokenizer_mode) self.trust_remote_code, self.tokenizer_mode)
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
# The values displayed here are only a rough indicator of the size of the model
# yapf: disable # yapf: disable
GENERATION_MODEL_SETTINGS = { GENERATION_MODEL_SETTINGS = {
# [DETAILED TESTS] # [DETAILED TESTS]
...@@ -104,15 +107,13 @@ GENERATION_MODEL_SETTINGS = { ...@@ -104,15 +107,13 @@ GENERATION_MODEL_SETTINGS = {
# [FAST TESTS] # [FAST TESTS]
# Uses Llama # Uses Llama
# "BAAI/AquilaChat-7B": PPTestSettings.fast(), # "BAAI/AquilaChat-7B": PPTestSettings.fast(),
# TODO: Test on larger GPU "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(tp_base=8, trust_remote_code=True), # noqa: E501
# "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"baichuan-inc/Baichuan-7B": PPTestSettings.fast(trust_remote_code=True), "baichuan-inc/Baichuan-7B": PPTestSettings.fast(trust_remote_code=True),
"baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"bigscience/bloomz-1b1": PPTestSettings.fast(), "bigscience/bloomz-1b1": PPTestSettings.fast(),
"THUDM/chatglm3-6b": PPTestSettings.fast(trust_remote_code=True), "THUDM/chatglm3-6b": PPTestSettings.fast(trust_remote_code=True),
"CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(tp_base=2, trust_remote_code=True), # noqa: E501 "CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(tp_base=2, trust_remote_code=True), # noqa: E501
# TODO: Test on larger GPU "databricks/dbrx-instruct": PPTestSettings.fast(tp_base=8),
# "databricks/dbrx-instruct": PPTestSettings.fast(),
"Deci/DeciLM-7B-instruct": PPTestSettings.fast(trust_remote_code=True), "Deci/DeciLM-7B-instruct": PPTestSettings.fast(trust_remote_code=True),
"deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(), "deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
"deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 "deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
...@@ -161,8 +162,9 @@ GENERATION_MODEL_SETTINGS = { ...@@ -161,8 +162,9 @@ GENERATION_MODEL_SETTINGS = {
EMBEDDING_MODEL_SETTINGS = { # type: ignore[var-annotated] EMBEDDING_MODEL_SETTINGS = { # type: ignore[var-annotated]
# [FAST TESTS] # [FAST TESTS]
# Uses Llama "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
# "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(), "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(),
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(tp_base=4, trust_remote_code=True), # noqa: E501
} }
MULTIMODAL_MODEL_SETTINGS = { MULTIMODAL_MODEL_SETTINGS = {
...@@ -192,40 +194,35 @@ CONDITIONAL_GENERATION_MODEL_SETTINGS = { # type: ignore[var-annotated] ...@@ -192,40 +194,35 @@ CONDITIONAL_GENERATION_MODEL_SETTINGS = { # type: ignore[var-annotated]
} }
# yapf: enable # yapf: enable
MODEL_SETTINGS = { # NOTE: You can update this on your local machine to run specific tests
**GENERATION_MODEL_SETTINGS,
**EMBEDDING_MODEL_SETTINGS,
**MULTIMODAL_MODEL_SETTINGS,
}
# You can update this on your local machine to run specific tests
TEST_MODELS = [ TEST_MODELS = [
# [LANGUAGE GENERATION]
"meta-llama/Meta-Llama-3-8B", "meta-llama/Meta-Llama-3-8B",
"facebook/chameleon-7b", "ibm/PowerLM-3b",
# [LANGUAGE EMBEDDING]
"intfloat/e5-mistral-7b-instruct",
"BAAI/bge-multilingual-gemma2",
# [MULTIMODAL GENERATION]
"OpenGVLab/InternVL2-1B", "OpenGVLab/InternVL2-1B",
"microsoft/Phi-3-vision-128k-instruct", "microsoft/Phi-3-vision-128k-instruct",
"mistralai/Pixtral-12B-2409",
"fixie-ai/ultravox-v0_3", "fixie-ai/ultravox-v0_3",
] ]
@pytest.mark.parametrize( def _compare_tp(
("model_name", "parallel_setup", "distributed_backend", model_name: str,
"trust_remote_code", "tokenizer_mode"), parallel_setup: ParallelSetup,
[ distributed_backend: str,
params for model_name, settings in MODEL_SETTINGS.items() trust_remote_code: bool,
for params in settings.iter_params(model_name) tokenizer_mode: Optional[str],
if model_name in TEST_MODELS num_gpus_available: int,
], *,
) method: Literal["generate", "encode"] = "encode",
@fork_new_process_for_each_test ):
def test_compare_tp(model_name: str, parallel_setup: ParallelSetup,
distributed_backend: str, trust_remote_code: bool,
tokenizer_mode: Optional[str], num_gpus_available):
tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup
if num_gpus_available < tp_size: if num_gpus_available < tp_size * pp_size:
pytest.skip(f"Need at least {tp_size} GPUs to run the test") pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
if VLLM_MULTI_NODE and distributed_backend == "mp": if VLLM_MULTI_NODE and distributed_backend == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for " pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend") "multiprocessing distributed backend")
...@@ -286,10 +283,95 @@ def test_compare_tp(model_name: str, parallel_setup: ParallelSetup, ...@@ -286,10 +283,95 @@ def test_compare_tp(model_name: str, parallel_setup: ParallelSetup,
] ]
try: try:
compare_two_settings(model_name, pp_args, tp_args, pp_env) compare_two_settings(model_name,
pp_args,
tp_args,
pp_env,
method=method)
except Exception: except Exception:
if pp_env is None: if pp_env is None:
raise raise
else: else:
# Ray ADAG tests are flaky, so we don't want to fail the test # Ray ADAG tests are flaky, so we don't want to fail the test
logger.exception("Ray ADAG tests failed") logger.exception("Ray ADAG tests failed")
@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend",
"trust_remote_code", "tokenizer_mode"),
[
params for model_name, settings in GENERATION_MODEL_SETTINGS.items()
for params in settings.iter_params(model_name)
if model_name in TEST_MODELS
],
)
@fork_new_process_for_each_test
def test_tp_language_generation(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
trust_remote_code: bool,
tokenizer_mode: Optional[str],
num_gpus_available,
):
_compare_tp(model_name,
parallel_setup,
distributed_backend,
trust_remote_code,
tokenizer_mode,
num_gpus_available,
method="generate")
@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend",
"trust_remote_code", "tokenizer_mode"),
[
params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items()
for params in settings.iter_params(model_name)
if model_name in TEST_MODELS
],
)
@fork_new_process_for_each_test
def test_tp_language_embedding(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
trust_remote_code: bool,
tokenizer_mode: Optional[str],
num_gpus_available,
):
_compare_tp(model_name,
parallel_setup,
distributed_backend,
trust_remote_code,
tokenizer_mode,
num_gpus_available,
method="encode")
@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend",
"trust_remote_code", "tokenizer_mode"),
[
params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items()
for params in settings.iter_params(model_name)
if model_name in TEST_MODELS
],
)
@fork_new_process_for_each_test
def test_tp_multimodal_generation(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
trust_remote_code: bool,
tokenizer_mode: Optional[str],
num_gpus_available,
):
_compare_tp(model_name,
parallel_setup,
distributed_backend,
trust_remote_code,
tokenizer_mode,
num_gpus_available,
method="generate")
...@@ -8,13 +8,13 @@ import time ...@@ -8,13 +8,13 @@ import time
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Literal, Optional, Union
import openai import openai
import pytest import pytest
import requests import requests
from openai.types.completion import Completion from openai.types.completion import Completion
from typing_extensions import ParamSpec from typing_extensions import ParamSpec, assert_never
from tests.models.utils import TextTextLogprobs from tests.models.utils import TextTextLogprobs
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
...@@ -163,11 +163,140 @@ class RemoteOpenAIServer: ...@@ -163,11 +163,140 @@ class RemoteOpenAIServer:
) )
def _test_completion(
client: openai.OpenAI,
model: str,
prompt: str,
token_ids: List[int],
):
results = []
# test with text prompt
completion = client.completions.create(model=model,
prompt=prompt,
max_tokens=5,
temperature=0.0)
results.append({
"test": "single_completion",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})
# test using token IDs
completion = client.completions.create(
model=model,
prompt=token_ids,
max_tokens=5,
temperature=0.0,
)
results.append({
"test": "token_ids",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})
# test seeded random sampling
completion = client.completions.create(model=model,
prompt=prompt,
max_tokens=5,
seed=33,
temperature=1.0)
results.append({
"test": "seeded_sampling",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})
# test seeded random sampling with multiple prompts
completion = client.completions.create(model=model,
prompt=[prompt, prompt],
max_tokens=5,
seed=33,
temperature=1.0)
results.append({
"test":
"seeded_sampling",
"text": [choice.text for choice in completion.choices],
"finish_reason":
[choice.finish_reason for choice in completion.choices],
"usage":
completion.usage,
})
# test simple list
batch = client.completions.create(
model=model,
prompt=[prompt, prompt],
max_tokens=5,
temperature=0.0,
)
results.append({
"test": "simple_list",
"text0": batch.choices[0].text,
"text1": batch.choices[1].text,
})
# test streaming
batch = client.completions.create(
model=model,
prompt=[prompt, prompt],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
results.append({
"test": "streaming",
"texts": texts,
})
return results
def _test_embeddings(
client: openai.OpenAI,
model: str,
text: str,
):
results = []
# test with text input
embeddings = client.embeddings.create(
model=model,
input=text,
encoding_format="float",
)
results.append({
"test": "single_embedding",
"embedding": embeddings.data[0].embedding,
"usage": embeddings.usage,
})
return results
def compare_two_settings(model: str, def compare_two_settings(model: str,
arg1: List[str], arg1: List[str],
arg2: List[str], arg2: List[str],
env1: Optional[Dict[str, str]] = None, env1: Optional[Dict[str, str]] = None,
env2: Optional[Dict[str, str]] = None, env2: Optional[Dict[str, str]] = None,
*,
method: Literal["generate", "encode"] = "generate",
max_wait_seconds: Optional[float] = None) -> None: max_wait_seconds: Optional[float] = None) -> None:
""" """
Launch API server with two different sets of arguments/environments Launch API server with two different sets of arguments/environments
...@@ -219,96 +348,12 @@ def compare_two_settings(model: str, ...@@ -219,96 +348,12 @@ def compare_two_settings(model: str,
"root": served_model.root, "root": served_model.root,
}) })
# test with text prompt if method == "generate":
completion = client.completions.create(model=model, results += _test_completion(client, model, prompt, token_ids)
prompt=prompt, elif method == "encode":
max_tokens=5, results += _test_embeddings(client, model, prompt)
temperature=0.0) else:
assert_never(method)
results.append({
"test": "single_completion",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})
# test using token IDs
completion = client.completions.create(
model=model,
prompt=token_ids,
max_tokens=5,
temperature=0.0,
)
results.append({
"test": "token_ids",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})
# test seeded random sampling
completion = client.completions.create(model=model,
prompt=prompt,
max_tokens=5,
seed=33,
temperature=1.0)
results.append({
"test": "seeded_sampling",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})
# test seeded random sampling with multiple prompts
completion = client.completions.create(model=model,
prompt=[prompt, prompt],
max_tokens=5,
seed=33,
temperature=1.0)
results.append({
"test":
"seeded_sampling",
"text": [choice.text for choice in completion.choices],
"finish_reason":
[choice.finish_reason for choice in completion.choices],
"usage":
completion.usage,
})
# test simple list
batch = client.completions.create(
model=model,
prompt=[prompt, prompt],
max_tokens=5,
temperature=0.0,
)
results.append({
"test": "simple_list",
"text0": batch.choices[0].text,
"text1": batch.choices[1].text,
})
# test streaming
batch = client.completions.create(
model=model,
prompt=[prompt, prompt],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
results.append({
"test": "streaming",
"texts": texts,
})
n = len(results) // 2 n = len(results) // 2
arg1_results = results[:n] arg1_results = results[:n]
......
...@@ -40,7 +40,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -40,7 +40,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (group_weights_with_prefix, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -273,7 +273,7 @@ class Gemma2Model(nn.Module): ...@@ -273,7 +273,7 @@ class Gemma2Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
...@@ -308,6 +308,49 @@ class Gemma2Model(nn.Module): ...@@ -308,6 +308,49 @@ class Gemma2Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
logger.warning(
"Some weights are not initialized from checkpoints: %s",
unloaded_params)
class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
...@@ -391,48 +434,19 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -391,48 +434,19 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ weights_group = group_weights_with_prefix(weights)
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), self.model.load_weights(weights_group["model"])
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), if not self.config.tie_word_embeddings:
("gate_up_proj", "gate_proj", 0), # NOTE: For now self.lm_head is not defined because
("gate_up_proj", "up_proj", 1), # tie_word_embeddings is assumed to the False
] lm_head_dict = dict(self.lm_head.named_parameters())
params_dict = dict(self.named_parameters()) for name, loaded_weight in weights_group["lm_head"]:
loaded_params: Set[str] = set() if is_pp_missing_parameter(name, self.lm_head):
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue continue
param = params_dict[name]
param = lm_head_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
logger.warning(
"Some weights are not initialized from checkpoints: %s",
unloaded_params)
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from .gemma2 import Gemma2Model
from .interfaces import SupportsPP
class Gemma2EmbeddingModel(nn.Module):
class Gemma2EmbeddingModel(nn.Module, SupportsPP):
"""A model that uses Gemma2 with additional embedding functionalities. """A model that uses Gemma2 with additional embedding functionalities.
This class encapsulates the Gemma2Model and provides an interface for This class encapsulates the Gemma2Model and provides an interface for
...@@ -30,6 +31,9 @@ class Gemma2EmbeddingModel(nn.Module): ...@@ -30,6 +31,9 @@ class Gemma2EmbeddingModel(nn.Module):
self.model = Gemma2Model(**kwargs) self.model = Gemma2Model(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
...@@ -38,10 +42,9 @@ class Gemma2EmbeddingModel(nn.Module): ...@@ -38,10 +42,9 @@ class Gemma2EmbeddingModel(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
return self.model.forward(input_ids, positions, kv_caches, return self.model(input_ids, positions, kv_caches, attn_metadata,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
def pooler( def pooler(
self, self,
...@@ -51,32 +54,4 @@ class Gemma2EmbeddingModel(nn.Module): ...@@ -51,32 +54,4 @@ class Gemma2EmbeddingModel(nn.Module):
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ self.model.load_weights(weights)
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.model.named_parameters())
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -51,7 +51,8 @@ from vllm.sequence import IntermediateTensors ...@@ -51,7 +51,8 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, group_weights_with_prefix,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers)
...@@ -347,6 +348,90 @@ class LlamaModel(nn.Module): ...@@ -347,6 +348,90 @@ class LlamaModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn
if is_hip():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
...@@ -372,6 +457,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -372,6 +457,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"gate_proj": ("gate_up_proj", 0), "gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1), "up_proj": ("gate_up_proj", 1),
} }
# Mistral/Llama models can also be loaded with --load-format mistral # Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints # from consolidated.safetensors checkpoints
mistral_mapping = { mistral_mapping = {
...@@ -465,103 +551,38 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -465,103 +551,38 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ weights = [
# (param_name, shard_name, shard_id) self.maybe_remap_mistral(name, loaded_weight)
(".qkv_proj", ".q_proj", "q"), for name, loaded_weight in weights
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)
if "rotary_emb.inv_freq" in name: weights_group = group_weights_with_prefix(weights)
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self): self.model.load_weights(weights_group["model"])
continue
param = params_dict[name] if not self.config.tie_word_embeddings:
weight_loader = param.weight_loader lm_head_dict = dict(self.lm_head.named_parameters())
weight_loader(param, loaded_weight, shard_id) for name, loaded_weight in weights_group["lm_head"]:
if is_pp_missing_parameter(name, self.lm_head):
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue continue
param = params_dict[name] param = lm_head_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None: def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size() self.model.load_kv_cache_scales(quantization_param_path)
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type):
if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip():
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
# This function is used to remap the mistral format as # This function is used to remap the mistral format as
# used by Mistral and Llama <=2 # used by Mistral and Llama <=2
def maybe_remap_mistral( def maybe_remap_mistral(
self, name: str, self,
loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]: name: str,
loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]:
def permute(w, n_heads): def permute(w: torch.Tensor, n_heads: int):
attn_in = self.config.head_dim * n_heads attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size attn_out = self.config.hidden_size
......
...@@ -5,13 +5,11 @@ from torch import nn ...@@ -5,13 +5,11 @@ from torch import nn
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import is_pp_missing_parameter from .llama import LlamaModel
class LlamaEmbeddingModel(nn.Module, SupportsPP): class LlamaEmbeddingModel(nn.Module, SupportsPP):
...@@ -44,9 +42,8 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP): ...@@ -44,9 +42,8 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
return self.model.forward(input_ids, positions, kv_caches, return self.model(input_ids, positions, kv_caches, attn_metadata,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
def pooler( def pooler(
self, self,
...@@ -56,43 +53,7 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP): ...@@ -56,43 +53,7 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP):
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ self.model.load_weights(weights)
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), def load_kv_cache_scales(self, quantization_param_path: str) -> None:
("qkv_proj", "k_proj", "k"), self.model.load_kv_cache_scales(quantization_param_path)
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.model.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -48,7 +48,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -48,7 +48,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, group_weights_with_prefix,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers)
...@@ -300,6 +301,47 @@ class Qwen2Model(nn.Module): ...@@ -300,6 +301,47 @@ class Qwen2Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
...@@ -393,44 +435,17 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -393,44 +435,17 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ weights_group = group_weights_with_prefix(weights)
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), self.model.load_weights(weights_group["model"])
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), if not self.config.tie_word_embeddings:
("gate_up_proj", "gate_proj", 0), lm_head_dict = dict(self.lm_head.named_parameters())
("gate_up_proj", "up_proj", 1), for name, loaded_weight in weights_group["lm_head"]:
] if is_pp_missing_parameter(name, self.lm_head):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue continue
param = params_dict[name]
param = lm_head_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# Copyright 2024 The Qwen team. # Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
"""Inference-only Qwen2-RM model compatible with HuggingFace weights.""" """Inference-only Qwen2-RM model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -15,15 +15,14 @@ from vllm.config import CacheConfig, LoRAConfig ...@@ -15,15 +15,14 @@ from vllm.config import CacheConfig, LoRAConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig) from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from .utils import is_pp_missing_parameter from .interfaces import SupportsPP
from .qwen2 import Qwen2Model
from .utils import group_weights_with_prefix
class ReLU(nn.Module): class ReLU(nn.Module):
...@@ -37,7 +36,7 @@ class ReLU(nn.Module): ...@@ -37,7 +36,7 @@ class ReLU(nn.Module):
return self.activation(input) return self.activation(input)
class Qwen2ForRewardModel(nn.Module): class Qwen2ForRewardModel(nn.Module, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -97,6 +96,9 @@ class Qwen2ForRewardModel(nn.Module): ...@@ -97,6 +96,9 @@ class Qwen2ForRewardModel(nn.Module):
) )
self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False) self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -104,7 +106,7 @@ class Qwen2ForRewardModel(nn.Module): ...@@ -104,7 +106,7 @@ class Qwen2ForRewardModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
logits, _ = self.score(hidden_states) logits, _ = self.score(hidden_states)
...@@ -118,45 +120,13 @@ class Qwen2ForRewardModel(nn.Module): ...@@ -118,45 +120,13 @@ class Qwen2ForRewardModel(nn.Module):
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ weights_group = group_weights_with_prefix(weights)
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), self.model.load_weights(weights_group["model"])
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), score_dict = dict(self.score.named_parameters())
("gate_up_proj", "gate_proj", 0), for name, loaded_weight in weights_group["score"]:
("gate_up_proj", "up_proj", 1), param = score_dict[name]
] weight_loader = getattr(param, "weight_loader",
params_dict = dict(self.named_parameters(remove_duplicate=False)) default_weight_loader)
for name, loaded_weight in weights: weight_loader(param, loaded_weight)
# Skip loading lm_head for embedding model
if name == "lm_head.weight":
continue
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -306,10 +306,12 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]: ...@@ -306,10 +306,12 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
"""Check if a parameter is missing in a pipeline parallel model.""" """Check if a parameter is missing in a pipeline parallel model."""
for missing_layer_name in get_pp_missing_layer_names(model): if isinstance(model, PPMissingLayer):
if name.startswith(missing_layer_name): return True
return True
return False return any(
name.startswith(missing_layer_name)
for missing_layer_name in get_pp_missing_layer_names(model))
def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
......
import dataclasses import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
...@@ -66,7 +67,7 @@ class EmbeddingModelRunner( ...@@ -66,7 +67,7 @@ class EmbeddingModelRunner(
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[PoolerOutput]]: ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
if num_steps > 1: if num_steps > 1:
raise ValueError( raise ValueError(
"EmbeddingModelRunner does not support multi-step execution.") "EmbeddingModelRunner does not support multi-step execution.")
...@@ -107,28 +108,52 @@ class EmbeddingModelRunner( ...@@ -107,28 +108,52 @@ class EmbeddingModelRunner(
for _ in range(num_layers) for _ in range(num_layers)
] ]
execute_model_kwargs = { multi_modal_kwargs = model_input.multi_modal_kwargs or {}
"input_ids": if (self.observability_config is not None
model_input.input_tokens, and self.observability_config.collect_model_forward_time):
"positions": model_forward_start = torch.cuda.Event(enable_timing=True)
model_input.input_positions, model_forward_end = torch.cuda.Event(enable_timing=True)
"kv_caches": model_forward_start.record()
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
}
with set_forward_context(model_input.attn_metadata): with set_forward_context(model_input.attn_metadata):
hidden_states = model_executable(**execute_model_kwargs) hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
device=self.device))
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.record()
# Only perform pooling in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
and hidden_or_intermediate_states is not None
and isinstance(hidden_or_intermediate_states,
IntermediateTensors)
and self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
hidden_or_intermediate_states.tensors["model_forward_time"] = (
torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states
# Only perform pooling in the driver worker. # Only perform pooling in the driver worker.
if not self.is_driver_worker: if not self.is_driver_worker:
return [] return []
return [ return [
self.model.pooler(hidden_states=hidden_states, self.model.pooler(hidden_states=hidden_or_intermediate_states,
pooling_metadata=model_input.pooling_metadata) pooling_metadata=model_input.pooling_metadata)
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment