Unverified Commit d58268c5 authored by Joe Runde's avatar Joe Runde Committed by GitHub
Browse files

[V1] Make v1 more testable (#9888)


Signed-off-by: default avatarJoe Runde <Joseph.Runde@ibm.com>
parent 87bd7e05
...@@ -191,6 +191,9 @@ ADD . /vllm-workspace/ ...@@ -191,6 +191,9 @@ ADD . /vllm-workspace/
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-dev.txt python3 -m pip install -r requirements-dev.txt
# Copy in the v1 package for testing (it isn't distributed yet)
COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1
# doc requires source code # doc requires source code
# we hide them inside `test_docs/` , so that this source code # we hide them inside `test_docs/` , so that this source code
# will not be imported by other tests # will not be imported by other tests
......
...@@ -97,4 +97,5 @@ markers = [ ...@@ -97,4 +97,5 @@ markers = [
"skip_global_cleanup", "skip_global_cleanup",
"core_model: run this model test in each PR instead of just daily", "core_model: run this model test in each PR instead of just daily",
"distributed_2_gpus: run this test only in distributed tests for 2 GPUs", "distributed_2_gpus: run this test only in distributed tests for 2 GPUs",
"skip_v1: do not run this test with v1",
] ]
...@@ -5,6 +5,7 @@ from collections import UserList ...@@ -5,6 +5,7 @@ from collections import UserList
from enum import Enum from enum import Enum
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type, from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
TypedDict, TypeVar, Union) TypedDict, TypeVar, Union)
from unittest.mock import patch
import numpy as np import numpy as np
import pytest import pytest
...@@ -108,6 +109,23 @@ VIDEO_ASSETS = _VideoAssets() ...@@ -108,6 +109,23 @@ VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`.""" """Singleton instance of :class:`_VideoAssets`."""
@pytest.fixture(params=[True, False])
def run_with_both_engines(request):
# Automatically runs tests twice, once with V1 and once without
use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1
skip_v1 = request.node.get_closest_marker("skip_v1")
if use_v1:
if skip_v1:
pytest.skip("Skipping test on vllm V1")
with patch('vllm.envs.VLLM_USE_V1', True):
yield
else:
with patch('vllm.envs.VLLM_USE_V1', False):
yield
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def init_test_http_connection(): def init_test_http_connection():
# pytest_asyncio may use a different event loop per test # pytest_asyncio may use a different event loop per test
......
...@@ -3,12 +3,21 @@ import pytest ...@@ -3,12 +3,21 @@ import pytest
from vllm import LLM from vllm import LLM
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
def test_empty_prompt(): def test_empty_prompt():
llm = LLM(model="gpt2", enforce_eager=True) llm = LLM(model="gpt2", enforce_eager=True)
with pytest.raises(ValueError, match='Prompt cannot be empty'): with pytest.raises(ValueError, match='Prompt cannot be empty'):
llm.generate([""]) llm.generate([""])
@pytest.mark.skip_v1
def test_out_of_vocab_token(): def test_out_of_vocab_token():
llm = LLM(model="gpt2", enforce_eager=True) llm = LLM(model="gpt2", enforce_eager=True)
with pytest.raises(ValueError, match='out of vocabulary'): with pytest.raises(ValueError, match='out of vocabulary'):
......
...@@ -44,6 +44,8 @@ def test_env(name: str, device: str, monkeypatch): ...@@ -44,6 +44,8 @@ def test_env(name: str, device: str, monkeypatch):
def test_flash_attn(monkeypatch): def test_flash_attn(monkeypatch):
"""Test FlashAttn validation.""" """Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to
# which_attn_to_use
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL) override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
......
...@@ -16,7 +16,7 @@ from tests.kernels.utils import * ...@@ -16,7 +16,7 @@ from tests.kernels.utils import *
from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
AttentionType) AttentionType)
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, get_attn_backend, from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager) global_force_attn_backend_context_manager)
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -774,7 +774,7 @@ def set_reset_environment(attn_backend): ...@@ -774,7 +774,7 @@ def set_reset_environment(attn_backend):
default_dtype = torch.get_default_dtype() default_dtype = torch.get_default_dtype()
if attn_backend.name == 'FLASH_ATTN': if attn_backend.name == 'FLASH_ATTN':
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
get_attn_backend.cache_clear() _cached_get_attn_backend.cache_clear()
yield yield
# Reset the torch datatype to what it was before the test # Reset the torch datatype to what it was before the test
# so as not to impact the remaining tests. # so as not to impact the remaining tests.
......
...@@ -89,7 +89,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: ...@@ -89,7 +89,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
return forced_attn_backend return forced_attn_backend
@lru_cache(maxsize=None)
def get_attn_backend( def get_attn_backend(
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
...@@ -99,6 +98,31 @@ def get_attn_backend( ...@@ -99,6 +98,31 @@ def get_attn_backend(
is_blocksparse: bool = False, is_blocksparse: bool = False,
) -> Type[AttentionBackend]: ) -> Type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it.""" """Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
# value to be returned from the cache if the value changes between calls.
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
# private function.
return _cached_get_attn_backend(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
block_size=block_size,
is_attention_free=is_attention_free,
is_blocksparse=is_blocksparse,
use_v1=envs.VLLM_USE_V1,
)
@lru_cache(maxsize=None)
def _cached_get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
is_blocksparse: bool = False,
use_v1: bool = False,
) -> Type[AttentionBackend]:
if is_blocksparse: if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.") logger.info("Using BlocksparseFlashAttention backend.")
from vllm.attention.backends.blocksparse_attn import ( from vllm.attention.backends.blocksparse_attn import (
...@@ -106,7 +130,7 @@ def get_attn_backend( ...@@ -106,7 +130,7 @@ def get_attn_backend(
return BlocksparseFlashAttentionBackend return BlocksparseFlashAttentionBackend
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size, backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
is_attention_free) is_attention_free, use_v1)
if backend == _Backend.FLASH_ATTN: if backend == _Backend.FLASH_ATTN:
logger.info("Using Flash Attention backend.") logger.info("Using Flash Attention backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401 from vllm.attention.backends.flash_attn import ( # noqa: F401
...@@ -162,13 +186,12 @@ def get_attn_backend( ...@@ -162,13 +186,12 @@ def get_attn_backend(
raise ValueError("Invalid attention backend.") raise ValueError("Invalid attention backend.")
def which_attn_to_use( def which_attn_to_use(head_size: int,
head_size: int, dtype: torch.dtype,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
kv_cache_dtype: Optional[str], block_size: int,
block_size: int, is_attention_free: bool,
is_attention_free: bool, use_v1: bool = False) -> _Backend:
) -> _Backend:
"""Returns which flash attention backend to use.""" """Returns which flash attention backend to use."""
# Default case. # Default case.
selected_backend = _Backend.FLASH_ATTN selected_backend = _Backend.FLASH_ATTN
...@@ -228,7 +251,7 @@ def which_attn_to_use( ...@@ -228,7 +251,7 @@ def which_attn_to_use(
if current_platform.is_hpu(): if current_platform.is_hpu():
return _Backend.HPU_ATTN return _Backend.HPU_ATTN
if envs.VLLM_USE_V1: if use_v1:
return _Backend.FLASH_ATTN_VLLM_V1 return _Backend.FLASH_ATTN_VLLM_V1
# FlashAttn in NVIDIA GPUs. # FlashAttn in NVIDIA GPUs.
......
...@@ -6,7 +6,9 @@ from typing import Iterator, List, Optional, Union ...@@ -6,7 +6,9 @@ from typing import Iterator, List, Optional, Union
import cloudpickle import cloudpickle
import zmq import zmq
import vllm.envs
from vllm import AsyncEngineArgs, SamplingParams from vllm import AsyncEngineArgs, SamplingParams
from vllm.engine.llm_engine import LLMEngine
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
...@@ -17,17 +19,11 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -17,17 +19,11 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCStartupRequest, RPCStartupResponse, RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest) RPCUProfileRequest)
# yapf: enable # yapf: enable
from vllm.envs import VLLM_USE_V1
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
if VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine
else:
from vllm.engine.llm_engine import LLMEngine
logger = init_logger(__name__) logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 10000 POLLING_TIMEOUT_MS = 10000
...@@ -117,11 +113,17 @@ class MQLLMEngine: ...@@ -117,11 +113,17 @@ class MQLLMEngine:
load_general_plugins() load_general_plugins()
engine_config = engine_args.create_engine_config() engine_config = engine_args.create_engine_config()
if vllm.envs.VLLM_USE_V1:
# Lazy import: the v1 package isn't distributed
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
engine_class = V1LLMEngine
else:
engine_class = LLMEngine
executor_class = LLMEngine._get_executor_cls(engine_config) executor_class = engine_class._get_executor_cls(engine_config)
use_async_sockets = (engine_config.model_config.use_async_output_proc use_async_sockets = (engine_config.model_config.use_async_output_proc
and not VLLM_USE_V1) and not vllm.envs.VLLM_USE_V1)
return cls(ipc_path=ipc_path, return cls(ipc_path=ipc_path,
use_async_sockets=use_async_sockets, use_async_sockets=use_async_sockets,
......
import itertools import itertools
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Union, cast, overload) Union, cast, overload)
from tqdm import tqdm from tqdm import tqdm
...@@ -10,6 +10,7 @@ from vllm import envs ...@@ -10,6 +10,7 @@ from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score) BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs, TaskOption from vllm.engine.arg_utils import EngineArgs, TaskOption
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template, apply_hf_chat_template,
apply_mistral_chat_template, apply_mistral_chat_template,
...@@ -31,11 +32,6 @@ from vllm.transformers_utils.tokenizer_group import TokenizerGroup ...@@ -31,11 +32,6 @@ from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
if envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine # type: ignore
else:
from vllm.engine.llm_engine import LLMEngine # type: ignore
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -206,10 +202,21 @@ class LLM: ...@@ -206,10 +202,21 @@ class LLM:
pooling_returned_token_ids=pooling_returned_token_ids, pooling_returned_token_ids=pooling_returned_token_ids,
**kwargs, **kwargs,
) )
self.llm_engine = LLMEngine.from_engine_args( # Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
self.engine_class = self.get_engine_class()
self.llm_engine = self.engine_class.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS) engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter() self.request_counter = Counter()
@staticmethod
def get_engine_class() -> Type[LLMEngine]:
if envs.VLLM_USE_V1:
# Lazy import: the v1 package isn't distributed
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
return V1LLMEngine # type: ignore
return LLMEngine
def get_tokenizer(self) -> AnyTokenizer: def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
...@@ -394,7 +401,7 @@ class LLM: ...@@ -394,7 +401,7 @@ class LLM:
priority=priority) priority=priority)
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput) return self.engine_class.validate_outputs(outputs, RequestOutput)
def beam_search( def beam_search(
self, self,
...@@ -769,7 +776,8 @@ class LLM: ...@@ -769,7 +776,8 @@ class LLM:
) )
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput) return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput)
def start_profile(self) -> None: def start_profile(self) -> None:
self.llm_engine.start_profile() self.llm_engine.start_profile()
......
...@@ -30,6 +30,15 @@ if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): ...@@ -30,6 +30,15 @@ if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
else: else:
flashinfer_top_k_top_p_sampling = None flashinfer_top_k_top_p_sampling = None
def get_sampler() -> torch.nn.Module:
if envs.VLLM_USE_V1:
# Lazy import: the v1 package isn't distributed
from vllm.v1.sample.sampler import Sampler as V1Sampler
return V1Sampler()
return Sampler()
# (num_token_ids, num_parent_ids) per sequence group. # (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]] SampleResultType = List[Tuple[List[int], List[int]]]
......
...@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.deepspeedfp import ( from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig, DeepSpeedFPParameter) DeepSpeedFPConfig, DeepSpeedFPParameter)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -436,7 +436,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP): ...@@ -436,7 +436,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
......
...@@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -352,7 +352,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -352,7 +352,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
......
...@@ -34,7 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -34,7 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -838,7 +838,7 @@ class BartForConditionalGeneration(nn.Module): ...@@ -838,7 +838,7 @@ class BartForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = get_sampler()
def forward( def forward(
self, self,
......
...@@ -13,7 +13,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, ...@@ -13,7 +13,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.multimodal.utils import consecutive_placeholder_ranges
...@@ -525,7 +525,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -525,7 +525,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
if hasattr(self.language_model, "sampler"): if hasattr(self.language_model, "sampler"):
return self.language_model.sampler return self.language_model.sampler
return Sampler() return get_sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
......
...@@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -298,7 +298,7 @@ class BloomForCausalLM(nn.Module, SupportsPP): ...@@ -298,7 +298,7 @@ class BloomForCausalLM(nn.Module, SupportsPP):
self.config.hidden_size) self.config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors) self.transformer.make_empty_intermediate_tensors)
......
...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -946,7 +946,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -946,7 +946,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale) config.vocab_size, logit_scale)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
......
...@@ -24,7 +24,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -24,7 +24,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -616,7 +616,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -616,7 +616,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
self.transformer.embedding.weight) self.transformer.embedding.weight)
self.lm_head = self.transformer.output_layer self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler() self.sampler = get_sampler()
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -355,7 +355,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -355,7 +355,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
cache_config, cache_config,
quant_config, quant_config,
lora_config=lora_config) lora_config=lora_config)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
......
...@@ -14,7 +14,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, ...@@ -14,7 +14,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -373,7 +373,7 @@ class DbrxForCausalLM(nn.Module, SupportsPP): ...@@ -373,7 +373,7 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
) )
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors) self.transformer.make_empty_intermediate_tensors)
......
...@@ -41,7 +41,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -41,7 +41,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -399,7 +399,7 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): ...@@ -399,7 +399,7 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment