Unverified Commit d4d93db2 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[V1] V1 Enablement Oracle (#13726)


Signed-off-by: default avatarrshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: default avatarrshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: default avatarNicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarMichael Goin <michael@neuralmagic.com>
parent 8c0d15d5
...@@ -595,6 +595,13 @@ class AsyncLLMEngine(EngineClient): ...@@ -595,6 +595,13 @@ class AsyncLLMEngine(EngineClient):
log_requests: bool = True, log_requests: bool = True,
start_engine_loop: bool = True, start_engine_loop: bool = True,
**kwargs) -> None: **kwargs) -> None:
if envs.VLLM_USE_V1:
raise ValueError(
"Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. "
"This should not happen. As a workaround, try using "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
self.log_requests = log_requests self.log_requests = log_requests
self.engine = self._engine_class(*args, **kwargs) self.engine = self._engine_class(*args, **kwargs)
...@@ -629,33 +636,53 @@ class AsyncLLMEngine(EngineClient): ...@@ -629,33 +636,53 @@ class AsyncLLMEngine(EngineClient):
engine_config: VllmConfig) -> Type[ExecutorBase]: engine_config: VllmConfig) -> Type[ExecutorBase]:
return LLMEngine._get_executor_cls(engine_config) return LLMEngine._get_executor_cls(engine_config)
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
disable_log_requests: bool = False,
disable_log_stats: bool = False,
) -> "AsyncLLMEngine":
"""Create an AsyncLLMEngine from the EngineArgs."""
return cls(
vllm_config=vllm_config,
executor_class=cls._get_executor_cls(vllm_config),
start_engine_loop=start_engine_loop,
log_requests=not disable_log_requests,
log_stats=not disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
@classmethod @classmethod
def from_engine_args( def from_engine_args(
cls, cls,
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True, start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine": ) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
if engine_config is None: vllm_config = engine_args.create_engine_config(usage_context)
engine_config = engine_args.create_engine_config(usage_context)
async_engine_cls = cls
executor_class = cls._get_executor_cls(engine_config) if envs.VLLM_USE_V1:
from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine
# Create the async LLM engine. async_engine_cls = V1AsyncLLMEngine
engine = cls(
vllm_config=engine_config, return async_engine_cls.from_vllm_config(
executor_class=executor_class, vllm_config=vllm_config,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
start_engine_loop=start_engine_loop, start_engine_loop=start_engine_loop,
usage_context=usage_context, usage_context=usage_context,
stat_loggers=stat_loggers, stat_loggers=stat_loggers,
disable_log_stats=engine_args.disable_log_stats,
disable_log_requests=engine_args.disable_log_requests,
) )
return engine
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
...@@ -1203,7 +1230,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -1203,7 +1230,7 @@ class AsyncLLMEngine(EngineClient):
# TODO(v1): Remove this class proxy when V1 goes default. # TODO(v1): Remove this class proxy when V1 goes default.
if envs.VLLM_USE_V1: if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
AsyncLLMEngine = AsyncLLM # type: ignore AsyncLLMEngine = AsyncLLM # type: ignore
...@@ -216,6 +216,12 @@ class LLMEngine: ...@@ -216,6 +216,12 @@ class LLMEngine:
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False, use_cached_outputs: bool = False,
) -> None: ) -> None:
if envs.VLLM_USE_V1:
raise ValueError(
"Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
"This should not happen. As a workaround, try using "
"LLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
...@@ -479,6 +485,22 @@ class LLMEngine: ...@@ -479,6 +485,22 @@ class LLMEngine:
f"{distributed_executor_backend}") f"{distributed_executor_backend}")
return executor_class return executor_class
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
disable_log_stats: bool = False,
) -> "LLMEngine":
return cls(
vllm_config=vllm_config,
executor_class=cls._get_executor_cls(vllm_config),
log_stats=(not disable_log_stats),
usage_context=usage_context,
stat_loggers=stat_loggers,
)
@classmethod @classmethod
def from_engine_args( def from_engine_args(
cls, cls,
...@@ -488,19 +510,20 @@ class LLMEngine: ...@@ -488,19 +510,20 @@ class LLMEngine:
) -> "LLMEngine": ) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments.""" """Creates an LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_config = engine_args.create_engine_config(usage_context) vllm_config = engine_args.create_engine_config(usage_context)
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine. engine_cls = cls
engine = cls( if envs.VLLM_USE_V1:
vllm_config=engine_config, from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
executor_class=executor_class, engine_cls = V1LLMEngine
log_stats=not engine_args.disable_log_stats,
return engine_cls.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context, usage_context=usage_context,
stat_loggers=stat_loggers, stat_loggers=stat_loggers,
disable_log_stats=engine_args.disable_log_stats,
) )
return engine
def __reduce__(self): def __reduce__(self):
# This is to ensure that the LLMEngine is not referenced in # This is to ensure that the LLMEngine is not referenced in
# the closure used to initialize Ray worker actors # the closure used to initialize Ray worker actors
...@@ -2097,6 +2120,6 @@ class LLMEngine: ...@@ -2097,6 +2120,6 @@ class LLMEngine:
return sampling_params return sampling_params
# TODO(v1): Remove this class proxy when V1 goes default. if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
if envs.VLLM_USE_V1: from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
from vllm.v1.engine.llm_engine import LLMEngine # type: ignore LLMEngine = V1LLMEngine # type: ignore
...@@ -18,7 +18,6 @@ from zmq.asyncio import Socket ...@@ -18,7 +18,6 @@ from zmq.asyncio import Socket
from vllm import PoolingParams from vllm import PoolingParams
from vllm.config import DecodingConfig, ModelConfig, VllmConfig from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.engine.async_llm_engine import ( from vllm.engine.async_llm_engine import (
...@@ -133,9 +132,9 @@ class MQLLMEngineClient(EngineClient): ...@@ -133,9 +132,9 @@ class MQLLMEngineClient(EngineClient):
self._engine_process = psutil.Process(engine_pid) self._engine_process = psutil.Process(engine_pid)
@staticmethod @staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs): def is_unsupported_config(vllm_config: VllmConfig):
# Pipeline parallel not yet supported # Pipeline parallel not yet supported
return engine_args.pipeline_parallel_size > 1 return vllm_config.parallel_config.pipeline_parallel_size > 1
@contextmanager @contextmanager
def get_data_socket(self) -> Iterator[Socket]: def get_data_socket(self) -> Iterator[Socket]:
......
...@@ -9,6 +9,7 @@ import cloudpickle ...@@ -9,6 +9,7 @@ import cloudpickle
import zmq import zmq
from vllm import AsyncEngineArgs, SamplingParams from vllm import AsyncEngineArgs, SamplingParams
from vllm.config import VllmConfig
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -110,25 +111,39 @@ class MQLLMEngine: ...@@ -110,25 +111,39 @@ class MQLLMEngine:
return ENGINE_DEAD_ERROR() return ENGINE_DEAD_ERROR()
@classmethod @classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs, def from_vllm_config(cls, vllm_config: VllmConfig,
usage_context: UsageContext, ipc_path: str): usage_context: UsageContext,
"""Creates an MQLLMEngine from the engine arguments.""" disable_log_requests: bool, disable_log_stats: bool,
ipc_path: str) -> "MQLLMEngine":
# Setup plugins for each process # Setup plugins for each process
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
load_general_plugins() load_general_plugins()
engine_config = engine_args.create_engine_config(usage_context) use_async_sockets = vllm_config.model_config.use_async_output_proc
executor_class = LLMEngine._get_executor_cls(engine_config)
return cls(
vllm_config=vllm_config,
executor_class=LLMEngine._get_executor_cls(vllm_config),
ipc_path=ipc_path,
usage_context=usage_context,
use_async_sockets=use_async_sockets,
log_requests=(not disable_log_requests),
log_stats=(not disable_log_stats),
)
use_async_sockets = engine_config.model_config.use_async_output_proc @staticmethod
def from_engine_args(engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str):
"""Creates an MQLLMEngine from the engine arguments."""
return cls(ipc_path=ipc_path, vllm_config = engine_args.create_engine_config(usage_context)
use_async_sockets=use_async_sockets, return MQLLMEngine.from_vllm_config(
vllm_config=engine_config, ipc_path=ipc_path,
executor_class=executor_class, vllm_config=vllm_config,
log_requests=not engine_args.disable_log_requests, usage_context=usage_context,
log_stats=not engine_args.disable_log_stats, disable_log_requests=engine_args.disable_log_requests,
usage_context=usage_context) disable_log_stats=engine_args.disable_log_stats,
)
def start(self): def start(self):
try: try:
...@@ -396,12 +411,16 @@ def signal_handler(*_) -> None: ...@@ -396,12 +411,16 @@ def signal_handler(*_) -> None:
raise KeyboardInterrupt("MQLLMEngine terminated") raise KeyboardInterrupt("MQLLMEngine terminated")
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
ipc_path: str, engine_alive): ipc_path: str, disable_log_stats: bool,
disable_log_requests: bool, engine_alive):
try: try:
engine = MQLLMEngine.from_engine_args(engine_args=engine_args, engine = MQLLMEngine.from_vllm_config(
usage_context=usage_context, vllm_config=vllm_config,
ipc_path=ipc_path) usage_context=usage_context,
disable_log_stats=disable_log_stats,
disable_log_requests=disable_log_requests,
ipc_path=ipc_path)
signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGTERM, signal_handler)
......
...@@ -11,7 +11,6 @@ import torch.nn as nn ...@@ -11,7 +11,6 @@ import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing_extensions import TypeVar, deprecated from typing_extensions import TypeVar, deprecated
from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score) BeamSearchSequence, get_beam_search_score)
from vllm.config import CompilationConfig from vllm.config import CompilationConfig
...@@ -238,23 +237,15 @@ class LLM: ...@@ -238,23 +237,15 @@ class LLM:
compilation_config=compilation_config_instance, compilation_config=compilation_config_instance,
**kwargs, **kwargs,
) )
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues # Create the Engine (autoselects V0 vs V1)
self.engine_class = self.get_engine_class() self.llm_engine = LLMEngine.from_engine_args(
self.llm_engine = self.engine_class.from_engine_args( engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
engine_args, usage_context=UsageContext.LLM_CLASS) self.engine_class = type(self.llm_engine)
self.request_counter = Counter() self.request_counter = Counter()
self.default_sampling_params: Union[dict[str, Any], None] = None self.default_sampling_params: Union[dict[str, Any], None] = None
@staticmethod
def get_engine_class() -> type[LLMEngine]:
if envs.VLLM_USE_V1:
# Lazy import: the v1 package isn't distributed
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
return V1LLMEngine # type: ignore
return LLMEngine
def get_tokenizer(self) -> AnyTokenizer: def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
......
...@@ -154,21 +154,47 @@ async def build_async_engine_client_from_engine_args( ...@@ -154,21 +154,47 @@ async def build_async_engine_client_from_engine_args(
Returns the Client or None if the creation failed. Returns the Client or None if the creation failed.
""" """
# AsyncLLMEngine. # Create the EngineConfig (determines if we can use V1).
if (MQLLMEngineClient.is_unsupported_config(engine_args) usage_context = UsageContext.OPENAI_API_SERVER
or envs.VLLM_USE_V1 or disable_frontend_multiprocessing): vllm_config = engine_args.create_engine_config(usage_context=usage_context)
# V1 AsyncLLM.
if envs.VLLM_USE_V1:
if disable_frontend_multiprocessing:
logger.warning(
"V1 is enabled, but got --disable-frontend-multiprocessing. "
"To disable frontend multiprocessing, set VLLM_USE_V1=0.")
from vllm.v1.engine.async_llm import AsyncLLM
async_llm: Optional[AsyncLLM] = None
try:
async_llm = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_requests=engine_args.disable_log_requests,
disable_log_stats=engine_args.disable_log_stats)
yield async_llm
finally:
if async_llm:
async_llm.shutdown()
# V0 AsyncLLM.
elif (MQLLMEngineClient.is_unsupported_config(vllm_config)
or disable_frontend_multiprocessing):
engine_client: Optional[EngineClient] = None engine_client: Optional[EngineClient] = None
try: try:
engine_client = AsyncLLMEngine.from_engine_args( engine_client = AsyncLLMEngine.from_vllm_config(
engine_args=engine_args, vllm_config=vllm_config,
usage_context=UsageContext.OPENAI_API_SERVER) usage_context=usage_context,
disable_log_requests=engine_args.disable_log_requests,
disable_log_stats=engine_args.disable_log_stats)
yield engine_client yield engine_client
finally: finally:
if engine_client and hasattr(engine_client, "shutdown"): if engine_client and hasattr(engine_client, "shutdown"):
engine_client.shutdown() engine_client.shutdown()
# MQLLMEngine. # V0MQLLMEngine.
else: else:
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing # Make TemporaryDirectory for prometheus multiprocessing
...@@ -199,10 +225,11 @@ async def build_async_engine_client_from_engine_args( ...@@ -199,10 +225,11 @@ async def build_async_engine_client_from_engine_args(
# not actually result in an exitcode being reported. As a result # not actually result in an exitcode being reported. As a result
# we use a shared variable to communicate the information. # we use a shared variable to communicate the information.
engine_alive = multiprocessing.Value('b', True, lock=False) engine_alive = multiprocessing.Value('b', True, lock=False)
engine_process = context.Process(target=run_mp_engine, engine_process = context.Process(
args=(engine_args, target=run_mp_engine,
UsageContext.OPENAI_API_SERVER, args=(vllm_config, UsageContext.OPENAI_API_SERVER, ipc_path,
ipc_path, engine_alive)) engine_args.disable_log_stats,
engine_args.disable_log_requests, engine_alive))
engine_process.start() engine_process.start()
engine_pid = engine_process.pid engine_pid = engine_process.pid
assert engine_pid is not None, "Engine process failed to start." assert engine_pid is not None, "Engine process failed to start."
...@@ -217,8 +244,7 @@ async def build_async_engine_client_from_engine_args( ...@@ -217,8 +244,7 @@ async def build_async_engine_client_from_engine_args(
atexit.register(_cleanup_ipc_path) atexit.register(_cleanup_ipc_path)
# Build RPCClient, which conforms to EngineClient Protocol. # Build RPCClient, which conforms to EngineClient Protocol.
engine_config = engine_args.create_engine_config() build_client = partial(MQLLMEngineClient, ipc_path, vllm_config,
build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
engine_pid) engine_pid)
mq_engine_client = await asyncio.get_running_loop().run_in_executor( mq_engine_client = await asyncio.get_running_loop().run_in_executor(
None, build_client) None, build_client)
......
...@@ -74,7 +74,7 @@ if TYPE_CHECKING: ...@@ -74,7 +74,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: list[str] = [] VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = False VLLM_USE_V1: bool = True
VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
...@@ -522,7 +522,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -522,7 +522,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, use the V1 code path. # If set, use the V1 code path.
"VLLM_USE_V1": "VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))), lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))),
# Pad the fp8 weights to 256 bytes for ROCm # Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING": "VLLM_ROCM_FP8_PADDING":
...@@ -644,3 +644,19 @@ def __getattr__(name: str): ...@@ -644,3 +644,19 @@ def __getattr__(name: str):
def __dir__(): def __dir__():
return list(environment_variables.keys()) return list(environment_variables.keys())
def is_set(name: str):
"""Check if an environment variable is explicitly set."""
if name in environment_variables:
return name in os.environ
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def set_vllm_use_v1(use_v1: bool):
if is_set("VLLM_USE_V1"):
raise ValueError(
"Should not call set_vllm_use_v1() if VLLM_USE_V1 is set "
"explicitly by the user. Please raise this as a Github "
"Issue and explicitly set VLLM_USE_V1=0 or 1.")
os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0"
...@@ -74,7 +74,8 @@ def resolve_transformers_fallback(model_config: ModelConfig, ...@@ -74,7 +74,8 @@ def resolve_transformers_fallback(model_config: ModelConfig,
if not is_transformers_impl_compatible(arch, custom_model_module): if not is_transformers_impl_compatible(arch, custom_model_module):
raise ValueError( raise ValueError(
f"{arch} has no vLLM implementation and the Transformers " f"{arch} has no vLLM implementation and the Transformers "
"implementation is not compatible with vLLM.") "implementation is not compatible with vLLM. Try setting "
"VLLM_USE_V1=0.")
logger.warning( logger.warning(
"%s has no vLLM implementation, falling back to Transformers " "%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and " "implementation. Some features may not be supported and "
......
...@@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -42,7 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP, SupportsV0Only
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -279,7 +279,7 @@ class BloomModel(nn.Module): ...@@ -279,7 +279,7 @@ class BloomModel(nn.Module):
return hidden_states return hidden_states
class BloomForCausalLM(nn.Module, SupportsPP): class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
......
...@@ -3,10 +3,11 @@ ...@@ -3,10 +3,11 @@
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from .interfaces import SupportsV0Only
from .utils import PPMissingLayer from .utils import PPMissingLayer
class GlmForCausalLM(LlamaForCausalLM): class GlmForCausalLM(LlamaForCausalLM, SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
......
...@@ -36,7 +36,7 @@ from vllm.sequence import IntermediateTensors ...@@ -36,7 +36,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP, SupportsV0Only)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings, merge_multimodal_embeddings,
...@@ -405,7 +405,8 @@ class ModifiedWhisperEncoder(WhisperEncoder): ...@@ -405,7 +405,8 @@ class ModifiedWhisperEncoder(WhisperEncoder):
UltravoxMultiModalProcessor, UltravoxMultiModalProcessor,
info=UltravoxProcessingInfo, info=UltravoxProcessingInfo,
dummy_inputs=UltravoxDummyInputsBuilder) dummy_inputs=UltravoxDummyInputsBuilder)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
SupportsV0Only):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
......
...@@ -196,7 +196,8 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -196,7 +196,8 @@ class FlashAttentionImpl(AttentionImpl):
if head_size not in support_head_sizes: if head_size not in support_head_sizes:
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. " f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.") f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
......
...@@ -8,6 +8,7 @@ from typing import Optional, Union ...@@ -8,6 +8,7 @@ from typing import Optional, Union
import numpy as np import numpy as np
import vllm.envs as envs
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
...@@ -49,6 +50,12 @@ class AsyncLLM(EngineClient): ...@@ -49,6 +50,12 @@ class AsyncLLM(EngineClient):
log_requests: bool = True, log_requests: bool = True,
start_engine_loop: bool = True, start_engine_loop: bool = True,
) -> None: ) -> None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
assert start_engine_loop assert start_engine_loop
...@@ -92,22 +99,50 @@ class AsyncLLM(EngineClient): ...@@ -92,22 +99,50 @@ class AsyncLLM(EngineClient):
self.output_handler: Optional[asyncio.Task] = None self.output_handler: Optional[asyncio.Task] = None
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
disable_log_requests: bool = False,
disable_log_stats: bool = False,
) -> "AsyncLLM":
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
# FIXME(rob): refactor VllmConfig to include the StatLoggers
# include StatLogger in the Oracle decision.
if stat_loggers is not None:
raise ValueError("Custom StatLoggers are not yet supported on V1. "
"Explicitly set VLLM_USE_V1=0 to disable V1.")
# Create the LLMEngine.
return cls(
vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
start_engine_loop=start_engine_loop,
log_requests=not disable_log_requests,
log_stats=not disable_log_stats,
usage_context=usage_context,
)
@classmethod @classmethod
def from_engine_args( def from_engine_args(
cls, cls,
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True, start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> "AsyncLLM": ) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs.""" """Create an AsyncLLM from the EngineArgs."""
# Create the engine configs. # Create the engine configs.
if engine_config is None: vllm_config = engine_args.create_engine_config(usage_context)
vllm_config = engine_args.create_engine_config(usage_context)
else:
vllm_config = engine_config
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
# Create the AsyncLLM. # Create the AsyncLLM.
......
...@@ -46,6 +46,13 @@ class LLMEngine: ...@@ -46,6 +46,13 @@ class LLMEngine:
use_cached_outputs: bool = False, use_cached_outputs: bool = False,
multiprocess_mode: bool = False, multiprocess_mode: bool = False,
) -> None: ) -> None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
"This should not happen. As a workaround, try using "
"LLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
...@@ -88,6 +95,26 @@ class LLMEngine: ...@@ -88,6 +95,26 @@ class LLMEngine:
# for v0 compatibility # for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
disable_log_stats: bool = False,
) -> "LLMEngine":
if stat_loggers is not None:
raise NotImplementedError(
"Passing StatLoggers to V1 is not yet supported. "
"Set VLLM_USE_V1=0 and file and issue on Github.")
return cls(vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
log_stats=(not disable_log_stats),
usage_context=usage_context,
stat_loggers=stat_loggers,
multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)
@classmethod @classmethod
def from_engine_args( def from_engine_args(
cls, cls,
......
...@@ -184,7 +184,7 @@ class Processor: ...@@ -184,7 +184,7 @@ class Processor:
# Only applicable to multimodal models with legacy input processor. # Only applicable to multimodal models with legacy input processor.
processed_inputs = self.input_processor(preprocessed_inputs) processed_inputs = self.input_processor(preprocessed_inputs)
self._validate_model_inputs(processed_inputs) self._validate_model_inputs(processed_inputs, lora_request)
if is_encoder_decoder_inputs(processed_inputs): if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = SingletonInputsAdapter( decoder_inputs = SingletonInputsAdapter(
...@@ -200,8 +200,12 @@ class Processor: ...@@ -200,8 +200,12 @@ class Processor:
raise NotImplementedError raise NotImplementedError
assert isinstance(params, SamplingParams) assert isinstance(params, SamplingParams)
# TODO: can we avoid cloning here in multiproc case # TODO: can we avoid cloning here in multiproc case?
sampling_params = params.clone() sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
sampling_params.max_tokens = (self.model_config.max_model_len -
len(decoder_inputs.prompt_token_ids))
sampling_params.update_from_generation_config( sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id) self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer( sampling_params.update_from_tokenizer(
...@@ -296,7 +300,9 @@ class Processor: ...@@ -296,7 +300,9 @@ class Processor:
lora_request=lora_request, lora_request=lora_request,
) )
def _validate_model_inputs(self, inputs: ProcessorInputs): def _validate_model_inputs(self,
inputs: ProcessorInputs,
lora_request: Optional[LoRARequest] = None):
if is_encoder_decoder_inputs(inputs): if is_encoder_decoder_inputs(inputs):
# For encoder-decoder multimodal models, the max_prompt_len # For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length # restricts the decoder prompt length
...@@ -310,6 +316,13 @@ class Processor: ...@@ -310,6 +316,13 @@ class Processor:
if prompt_ids is None or len(prompt_ids) == 0: if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty") raise ValueError("Prompt cannot be empty")
max_input_id = max(prompt_ids)
max_allowed = self.tokenizer.get_lora_tokenizer(
lora_request).max_token_id
if max_input_id > max_allowed:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))
if len(prompt_ids) >= self.model_config.max_model_len: if len(prompt_ids) >= self.model_config.max_model_len:
raise ValueError( raise ValueError(
f"Prompt length of {len(prompt_ids)} is longer than the " f"Prompt length of {len(prompt_ids)} is longer than the "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment