Unverified Commit d4d93db2 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[V1] V1 Enablement Oracle (#13726)


Signed-off-by: default avatarrshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: default avatarrshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: default avatarNicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarMichael Goin <michael@neuralmagic.com>
parent 8c0d15d5
......@@ -595,6 +595,13 @@ class AsyncLLMEngine(EngineClient):
log_requests: bool = True,
start_engine_loop: bool = True,
**kwargs) -> None:
if envs.VLLM_USE_V1:
raise ValueError(
"Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. "
"This should not happen. As a workaround, try using "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
self.log_requests = log_requests
self.engine = self._engine_class(*args, **kwargs)
......@@ -629,33 +636,53 @@ class AsyncLLMEngine(EngineClient):
engine_config: VllmConfig) -> Type[ExecutorBase]:
return LLMEngine._get_executor_cls(engine_config)
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
disable_log_requests: bool = False,
disable_log_stats: bool = False,
) -> "AsyncLLMEngine":
"""Create an AsyncLLMEngine from the EngineArgs."""
return cls(
vllm_config=vllm_config,
executor_class=cls._get_executor_cls(vllm_config),
start_engine_loop=start_engine_loop,
log_requests=not disable_log_requests,
log_stats=not disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
if engine_config is None:
engine_config = engine_args.create_engine_config(usage_context)
executor_class = cls._get_executor_cls(engine_config)
# Create the async LLM engine.
engine = cls(
vllm_config=engine_config,
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
vllm_config = engine_args.create_engine_config(usage_context)
async_engine_cls = cls
if envs.VLLM_USE_V1:
from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine
async_engine_cls = V1AsyncLLMEngine
return async_engine_cls.from_vllm_config(
vllm_config=vllm_config,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
disable_log_stats=engine_args.disable_log_stats,
disable_log_requests=engine_args.disable_log_requests,
)
return engine
@property
def is_running(self) -> bool:
......@@ -1203,7 +1230,7 @@ class AsyncLLMEngine(EngineClient):
# TODO(v1): Remove this class proxy when V1 goes default.
if envs.VLLM_USE_V1:
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
from vllm.v1.engine.async_llm import AsyncLLM
AsyncLLMEngine = AsyncLLM # type: ignore
......@@ -216,6 +216,12 @@ class LLMEngine:
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
) -> None:
if envs.VLLM_USE_V1:
raise ValueError(
"Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
"This should not happen. As a workaround, try using "
"LLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
......@@ -479,6 +485,22 @@ class LLMEngine:
f"{distributed_executor_backend}")
return executor_class
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
disable_log_stats: bool = False,
) -> "LLMEngine":
return cls(
vllm_config=vllm_config,
executor_class=cls._get_executor_cls(vllm_config),
log_stats=(not disable_log_stats),
usage_context=usage_context,
stat_loggers=stat_loggers,
)
@classmethod
def from_engine_args(
cls,
......@@ -488,19 +510,20 @@ class LLMEngine:
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config(usage_context)
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine.
engine = cls(
vllm_config=engine_config,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
vllm_config = engine_args.create_engine_config(usage_context)
engine_cls = cls
if envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
engine_cls = V1LLMEngine
return engine_cls.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
stat_loggers=stat_loggers,
disable_log_stats=engine_args.disable_log_stats,
)
return engine
def __reduce__(self):
# This is to ensure that the LLMEngine is not referenced in
# the closure used to initialize Ray worker actors
......@@ -2097,6 +2120,6 @@ class LLMEngine:
return sampling_params
# TODO(v1): Remove this class proxy when V1 goes default.
if envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine # type: ignore
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
LLMEngine = V1LLMEngine # type: ignore
......@@ -18,7 +18,6 @@ from zmq.asyncio import Socket
from vllm import PoolingParams
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.async_llm_engine import (
......@@ -133,9 +132,9 @@ class MQLLMEngineClient(EngineClient):
self._engine_process = psutil.Process(engine_pid)
@staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs):
def is_unsupported_config(vllm_config: VllmConfig):
# Pipeline parallel not yet supported
return engine_args.pipeline_parallel_size > 1
return vllm_config.parallel_config.pipeline_parallel_size > 1
@contextmanager
def get_data_socket(self) -> Iterator[Socket]:
......
......@@ -9,6 +9,7 @@ import cloudpickle
import zmq
from vllm import AsyncEngineArgs, SamplingParams
from vllm.config import VllmConfig
from vllm.engine.llm_engine import LLMEngine
# yapf conflicts with isort for this block
# yapf: disable
......@@ -110,25 +111,39 @@ class MQLLMEngine:
return ENGINE_DEAD_ERROR()
@classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str):
"""Creates an MQLLMEngine from the engine arguments."""
def from_vllm_config(cls, vllm_config: VllmConfig,
usage_context: UsageContext,
disable_log_requests: bool, disable_log_stats: bool,
ipc_path: str) -> "MQLLMEngine":
# Setup plugins for each process
from vllm.plugins import load_general_plugins
load_general_plugins()
engine_config = engine_args.create_engine_config(usage_context)
executor_class = LLMEngine._get_executor_cls(engine_config)
use_async_sockets = vllm_config.model_config.use_async_output_proc
return cls(
vllm_config=vllm_config,
executor_class=LLMEngine._get_executor_cls(vllm_config),
ipc_path=ipc_path,
usage_context=usage_context,
use_async_sockets=use_async_sockets,
log_requests=(not disable_log_requests),
log_stats=(not disable_log_stats),
)
use_async_sockets = engine_config.model_config.use_async_output_proc
@staticmethod
def from_engine_args(engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str):
"""Creates an MQLLMEngine from the engine arguments."""
return cls(ipc_path=ipc_path,
use_async_sockets=use_async_sockets,
vllm_config=engine_config,
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context)
vllm_config = engine_args.create_engine_config(usage_context)
return MQLLMEngine.from_vllm_config(
ipc_path=ipc_path,
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_requests=engine_args.disable_log_requests,
disable_log_stats=engine_args.disable_log_stats,
)
def start(self):
try:
......@@ -396,12 +411,16 @@ def signal_handler(*_) -> None:
raise KeyboardInterrupt("MQLLMEngine terminated")
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str, engine_alive):
def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
ipc_path: str, disable_log_stats: bool,
disable_log_requests: bool, engine_alive):
try:
engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
usage_context=usage_context,
ipc_path=ipc_path)
engine = MQLLMEngine.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_stats=disable_log_stats,
disable_log_requests=disable_log_requests,
ipc_path=ipc_path)
signal.signal(signal.SIGTERM, signal_handler)
......
......@@ -11,7 +11,6 @@ import torch.nn as nn
from tqdm import tqdm
from typing_extensions import TypeVar, deprecated
from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.config import CompilationConfig
......@@ -238,23 +237,15 @@ class LLM:
compilation_config=compilation_config_instance,
**kwargs,
)
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
self.engine_class = self.get_engine_class()
self.llm_engine = self.engine_class.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS)
# Create the Engine (autoselects V0 vs V1)
self.llm_engine = LLMEngine.from_engine_args(
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
self.engine_class = type(self.llm_engine)
self.request_counter = Counter()
self.default_sampling_params: Union[dict[str, Any], None] = None
@staticmethod
def get_engine_class() -> type[LLMEngine]:
if envs.VLLM_USE_V1:
# Lazy import: the v1 package isn't distributed
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
return V1LLMEngine # type: ignore
return LLMEngine
def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
......
......@@ -154,21 +154,47 @@ async def build_async_engine_client_from_engine_args(
Returns the Client or None if the creation failed.
"""
# AsyncLLMEngine.
if (MQLLMEngineClient.is_unsupported_config(engine_args)
or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
# Create the EngineConfig (determines if we can use V1).
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
# V1 AsyncLLM.
if envs.VLLM_USE_V1:
if disable_frontend_multiprocessing:
logger.warning(
"V1 is enabled, but got --disable-frontend-multiprocessing. "
"To disable frontend multiprocessing, set VLLM_USE_V1=0.")
from vllm.v1.engine.async_llm import AsyncLLM
async_llm: Optional[AsyncLLM] = None
try:
async_llm = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_requests=engine_args.disable_log_requests,
disable_log_stats=engine_args.disable_log_stats)
yield async_llm
finally:
if async_llm:
async_llm.shutdown()
# V0 AsyncLLM.
elif (MQLLMEngineClient.is_unsupported_config(vllm_config)
or disable_frontend_multiprocessing):
engine_client: Optional[EngineClient] = None
try:
engine_client = AsyncLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.OPENAI_API_SERVER)
engine_client = AsyncLLMEngine.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_requests=engine_args.disable_log_requests,
disable_log_stats=engine_args.disable_log_stats)
yield engine_client
finally:
if engine_client and hasattr(engine_client, "shutdown"):
engine_client.shutdown()
# MQLLMEngine.
# V0MQLLMEngine.
else:
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
......@@ -199,10 +225,11 @@ async def build_async_engine_client_from_engine_args(
# not actually result in an exitcode being reported. As a result
# we use a shared variable to communicate the information.
engine_alive = multiprocessing.Value('b', True, lock=False)
engine_process = context.Process(target=run_mp_engine,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
ipc_path, engine_alive))
engine_process = context.Process(
target=run_mp_engine,
args=(vllm_config, UsageContext.OPENAI_API_SERVER, ipc_path,
engine_args.disable_log_stats,
engine_args.disable_log_requests, engine_alive))
engine_process.start()
engine_pid = engine_process.pid
assert engine_pid is not None, "Engine process failed to start."
......@@ -217,8 +244,7 @@ async def build_async_engine_client_from_engine_args(
atexit.register(_cleanup_ipc_path)
# Build RPCClient, which conforms to EngineClient Protocol.
engine_config = engine_args.create_engine_config()
build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
build_client = partial(MQLLMEngineClient, ipc_path, vllm_config,
engine_pid)
mq_engine_client = await asyncio.get_running_loop().run_in_executor(
None, build_client)
......
......@@ -74,7 +74,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = False
VLLM_USE_V1: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
......@@ -522,7 +522,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, use the V1 code path.
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),
lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))),
# Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING":
......@@ -644,3 +644,19 @@ def __getattr__(name: str):
def __dir__():
return list(environment_variables.keys())
def is_set(name: str):
"""Check if an environment variable is explicitly set."""
if name in environment_variables:
return name in os.environ
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def set_vllm_use_v1(use_v1: bool):
if is_set("VLLM_USE_V1"):
raise ValueError(
"Should not call set_vllm_use_v1() if VLLM_USE_V1 is set "
"explicitly by the user. Please raise this as a Github "
"Issue and explicitly set VLLM_USE_V1=0 or 1.")
os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0"
......@@ -74,7 +74,8 @@ def resolve_transformers_fallback(model_config: ModelConfig,
if not is_transformers_impl_compatible(arch, custom_model_module):
raise ValueError(
f"{arch} has no vLLM implementation and the Transformers "
"implementation is not compatible with vLLM.")
"implementation is not compatible with vLLM. Try setting "
"VLLM_USE_V1=0.")
logger.warning(
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
......
......@@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .interfaces import SupportsPP, SupportsV0Only
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -279,7 +279,7 @@ class BloomModel(nn.Module):
return hidden_states
class BloomForCausalLM(nn.Module, SupportsPP):
class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
......@@ -3,10 +3,11 @@
from vllm.config import VllmConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from .interfaces import SupportsV0Only
from .utils import PPMissingLayer
class GlmForCausalLM(LlamaForCausalLM):
class GlmForCausalLM(LlamaForCausalLM, SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
......
......@@ -36,7 +36,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
SupportsMultiModal, SupportsPP, SupportsV0Only)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings,
......@@ -405,7 +405,8 @@ class ModifiedWhisperEncoder(WhisperEncoder):
UltravoxMultiModalProcessor,
info=UltravoxProcessingInfo,
dummy_inputs=UltravoxDummyInputsBuilder)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
SupportsV0Only):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
......
......@@ -196,7 +196,8 @@ class FlashAttentionImpl(AttentionImpl):
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
......
......@@ -8,6 +8,7 @@ from typing import Optional, Union
import numpy as np
import vllm.envs as envs
from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
......@@ -49,6 +50,12 @@ class AsyncLLM(EngineClient):
log_requests: bool = True,
start_engine_loop: bool = True,
) -> None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
assert start_engine_loop
......@@ -92,22 +99,50 @@ class AsyncLLM(EngineClient):
self.output_handler: Optional[asyncio.Task] = None
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
disable_log_requests: bool = False,
disable_log_stats: bool = False,
) -> "AsyncLLM":
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
# FIXME(rob): refactor VllmConfig to include the StatLoggers
# include StatLogger in the Oracle decision.
if stat_loggers is not None:
raise ValueError("Custom StatLoggers are not yet supported on V1. "
"Explicitly set VLLM_USE_V1=0 to disable V1.")
# Create the LLMEngine.
return cls(
vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
start_engine_loop=start_engine_loop,
log_requests=not disable_log_requests,
log_stats=not disable_log_stats,
usage_context=usage_context,
)
@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs."""
# Create the engine configs.
if engine_config is None:
vllm_config = engine_args.create_engine_config(usage_context)
else:
vllm_config = engine_config
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = Executor.get_class(vllm_config)
# Create the AsyncLLM.
......
......@@ -46,6 +46,13 @@ class LLMEngine:
use_cached_outputs: bool = False,
multiprocess_mode: bool = False,
) -> None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"LLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
......@@ -88,6 +95,26 @@ class LLMEngine:
# for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
disable_log_stats: bool = False,
) -> "LLMEngine":
if stat_loggers is not None:
raise NotImplementedError(
"Passing StatLoggers to V1 is not yet supported. "
"Set VLLM_USE_V1=0 and file and issue on Github.")
return cls(vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
log_stats=(not disable_log_stats),
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)
@classmethod
def from_engine_args(
cls,
......
......@@ -184,7 +184,7 @@ class Processor:
# Only applicable to multimodal models with legacy input processor.
processed_inputs = self.input_processor(preprocessed_inputs)
self._validate_model_inputs(processed_inputs)
self._validate_model_inputs(processed_inputs, lora_request)
if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = SingletonInputsAdapter(
......@@ -200,8 +200,12 @@ class Processor:
raise NotImplementedError
assert isinstance(params, SamplingParams)
# TODO: can we avoid cloning here in multiproc case
# TODO: can we avoid cloning here in multiproc case?
sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
sampling_params.max_tokens = (self.model_config.max_model_len -
len(decoder_inputs.prompt_token_ids))
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer(
......@@ -296,7 +300,9 @@ class Processor:
lora_request=lora_request,
)
def _validate_model_inputs(self, inputs: ProcessorInputs):
def _validate_model_inputs(self,
inputs: ProcessorInputs,
lora_request: Optional[LoRARequest] = None):
if is_encoder_decoder_inputs(inputs):
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
......@@ -310,6 +316,13 @@ class Processor:
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")
max_input_id = max(prompt_ids)
max_allowed = self.tokenizer.get_lora_tokenizer(
lora_request).max_token_id
if max_input_id > max_allowed:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))
if len(prompt_ids) >= self.model_config.max_model_len:
raise ValueError(
f"Prompt length of {len(prompt_ids)} is longer than the "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment