Unverified Commit 0f6d7a9a authored by Murali Andoorveedu's avatar Murali Andoorveedu Committed by GitHub
Browse files

[Models] Add remaining model PP support (#7168)


Signed-off-by: default avatarMuralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: default avatarMurali Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 303d4479
......@@ -146,7 +146,9 @@ steps:
source_file_dependencies:
- vllm/
- tests/test_regression
command: pytest -v -s test_regression.py
commands:
- pip install modelscope
- pytest -v -s test_regression.py
working_dir: "/vllm-workspace/tests" # optional
- label: Engine Test # 10min
......
......@@ -12,201 +12,249 @@ Alongside each architecture, we include some popular models that use it.
Decoder-only Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. list-table::
:widths: 25 25 50 5
:widths: 25 25 50 5 5
:header-rows: 1
* - Architecture
- Models
- Example HuggingFace Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`AquilaForCausalLM`
- Aquila, Aquila2
- :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc.
- ✅︎
- ✅︎
* - :code:`ArcticForCausalLM`
- Arctic
- :code:`Snowflake/snowflake-arctic-base`, :code:`Snowflake/snowflake-arctic-instruct`, etc.
-
- ✅︎
* - :code:`BaiChuanForCausalLM`
- Baichuan2, Baichuan
- :code:`baichuan-inc/Baichuan2-13B-Chat`, :code:`baichuan-inc/Baichuan-7B`, etc.
- ✅︎
- ✅︎
* - :code:`BloomForCausalLM`
- BLOOM, BLOOMZ, BLOOMChat
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
-
- ✅︎
* - :code:`ChatGLMModel`
- ChatGLM
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
- ✅︎
- ✅︎
* - :code:`CohereForCausalLM`
- Command-R
- :code:`CohereForAI/c4ai-command-r-v01`, etc.
-
- ✅︎
- ✅︎
* - :code:`DbrxForCausalLM`
- DBRX
- :code:`databricks/dbrx-base`, :code:`databricks/dbrx-instruct`, etc.
-
- ✅︎
* - :code:`DeciLMForCausalLM`
- DeciLM
- :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc.
-
- ✅︎
* - :code:`DeepseekForCausalLM`
- DeepSeek
- :code:`deepseek-ai/deepseek-llm-67b-base`, :code:`deepseek-ai/deepseek-llm-7b-chat` etc.
-
- ✅︎
* - :code:`DeepseekV2ForCausalLM`
- DeepSeek-V2
- :code:`deepseek-ai/DeepSeek-V2`, :code:`deepseek-ai/DeepSeek-V2-Chat` etc.
-
- ✅︎
* - :code:`ExaoneForCausalLM`
- EXAONE-3
- :code:`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc.
- ✅︎
- ✅︎
* - :code:`FalconForCausalLM`
- Falcon
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
-
- ✅︎
* - :code:`GemmaForCausalLM`
- Gemma
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
- ✅︎
- ✅︎
* - :code:`Gemma2ForCausalLM`
- Gemma2
- :code:`google/gemma-2-9b`, :code:`google/gemma-2-27b`, etc.
- ✅︎
- ✅︎
* - :code:`GPT2LMHeadModel`
- GPT-2
- :code:`gpt2`, :code:`gpt2-xl`, etc.
-
- ✅︎
* - :code:`GPTBigCodeForCausalLM`
- StarCoder, SantaCoder, WizardCoder
- :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc.
- ✅︎
- ✅︎
* - :code:`GPTJForCausalLM`
- GPT-J
- :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc.
-
- ✅︎
* - :code:`GPTNeoXForCausalLM`
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
-
- ✅︎
* - :code:`GraniteForCausalLM`
- PowerLM
- :code:`ibm/PowerLM-3b` etc.
- ✅︎
- ✅︎
* - :code:`GraniteMoeForCausalLM`
- PowerMoE
- :code:`ibm/PowerMoE-3b` etc.
- ✅︎
- ✅︎
* - :code:`InternLMForCausalLM`
- InternLM
- :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc.
- ✅︎
- ✅︎
* - :code:`InternLM2ForCausalLM`
- InternLM2
- :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc.
-
- ✅︎
* - :code:`JAISLMHeadModel`
- Jais
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
-
- ✅︎
* - :code:`JambaForCausalLM`
- Jamba
- :code:`ai21labs/AI21-Jamba-1.5-Large`, :code:`ai21labs/AI21-Jamba-1.5-Mini`, :code:`ai21labs/Jamba-v0.1`, etc.
- ✅︎
-
* - :code:`LlamaForCausalLM`
- Llama 3.1, Llama 3, Llama 2, LLaMA, Yi
- :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc.
- ✅︎
- ✅︎
* - :code:`MiniCPMForCausalLM`
- MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
-
- ✅︎
- ✅︎
* - :code:`MiniCPM3ForCausalLM`
- MiniCPM3
- :code:`openbmb/MiniCPM3-4B`, etc.
-
- ✅︎
- ✅︎
* - :code:`MistralForCausalLM`
- Mistral, Mistral-Instruct
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
- ✅︎
- ✅︎
* - :code:`MixtralForCausalLM`
- Mixtral-8x7B, Mixtral-8x7B-Instruct
- :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, :code:`mistral-community/Mixtral-8x22B-v0.1`, etc.
- ✅︎
- ✅︎
* - :code:`MPTForCausalLM`
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
-
- ✅︎
* - :code:`NemotronForCausalLM`
- Nemotron-3, Nemotron-4, Minitron
- :code:`nvidia/Minitron-8B-Base`, :code:`mgoin/Nemotron-4-340B-Base-hf-FP8`, etc.
- ✅︎
* - :code:`OLMoEForCausalLM`
- OLMoE
- :code:`allenai/OLMoE-1B-7B-0924`, :code:`allenai/OLMoE-1B-7B-0924-Instruct`, etc.
-
- ✅︎
* - :code:`OLMoForCausalLM`
- OLMo
- :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc.
-
- ✅︎
* - :code:`OLMoEForCausalLM`
- OLMoE
- :code:`allenai/OLMoE-1B-7B-0924`, :code:`allenai/OLMoE-1B-7B-0924-Instruct`, etc.
- ✅︎
- ✅︎
* - :code:`OPTForCausalLM`
- OPT, OPT-IML
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
-
- ✅︎
* - :code:`OrionForCausalLM`
- Orion
- :code:`OrionStarAI/Orion-14B-Base`, :code:`OrionStarAI/Orion-14B-Chat`, etc.
-
- ✅︎
* - :code:`PhiForCausalLM`
- Phi
- :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc.
- ✅︎
- ✅︎
* - :code:`Phi3ForCausalLM`
- Phi-3
- :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, :code:`microsoft/Phi-3-medium-128k-instruct`, etc.
-
- ✅︎
- ✅︎
* - :code:`Phi3SmallForCausalLM`
- Phi-3-Small
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
-
- ✅︎
* - :code:`PhiMoEForCausalLM`
- Phi-3.5-MoE
- :code:`microsoft/Phi-3.5-MoE-instruct`, etc.
-
- ✅︎
- ✅︎
* - :code:`PersimmonForCausalLM`
- Persimmon
- :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc.
-
- ✅︎
* - :code:`QWenLMHeadModel`
- Qwen
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
-
- ✅︎
* - :code:`Qwen2ForCausalLM`
- Qwen2
- :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc.
- ✅︎
- ✅︎
* - :code:`Qwen2MoeForCausalLM`
- Qwen2MoE
- :code:`Qwen/Qwen1.5-MoE-A2.7B`, :code:`Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.
-
- ✅︎
* - :code:`StableLmForCausalLM`
- StableLM
- :code:`stabilityai/stablelm-3b-4e1t`, :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc.
-
- ✅︎
* - :code:`Starcoder2ForCausalLM`
- Starcoder2
- :code:`bigcode/starcoder2-3b`, :code:`bigcode/starcoder2-7b`, :code:`bigcode/starcoder2-15b`, etc.
-
- ✅︎
* - :code:`SolarForCausalLM`
- EXAONE-3
- Solar Pro
- :code:`upstage/solar-pro-preview-instruct`, etc.
-
- ✅︎
- ✅︎
* - :code:`XverseForCausalLM`
- Xverse
- XVERSE
- :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc.
-
- ✅︎
- ✅︎
.. note::
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
......@@ -217,7 +265,7 @@ Multimodal Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. list-table::
:widths: 25 25 25 25 5
:widths: 25 25 25 25 5 5
:header-rows: 1
* - Architecture
......@@ -225,86 +273,103 @@ Multimodal Language Models
- Modalities
- Example HuggingFace Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`Blip2ForConditionalGeneration`
- BLIP-2
- Image\ :sup:`E`
- :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc.
-
- ✅︎
* - :code:`ChameleonForConditionalGeneration`
- Chameleon
- Image
- :code:`facebook/chameleon-7b` etc.
-
- ✅︎
* - :code:`FuyuForCausalLM`
- Fuyu
- Image
- :code:`adept/fuyu-8b` etc.
-
- ✅︎
* - :code:`InternVLChatModel`
- InternVL2
- Image\ :sup:`E+`
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
-
- ✅︎
* - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5
- Image\ :sup:`E+`
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
-
- ✅︎
* - :code:`LlavaNextForConditionalGeneration`
- LLaVA-NeXT
- Image\ :sup:`E+`
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
-
- ✅︎
* - :code:`LlavaNextVideoForConditionalGeneration`
- LLaVA-NeXT-Video
- Video
- :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc.
-
- ✅︎
* - :code:`LlavaOnevisionForConditionalGeneration`
- LLaVA-Onevision
- Image\ :sup:`+` / Video
- :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc.
-
- ✅︎
* - :code:`MiniCPMV`
- MiniCPM-V
- Image\ :sup:`+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
- ✅︎
- ✅︎
* - :code:`MllamaForConditionalGeneration`
- Llama 3.2
- Image
- :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc.
-
-
* - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma
- Image\ :sup:`E`
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc.
-
- ✅︎
* - :code:`Phi3VForCausalLM`
- Phi-3-Vision, Phi-3.5-Vision
- Image\ :sup:`E+`
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
-
- ✅︎
* - :code:`PixtralForConditionalGeneration`
- Pixtral
- Image\ :sup:`+`
- :code:`mistralai/Pixtral-12B-2409`
-
- ✅︎
* - :code:`QWenLMHeadModel`
- Qwen-VL
- Image\ :sup:`E+`
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
-
- ✅︎
* - :code:`Qwen2VLForConditionalGeneration`
- Qwen2-VL
- Image\ :sup:`E+` / Video\ :sup:`+`
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
-
- ✅︎
* - :code:`UltravoxModel`
- Ultravox
- Audio\ :sup:`E+`
- :code:`fixie-ai/ultravox-v0_3`
-
- ✅︎
| :sup:`E` Pre-computed embeddings can be inputted for this modality.
| :sup:`+` Multiple items can be inputted per text prompt for this modality.
......
......@@ -10,8 +10,8 @@ pytest-shard
awscli
einops # required for MPT, qwen-vl and Mamba
httpx
librosa # required for audio test
opencv-python # required for video test
librosa # required for audio tests
opencv-python # required for video tests
peft
requests
ray[adag]==2.35
......
......@@ -6,6 +6,8 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
to fail.
"""
import os
from dataclasses import dataclass
from typing import List, NamedTuple, Optional
import pytest
......@@ -18,49 +20,256 @@ logger = init_logger("test_pipeline_parallel")
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
eager_mode: bool
chunked_prefill: bool
@dataclass
class PPTestSettings:
parallel_setups: List[ParallelSetup]
distributed_backends: List[str]
trust_remote_code: bool
tokenizer_mode: Optional[str]
@staticmethod
def detailed(
*,
tp_base: int = 1,
pp_base: int = 2,
trust_remote_code: bool = False,
tokenizer_mode: Optional[str] = None,
):
return PPTestSettings(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
eager_mode=False,
chunked_prefill=False),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
eager_mode=False,
chunked_prefill=True),
ParallelSetup(tp_size=tp_base,
pp_size=2 * pp_base,
eager_mode=True,
chunked_prefill=False),
ParallelSetup(tp_size=2 * tp_base,
pp_size=pp_base,
eager_mode=False,
chunked_prefill=True),
ParallelSetup(tp_size=2 * tp_base,
pp_size=pp_base,
eager_mode=True,
chunked_prefill=False),
],
distributed_backends=["mp", "ray"],
trust_remote_code=trust_remote_code,
tokenizer_mode=tokenizer_mode,
)
@staticmethod
def fast(
*,
tp_base: int = 1,
pp_base: int = 2,
trust_remote_code: bool = False,
tokenizer_mode: Optional[str] = None,
):
return PPTestSettings(
parallel_setups=[
ParallelSetup(tp_size=tp_base,
pp_size=pp_base,
eager_mode=True,
chunked_prefill=False),
],
distributed_backends=["mp"],
trust_remote_code=trust_remote_code,
tokenizer_mode=tokenizer_mode,
)
def iter_params(self, model_name: str):
for parallel_setup in self.parallel_setups:
for distributed_backend in self.distributed_backends:
yield (model_name, parallel_setup, distributed_backend,
self.trust_remote_code, self.tokenizer_mode)
# yapf: disable
GENERATION_MODEL_SETTINGS = {
# [DETAILED TESTS]
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
# [FAST TESTS]
# Uses Llama
# "BAAI/AquilaChat-7B": PPTestSettings.fast(),
# TODO: Test on larger GPU
# "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"baichuan-inc/Baichuan-7B": PPTestSettings.fast(trust_remote_code=True),
"baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"bigscience/bloomz-1b1": PPTestSettings.fast(),
"THUDM/chatglm3-6b": PPTestSettings.fast(trust_remote_code=True),
"CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(tp_base=2, trust_remote_code=True), # noqa: E501
# TODO: Test on larger GPU
# "databricks/dbrx-instruct": PPTestSettings.fast(),
"Deci/DeciLM-7B-instruct": PPTestSettings.fast(trust_remote_code=True),
"deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
"deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(),
"tiiuae/falcon-7b": PPTestSettings.fast(),
"google/gemma-2b": PPTestSettings.fast(),
"google/gemma-2-9b": PPTestSettings.fast(),
"gpt2": PPTestSettings.fast(),
"bigcode/starcoder": PPTestSettings.fast(),
"EleutherAI/gpt-j-6b": PPTestSettings.fast(),
"EleutherAI/pythia-12b": PPTestSettings.fast(),
"ibm/PowerLM-3b": PPTestSettings.fast(),
"ibm/PowerMoE-3b": PPTestSettings.fast(),
# Uses Llama
# "internlm/internlm-chat-7b": PPTestSettings.fast(),
"internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True),
"core42/jais-13b-chat": PPTestSettings.fast(),
# TODO: Implement PP
# "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(),
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True),
"openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True),
# Uses Llama
# "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(),
"mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4),
"mosaicml/mpt-7b": PPTestSettings.fast(),
"nvidia/Minitron-8B-Base": PPTestSettings.fast(),
"allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(),
"allenai/OLMo-1B-hf": PPTestSettings.fast(),
"facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
"OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
"microsoft/phi-2": PPTestSettings.fast(),
"microsoft/Phi-3-mini-4k-instruct": PPTestSettings.fast(),
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
# FIXME: https://github.com/vllm-project/vllm/issues/8553
# "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"adept/persimmon-8b-chat": PPTestSettings.fast(),
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
"Qwen/Qwen2-beta-7B-Chat": PPTestSettings.fast(),
"Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
"stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
"bigcode/starcoder2-3b": PPTestSettings.fast(),
"upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2),
# FIXME: Cannot load tokenizer in latest transformers version
# "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
}
EMBEDDING_MODEL_SETTINGS = { # type: ignore[var-annotated]
# [FAST TESTS]
# Uses Llama
# "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
}
MULTIMODAL_MODEL_SETTINGS = {
# [FAST TESTS]
"Salesforce/blip2-opt-2.7b": PPTestSettings.fast(),
"facebook/chameleon-7b": PPTestSettings.fast(),
"adept/fuyu-8b": PPTestSettings.fast(),
"OpenGVLab/InternVL2-1B": PPTestSettings.fast(trust_remote_code=True),
"llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(),
"llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(),
"llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(),
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(trust_remote_code=True),
# TODO: Implement PP
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
"microsoft/Phi-3-vision-128k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(tp_base=2, tokenizer_mode="mistral"), # noqa: E501
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(),
}
CONDITIONAL_GENERATION_MODEL_SETTINGS = { # type: ignore[var-annotated]
# [FAST TESTS]
# TODO: Implement PP
# "facebook/bart-base": PPTestSettings.fast(),
}
# yapf: enable
MODEL_SETTINGS = {
**GENERATION_MODEL_SETTINGS,
**EMBEDDING_MODEL_SETTINGS,
**MULTIMODAL_MODEL_SETTINGS,
}
# You can update this on your local machine to run specific tests
TEST_MODELS = [
"meta-llama/Meta-Llama-3-8B",
"facebook/chameleon-7b",
"OpenGVLab/InternVL2-1B",
"microsoft/Phi-3-vision-128k-instruct",
"mistralai/Pixtral-12B-2409",
"fixie-ai/ultravox-v0_3",
]
@pytest.mark.parametrize(
("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, "
"MODEL_NAME, DIST_BACKEND"),
("model_name", "parallel_setup", "distributed_backend",
"trust_remote_code", "tokenizer_mode"),
[
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
# NOTE: InternVL2 multi-node tests are flaky,
# use mp backend to skip the multi-node tests
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"),
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"),
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"),
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp"),
# TP only models
(2, 1, 1, 0, 0, "adept/fuyu-8b", "mp"),
params for model_name, settings in MODEL_SETTINGS.items()
for params in settings.iter_params(model_name)
if model_name in TEST_MODELS
],
)
@fork_new_process_for_each_test
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND):
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
def test_compare_tp(model_name: str, parallel_setup: ParallelSetup,
distributed_backend: str, trust_remote_code: bool,
tokenizer_mode: Optional[str], num_gpus_available):
tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup
if num_gpus_available < tp_size:
pytest.skip(f"Need at least {tp_size} GPUs to run the test")
if VLLM_MULTI_NODE and distributed_backend == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
pp_args = [
common_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"float16",
"--max-model-len",
"8192",
"2048",
"--max-num-seqs",
"8",
]
if chunked_prefill:
common_args.append("--enable-chunked-prefill")
if eager_mode:
common_args.append("--enforce-eager")
if trust_remote_code:
common_args.append("--trust-remote-code")
if tokenizer_mode:
common_args.extend(["--tokenizer-mode", tokenizer_mode])
if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2
and chunked_prefill):
# Test Ray ADAG for a subset of the tests
pp_env = {
"VLLM_USE_RAY_COMPILED_DAG": "1",
"VLLM_USE_RAY_SPMD_WORKER": "1",
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
}
# Temporary. Currently when zeromq + SPMD is used, it does not properly
# terminate because of aDAG issue.
common_args.append("--disable-frontend-multiprocessing")
else:
pp_env = None
pp_args = [
*common_args,
"--pipeline-parallel-size",
str(PP_SIZE),
str(pp_size),
"--tensor-parallel-size",
str(TP_SIZE),
str(tp_size),
"--distributed-executor-backend",
DIST_BACKEND,
distributed_backend,
]
# compare without pipeline parallelism
......@@ -69,41 +278,15 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
# schedule all workers in a node other than the head node,
# which can cause the test to fail.
tp_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"float16",
"--max-model-len",
"8192",
*common_args,
"--tensor-parallel-size",
str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI.
str(tp_size),
"--distributed-executor-backend",
"mp",
]
if CHUNKED_PREFILL:
pp_args.append("--enable-chunked-prefill")
tp_args.append("--enable-chunked-prefill")
if EAGER_MODE:
pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager")
if TRUST_REMOTE_CODE:
pp_args.append("--trust-remote-code")
tp_args.append("--trust-remote-code")
pp_env = None
if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2
and CHUNKED_PREFILL):
# Test Ray ADAG for a subset of the tests
pp_env = {
"VLLM_USE_RAY_COMPILED_DAG": "1",
"VLLM_USE_RAY_SPMD_WORKER": "1",
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
}
# Temporary. Currently when zeromq + SPMD is used, it does not properly
# terminate because of aDAG issue.
pp_args.append("--disable-frontend-multiprocessing")
tp_args.append("--disable-frontend-multiprocessing")
try:
compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env)
compare_two_settings(model_name, pp_args, tp_args, pp_env)
except Exception:
if pp_env is None:
raise
......
import warnings
import pytest
import torch.cuda
from vllm.model_executor.models import _MODELS, ModelRegistry
from vllm.platforms import current_platform
from ..utils import fork_new_process_for_each_test
@pytest.mark.parametrize("model_cls", _MODELS)
def test_registry_imports(model_cls):
@pytest.mark.parametrize("model_arch", _MODELS)
def test_registry_imports(model_arch):
# Ensure all model classes can be imported successfully
ModelRegistry.resolve_model_cls([model_cls])
ModelRegistry.resolve_model_cls(model_arch)
@fork_new_process_for_each_test
@pytest.mark.parametrize("model_arch,is_mm,init_cuda", [
("LlamaForCausalLM", False, False),
("MllamaForConditionalGeneration", True, False),
("LlavaForConditionalGeneration", True, True),
])
def test_registry_is_multimodal(model_arch, is_mm, init_cuda):
assert ModelRegistry.is_multimodal_model(model_arch) is is_mm
if init_cuda and current_platform.is_cuda_alike():
assert not torch.cuda.is_initialized()
ModelRegistry.resolve_model_cls(model_arch)
if not torch.cuda.is_initialized():
warnings.warn(
"This model no longer initializes CUDA on import. "
"Please test using a different one.",
stacklevel=2)
@fork_new_process_for_each_test
@pytest.mark.parametrize("model_arch,is_pp,init_cuda", [
("MLPSpeculatorPreTrainedModel", False, False),
("DeepseekV2ForCausalLM", True, False),
("Qwen2VLForConditionalGeneration", True, True),
])
def test_registry_is_pp(model_arch, is_pp, init_cuda):
assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp
if init_cuda and current_platform.is_cuda_alike():
assert not torch.cuda.is_initialized()
ModelRegistry.resolve_model_cls(model_arch)
if not torch.cuda.is_initialized():
warnings.warn(
"This model no longer initializes CUDA on import. "
"Please test using a different one.",
stacklevel=2)
......@@ -14,7 +14,6 @@ import openai
import pytest
import requests
from openai.types.completion import Completion
from transformers import AutoTokenizer
from typing_extensions import ParamSpec
from tests.models.utils import TextTextLogprobs
......@@ -24,6 +23,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import (FlexibleArgumentParser, GB_bytes,
cuda_device_count_stateless, get_open_port, is_hip)
......@@ -181,15 +181,26 @@ def compare_two_settings(model: str,
env2: The second set of environment variables to pass to the API server.
"""
trust_remote_code = "--trust-remote-code"
if trust_remote_code in arg1 or trust_remote_code in arg2:
tokenizer = AutoTokenizer.from_pretrained(model,
trust_remote_code=True)
else:
tokenizer = AutoTokenizer.from_pretrained(model)
trust_remote_code = False
for args in (arg1, arg2):
if "--trust-remote-code" in args:
trust_remote_code = True
break
tokenizer_mode = "auto"
for args in (arg1, arg2):
if "--tokenizer-mode" in args:
tokenizer_mode = args[args.index("--tokenizer-mode") + 1]
break
tokenizer = get_tokenizer(
model,
trust_remote_code=trust_remote_code,
tokenizer_mode=tokenizer_mode,
)
prompt = "Hello, my name is"
token_ids = tokenizer(prompt)["input_ids"]
token_ids = tokenizer(prompt).input_ids
results = []
for args, env in ((arg1, env1), (arg2, env2)):
with RemoteOpenAIServer(model,
......
......@@ -31,28 +31,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096
_PP_SUPPORTED_MODELS = [
"AquilaForCausalLM",
"AquilaModel",
"DeepseekV2ForCausalLM",
"GPT2LMHeadModel",
"InternLM2ForCausalLM",
"InternLMForCausalLM",
"InternVLChatModel",
"JAISLMHeadModel",
"LlamaForCausalLM",
"LLaMAForCausalLM",
"MistralForCausalLM",
"MixtralForCausalLM",
"NemotronForCausalLM",
"Phi3ForCausalLM",
"Qwen2ForCausalLM",
"Qwen2MoeForCausalLM",
"QWenLMHeadModel",
"Qwen2VLForConditionalGeneration",
]
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
class ModelConfig:
......@@ -228,15 +207,13 @@ class ModelConfig:
self, limit_mm_per_prompt: Optional[Mapping[str, int]]
) -> Optional["MultiModalConfig"]:
architectures = getattr(self.hf_config, "architectures", [])
if any(
ModelRegistry.is_multimodal_model(arch)
for arch in architectures):
if ModelRegistry.is_multimodal_model(architectures):
return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
else:
if limit_mm_per_prompt:
raise ValueError(
"limit_mm_per_prompt is only supported for multimodal "
"models.")
raise ValueError("`limit_mm_per_prompt` is only supported for "
"multimodal models.")
return None
def _verify_tokenizer_mode(self) -> None:
......@@ -249,8 +226,7 @@ class ModelConfig:
def _verify_embedding_mode(self) -> None:
architectures = getattr(self.hf_config, "architectures", [])
self.embedding_mode = any(
ModelRegistry.is_embedding_model(arch) for arch in architectures)
self.embedding_mode = ModelRegistry.is_embedding_model(architectures)
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
......@@ -417,14 +393,14 @@ class ModelConfig:
f"({tensor_parallel_size}).")
pipeline_parallel_size = parallel_config.pipeline_parallel_size
if pipeline_parallel_size > 1:
architectures = getattr(self.hf_config, "architectures", [])
if not all(arch in _PP_SUPPORTED_MODELS
for arch in architectures) and pipeline_parallel_size > 1:
if not ModelRegistry.is_pp_supported_model(architectures):
raise NotImplementedError(
"Pipeline parallelism is only supported for the following "
f" architectures: {_PP_SUPPORTED_MODELS}.")
"Pipeline parallelism is not supported for this model. "
"Supported models implement the `SupportsPP` interface.")
if pipeline_parallel_size > 1 and self.use_async_output_proc:
if self.use_async_output_proc:
logger.warning("Async output processor is not supported with "
"pipeline parallelism currently. Disabling it.")
self.use_async_output_proc = False
......
import functools
import importlib
from typing import Dict, List, Optional, Tuple, Type
import string
import subprocess
import sys
import uuid
from functools import lru_cache, partial
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_hip
from .interfaces import supports_multimodal, supports_pp
logger = init_logger(__name__)
_GENERATION_MODELS = {
......@@ -152,19 +158,25 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
class ModelRegistry:
@staticmethod
@functools.lru_cache(maxsize=128)
def _get_model(model_arch: str):
module_name, model_cls_name = _MODELS[model_arch]
module = importlib.import_module(
f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None)
def _get_module_cls_name(model_arch: str) -> Tuple[str, str]:
module_relname, cls_name = _MODELS[model_arch]
return f"vllm.model_executor.models.{module_relname}", cls_name
@staticmethod
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
@lru_cache(maxsize=128)
def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch not in _MODELS:
return None
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
module = importlib.import_module(module_name)
return getattr(module, cls_name, None)
@staticmethod
def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
if is_hip():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
......@@ -175,11 +187,24 @@ class ModelRegistry:
"Model architecture %s is partially supported by ROCm: %s",
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
return ModelRegistry._get_model(model_arch)
return None
@staticmethod
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
model = ModelRegistry._try_get_model_stateless(model_arch)
if model is not None:
return model
return ModelRegistry._try_get_model_stateful(model_arch)
@staticmethod
def resolve_model_cls(
architectures: List[str]) -> Tuple[Type[nn.Module], str]:
architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
for arch in architectures:
model_cls = ModelRegistry._try_load_model_cls(arch)
if model_cls is not None:
......@@ -200,21 +225,99 @@ class ModelRegistry:
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls.__name__)
global _OOT_MODELS
_OOT_MODELS[model_arch] = model_cls
@staticmethod
def is_embedding_model(model_arch: str) -> bool:
return model_arch in _EMBEDDING_MODELS
@lru_cache(maxsize=128)
def _check_stateless(
func: Callable[[Type[nn.Module]], bool],
model_arch: str,
*,
default: Optional[bool] = None,
) -> bool:
"""
Run a boolean function against a model and return the result.
If the model is not found, returns the provided default value.
If the model is not already imported, the function is run inside a
subprocess to avoid initializing CUDA for the main program.
"""
model = ModelRegistry._try_get_model_stateless(model_arch)
if model is not None:
return func(model)
if model_arch not in _MODELS and default is not None:
return default
module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
valid_name_characters = string.ascii_letters + string.digits + "._"
if any(s not in valid_name_characters for s in module_name):
raise ValueError(f"Unsafe module name detected for {model_arch}")
if any(s not in valid_name_characters for s in cls_name):
raise ValueError(f"Unsafe class name detected for {model_arch}")
if any(s not in valid_name_characters for s in func.__module__):
raise ValueError(f"Unsafe module name detected for {func}")
if any(s not in valid_name_characters for s in func.__name__):
raise ValueError(f"Unsafe class name detected for {func}")
err_id = uuid.uuid4()
stmts = ";".join([
f"from {module_name} import {cls_name}",
f"from {func.__module__} import {func.__name__}",
f"assert {func.__name__}({cls_name}), '{err_id}'",
])
result = subprocess.run([sys.executable, "-c", stmts],
capture_output=True)
if result.returncode != 0:
err_lines = [line.decode() for line in result.stderr.splitlines()]
if err_lines and err_lines[-1] != f"AssertionError: {err_id}":
err_str = "\n".join(err_lines)
raise RuntimeError(
"An unexpected error occurred while importing the model in "
f"another process. Error log:\n{err_str}")
return result.returncode == 0
@staticmethod
def is_multimodal_model(model_arch: str) -> bool:
def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
return any(arch in _EMBEDDING_MODELS for arch in architectures)
@staticmethod
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
is_mm = partial(ModelRegistry._check_stateless,
supports_multimodal,
default=False)
return any(is_mm(arch) for arch in architectures)
@staticmethod
def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
is_pp = partial(ModelRegistry._check_stateless,
supports_pp,
default=False)
# TODO: find a way to avoid initializing CUDA prematurely to
# use `supports_multimodal` to determine if a model is multimodal
# model_cls = ModelRegistry._try_load_model_cls(model_arch)
# from vllm.model_executor.models.interfaces import supports_multimodal
return model_arch in _MULTIMODAL_MODELS
return any(is_pp(arch) for arch in architectures)
__all__ = [
......
"""Inference-only Snowflake Arctic model."""
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
......@@ -18,8 +18,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig, DeepSpeedFPParameter)
from vllm.model_executor.layers.rotary_embedding import get_rope
......@@ -32,6 +31,10 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.arctic import ArcticConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__)
......@@ -364,6 +367,7 @@ class ArcticModel(nn.Module):
config: ArcticConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
......@@ -372,15 +376,16 @@ class ArcticModel(nn.Module):
self.vocab_size,
config.hidden_size,
org_num_embeddings=self.vocab_size)
self.layers = nn.ModuleList([
ArcticDecoderLayer(config,
layer_idx,
cache_config,
quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: ArcticDecoderLayer(config, int(
prefix.split(".")[-1]), cache_config, quant_config),
prefix=f"{prefix}.layers")
self._attn_implementation = config._attn_implementation
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
......@@ -388,17 +393,25 @@ class ArcticModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states, kv_caches[i],
hidden_states = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm(hidden_states)
return hidden_states
class ArcticForCausalLM(nn.Module):
class ArcticForCausalLM(nn.Module, SupportsPP):
def __init__(self,
config: ArcticConfig,
......@@ -422,6 +435,8 @@ class ArcticForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
......@@ -430,9 +445,9 @@ class ArcticForCausalLM(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -503,6 +518,8 @@ class ArcticForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -512,6 +529,8 @@ class ArcticForCausalLM(nn.Module):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -522,6 +541,8 @@ class ArcticForCausalLM(nn.Module):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
......@@ -532,6 +553,8 @@ class ArcticForCausalLM(nn.Module):
else:
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
......
......@@ -19,7 +19,7 @@
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -27,7 +27,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -35,8 +35,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -45,7 +44,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
......@@ -255,7 +256,8 @@ class BaiChuanModel(nn.Module):
config: PretrainedConfig,
position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
......@@ -265,12 +267,16 @@ class BaiChuanModel(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding, cache_config,
quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: BaiChuanDecoderLayer(config, position_embedding,
cache_config, quant_config),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
......@@ -278,23 +284,34 @@ class BaiChuanModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual,
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"W_pack": ["W_pack"],
"gate_up_proj": [
......@@ -335,6 +352,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
......@@ -343,9 +362,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -394,6 +413,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -402,6 +423,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......@@ -413,7 +436,7 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__(
self,
config,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
......@@ -431,7 +454,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__(
self,
config,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
......
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
......@@ -11,7 +12,7 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
......@@ -19,7 +20,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens)
from .interfaces import SupportsMultiModal
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
merge_multimodal_embeddings)
......@@ -475,7 +476,7 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self,
config: Blip2Config,
......@@ -508,6 +509,16 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
......@@ -600,7 +611,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
) -> Union[SamplerOutput, IntermediateTensors]:
"""Run forward pass for BLIP-2.
One key thing to understand is the `input_ids` already accounts for the
......@@ -631,6 +642,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
See also:
:class:`Blip2ImageInputs`
"""
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
......@@ -646,10 +661,12 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
hidden_states = self.language_model.model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
......
......@@ -17,7 +17,7 @@
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -25,15 +25,14 @@ from transformers import BloomConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
......@@ -41,6 +40,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
......@@ -222,6 +225,7 @@ class BloomModel(nn.Module):
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.embed_dim = config.hidden_size
......@@ -235,13 +239,16 @@ class BloomModel(nn.Module):
self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks
self.h = nn.ModuleList([
BloomBlock(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: BloomBlock(config, cache_config, quant_config),
prefix=f"{prefix}.h")
# Final Layer Norm
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
......@@ -249,22 +256,29 @@ class BloomModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
for i in range(len(self.h)):
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
class BloomForCausalLM(nn.Module):
class BloomForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
......@@ -284,6 +298,8 @@ class BloomForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
......@@ -292,9 +308,9 @@ class BloomForCausalLM(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -321,6 +337,8 @@ class BloomForCausalLM(nn.Module):
continue
if not name.startswith("transformer."):
name = "transformer." + name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
......
from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict)
Tuple, TypedDict, Union)
import torch
import torch.nn.functional as F
......@@ -10,7 +10,7 @@ from transformers import ChameleonConfig, ChameleonVQVAEConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -33,7 +33,9 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
# These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now.
......@@ -822,6 +824,7 @@ class ChameleonModel(nn.Module):
config: ChameleonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
......@@ -835,14 +838,20 @@ class ChameleonModel(nn.Module):
config.vocabulary_map)
decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \
else ChameleonSwinDecoderLayer
self.layers = nn.ModuleList([
decoder_layer(config=config,
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: decoder_layer(config=config,
cache_config=cache_config,
quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
quant_config=quant_config),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.vqmodel = ChameleonVQVAE(config.vq_config)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......@@ -865,22 +874,33 @@ class ChameleonModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
......@@ -889,7 +909,8 @@ class ChameleonModel(nn.Module):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(
self,
......@@ -914,6 +935,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
......@@ -956,22 +979,26 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
assert self.model.vqmodel is not None
image_tokens = self.model.get_image_tokens(image_input["data"].to(
self.config.torch_dtype))
image_tokens = self.model.get_image_tokens(
image_input["data"].to(self.config.torch_dtype))
image_token_id = self.model.vocabulary_mapping.image_token_id
special_image_mask = input_ids == image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
image_tokens = image_tokens.to(input_ids.device,
input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask,
image_tokens)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -1039,6 +1066,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -1060,11 +1089,15 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
continue
else:
name = remapped_kv_scale_name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if use_default_weight_loading and name in params_dict:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
......@@ -2,7 +2,7 @@
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -10,15 +10,14 @@ from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -28,14 +27,16 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GLMAttention(nn.Module):
def __init__(
self,
config,
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
......@@ -126,7 +127,7 @@ class GLMMLP(nn.Module):
def __init__(
self,
config,
config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -169,7 +170,7 @@ class GLMBlock(nn.Module):
def __init__(
self,
config,
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
......@@ -240,9 +241,10 @@ class GLMTransformer(nn.Module):
def __init__(
self,
config,
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.post_layer_norm = config.post_layer_norm
......@@ -251,10 +253,11 @@ class GLMTransformer(nn.Module):
self.num_layers = config.num_layers
# Transformer layers.
self.layers = nn.ModuleList([
GLMBlock(config, cache_config, quant_config)
for i in range(self.num_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
self.num_layers,
lambda prefix: GLMBlock(config, cache_config, quant_config),
prefix=f"{prefix}.layers",
)
if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
......@@ -269,16 +272,16 @@ class GLMTransformer(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
for i in range(self.num_layers):
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
position_ids=position_ids,
kv_cache=kv_caches[i],
kv_cache=kv_caches[i - self.start_layer],
attn_metadata=attn_metadata,
)
# Final layer norm.
if self.post_layer_norm:
if get_pp_group().is_last_rank and self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
......@@ -288,7 +291,7 @@ class ChatGLMModel(nn.Module):
def __init__(
self,
config,
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
......@@ -305,6 +308,9 @@ class ChatGLMModel(nn.Module):
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size,
quant_config=quant_config)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
......@@ -312,8 +318,12 @@ class ChatGLMModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
inputs_embeds = self.embedding(input_ids)
else:
inputs_embeds = intermediate_tensors["hidden_states"]
# Run encoder.
hidden_states = self.encoder(
......@@ -322,10 +332,13 @@ class ChatGLMModel(nn.Module):
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states
class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
......@@ -362,6 +375,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
......@@ -370,9 +385,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -402,6 +417,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
......@@ -20,7 +20,7 @@
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
from typing import Iterable, List, Optional, Set, Tuple
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -29,14 +29,13 @@ from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -47,7 +46,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
@torch.compile
......@@ -82,7 +83,7 @@ class CohereMLP(nn.Module):
def __init__(
self,
config,
config: CohereConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -256,6 +257,7 @@ class CohereModel(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
......@@ -265,12 +267,16 @@ class CohereModel(nn.Module):
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
CohereDecoderLayer(config, cache_config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: CohereDecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
self.norm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
......@@ -278,23 +284,34 @@ class CohereModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class CohereForCausalLM(nn.Module, SupportsLoRA):
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......@@ -337,6 +354,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA):
quant_config,
lora_config=lora_config)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@torch.no_grad()
def forward(
......@@ -346,9 +365,9 @@ class CohereForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -393,6 +412,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -405,6 +426,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
# coding=utf-8
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -24,6 +23,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dbrx import DbrxConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class DbrxRouter(nn.Module):
"""A Router implementation for DBRX that returns logits for each expert
......@@ -296,22 +299,27 @@ class DbrxModel(nn.Module):
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.wte = VocabParallelEmbedding(
config.vocab_size,
config.d_model,
)
self.blocks = nn.ModuleList([
DbrxBlock(config, cache_config, quant_config)
for _ in range(config.n_layers)
])
self.start_layer, self.end_layer, self.blocks = make_layers(
config.n_layers,
lambda prefix: DbrxBlock(config, cache_config, quant_config),
prefix=f"{prefix}.blocks",
)
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
for module in self.modules():
if hasattr(module, "bias") and isinstance(module.bias,
nn.Parameter):
# Remove the bias term in Linear and LayerNorm.
module.register_parameter("bias", None)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.d_model))
def forward(
self,
......@@ -319,21 +327,28 @@ class DbrxModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids)
for i in range(len(self.blocks)):
else:
assert intermediate_tensors
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
block = self.blocks[i]
hidden_states = block(
position_ids,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm_f(hidden_states)
return hidden_states
class DbrxForCausalLM(nn.Module):
class DbrxForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
......@@ -359,6 +374,8 @@ class DbrxForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
......@@ -367,9 +384,9 @@ class DbrxForCausalLM(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -401,11 +418,15 @@ class DbrxForCausalLM(nn.Module):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, weight_name)
break
else:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
......@@ -29,11 +29,12 @@ import torch
from transformers import LlamaConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM
from .utils import is_pp_missing_parameter
class DeciLMForCausalLM(LlamaForCausalLM):
"""
......@@ -91,6 +92,8 @@ class DeciLMForCausalLM(LlamaForCausalLM):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -99,6 +102,8 @@ class DeciLMForCausalLM(LlamaForCausalLM):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
......@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -29,7 +29,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
......@@ -40,8 +40,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -50,6 +49,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class DeepseekMLP(nn.Module):
......@@ -329,6 +332,7 @@ class DeepseekModel(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
......@@ -338,14 +342,17 @@ class DeepseekModel(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
DeepseekDecoderLayer(config,
layer_idx,
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
quant_config=quant_config),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
......@@ -353,19 +360,29 @@ class DeepseekModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
else:
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata,
residual)
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class DeepseekForCausalLM(nn.Module):
class DeepseekForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
......@@ -384,6 +401,8 @@ class DeepseekForCausalLM(nn.Module):
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
......@@ -392,9 +411,9 @@ class DeepseekForCausalLM(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -439,6 +458,8 @@ class DeepseekForCausalLM(nn.Module):
if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -451,6 +472,8 @@ class DeepseekForCausalLM(nn.Module):
if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict):
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
......@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeepseekV2 model."""
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -40,8 +40,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -50,7 +49,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class DeepseekV2MLP(nn.Module):
......@@ -439,6 +440,9 @@ class DeepseekV2Model(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
......@@ -447,7 +451,7 @@ class DeepseekV2Model(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
......@@ -472,7 +476,7 @@ class DeepseekV2Model(nn.Module):
return hidden_states
class DeepseekV2ForCausalLM(nn.Module):
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
......@@ -492,6 +496,8 @@ class DeepseekV2ForCausalLM(nn.Module):
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
......@@ -500,7 +506,7 @@ class DeepseekV2ForCausalLM(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
......
......@@ -38,8 +38,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
......@@ -53,8 +52,9 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.exaone import ExaoneConfig
from vllm.utils import is_hip
from .interfaces import SupportsLoRA
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class ExaoneGatedMLP(nn.Module):
......@@ -354,6 +354,10 @@ class ExaoneModel(nn.Module):
else:
self.ln_f = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
......@@ -397,7 +401,7 @@ class ExaoneModel(nn.Module):
return hidden_states
class ExaoneForCausalLM(nn.Module, SupportsLoRA):
class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......@@ -477,6 +481,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA):
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
......@@ -506,24 +513,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros(
(batch_size, self.config.hidden_size),
dtype=dtype,
device=device,
),
"residual":
torch.zeros(
(batch_size, self.config.hidden_size),
dtype=dtype,
device=device,
),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment