Commit 539aa992 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.2' into v0.6.2-dev

parents 93872128 7193774b
...@@ -8,7 +8,7 @@ from typing import Tuple, Union, cast ...@@ -8,7 +8,7 @@ from typing import Tuple, Union, cast
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -18,9 +18,12 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs, ...@@ -18,9 +18,12 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, CompletionStreamResponse,
ErrorResponse, UsageInfo) ErrorResponse,
RequestResponseMetadata,
UsageInfo)
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing, OpenAIServing,
PromptAdapterPath) PromptAdapterPath)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -43,18 +46,18 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -43,18 +46,18 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__( def __init__(
self, self,
async_engine_client: AsyncEngineClient, engine_client: EngineClient,
model_config: ModelConfig, model_config: ModelConfig,
served_model_names: List[str], base_model_paths: List[BaseModelPath],
*, *,
lora_modules: Optional[List[LoRAModulePath]], lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]], prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
): ):
super().__init__(async_engine_client=async_engine_client, super().__init__(engine_client=engine_client,
model_config=model_config, model_config=model_config,
served_model_names=served_model_names, base_model_paths=base_model_paths,
lora_modules=lora_modules, lora_modules=lora_modules,
prompt_adapters=prompt_adapters, prompt_adapters=prompt_adapters,
request_logger=request_logger, request_logger=request_logger,
...@@ -78,15 +81,25 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -78,15 +81,25 @@ class OpenAIServingCompletion(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
# Return error for unsupported features. # Return error for unsupported features.
if request.suffix is not None: if request.suffix is not None:
return self.create_error_response( return self.create_error_response(
"suffix is not currently supported") "suffix is not currently supported")
model_name = self.served_model_names[0] model_name = self.base_model_paths[0].name
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time()) created_time = int(time.time())
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = [] generators: List[AsyncGenerator[RequestOutput, None]] = []
try: try:
...@@ -95,8 +108,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -95,8 +108,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_adapter_request, prompt_adapter_request,
) = self._maybe_get_adapters(request) ) = self._maybe_get_adapters(request)
tokenizer = await self.async_engine_client.get_tokenizer( tokenizer = await self.engine_client.get_tokenizer(lora_request)
lora_request)
guided_decode_logits_processor = ( guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer)) await self._guided_decode_logits_processor(request, tokenizer))
...@@ -124,8 +136,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -124,8 +136,8 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
is_tracing_enabled = ( is_tracing_enabled = (await
await self.async_engine_client.is_tracing_enabled()) self.engine_client.is_tracing_enabled())
trace_headers = None trace_headers = None
if is_tracing_enabled: if is_tracing_enabled:
trace_headers = extract_trace_headers(raw_request.headers) trace_headers = extract_trace_headers(raw_request.headers)
...@@ -133,7 +145,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -133,7 +145,7 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request.headers): raw_request.headers):
log_tracing_disabled_warning() log_tracing_disabled_warning()
generator = self.async_engine_client.generate( generator = self.engine_client.generate(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
sampling_params, sampling_params,
request_id_item, request_id_item,
...@@ -159,13 +171,15 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -159,13 +171,15 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response # Streaming response
if stream: if stream:
return self.completion_stream_generator(request, return self.completion_stream_generator(
result_generator, request,
request_id, result_generator,
created_time, request_id,
model_name, created_time,
num_prompts=len(prompts), model_name,
tokenizer=tokenizer) num_prompts=len(prompts),
tokenizer=tokenizer,
request_metadata=request_metadata)
# Non-streaming response # Non-streaming response
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
...@@ -192,6 +206,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -192,6 +206,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time, created_time,
model_name, model_name,
tokenizer, tokenizer,
request_metadata,
) )
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
...@@ -221,6 +236,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -221,6 +236,7 @@ class OpenAIServingCompletion(OpenAIServing):
model_name: str, model_name: str,
num_prompts: int, num_prompts: int,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n num_choices = 1 if request.n is None else request.n
previous_text_lens = [0] * num_choices * num_prompts previous_text_lens = [0] * num_choices * num_prompts
...@@ -340,6 +356,14 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -340,6 +356,14 @@ class OpenAIServingCompletion(OpenAIServing):
exclude_unset=False, exclude_none=True)) exclude_unset=False, exclude_none=True))
yield f"data: {final_usage_data}\n\n" yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e)) data = self.create_streaming_error_response(str(e))
...@@ -354,6 +378,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -354,6 +378,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int, created_time: int,
model_name: str, model_name: str,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> CompletionResponse: ) -> CompletionResponse:
choices: List[CompletionResponseChoice] = [] choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
...@@ -427,6 +452,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -427,6 +452,8 @@ class OpenAIServingCompletion(OpenAIServing):
total_tokens=num_prompt_tokens + num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens,
) )
request_metadata.final_usage_info = usage
return CompletionResponse( return CompletionResponse(
id=request_id, id=request_id,
created=created_time, created=created_time,
......
...@@ -8,13 +8,13 @@ from fastapi import Request ...@@ -8,13 +8,13 @@ from fastapi import Request
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingRequest, from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse, EmbeddingResponse,
EmbeddingResponseData, EmbeddingResponseData,
ErrorResponse, UsageInfo) ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid from vllm.utils import merge_async_iterators, random_uuid
...@@ -71,15 +71,15 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -71,15 +71,15 @@ class OpenAIServingEmbedding(OpenAIServing):
def __init__( def __init__(
self, self,
async_engine_client: AsyncEngineClient, engine_client: EngineClient,
model_config: ModelConfig, model_config: ModelConfig,
served_model_names: List[str], base_model_paths: List[BaseModelPath],
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
): ):
super().__init__(async_engine_client=async_engine_client, super().__init__(engine_client=engine_client,
model_config=model_config, model_config=model_config,
served_model_names=served_model_names, base_model_paths=base_model_paths,
lora_modules=None, lora_modules=None,
prompt_adapters=None, prompt_adapters=None,
request_logger=request_logger) request_logger=request_logger)
...@@ -118,8 +118,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -118,8 +118,7 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_adapter_request, prompt_adapter_request,
) = self._maybe_get_adapters(request) ) = self._maybe_get_adapters(request)
tokenizer = await self.async_engine_client.get_tokenizer( tokenizer = await self.engine_client.get_tokenizer(lora_request)
lora_request)
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params()
...@@ -144,7 +143,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -144,7 +143,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"Prompt adapter is not supported " "Prompt adapter is not supported "
"for embedding models") "for embedding models")
generator = self.async_engine_client.encode( generator = self.engine_client.encode(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
pooling_params, pooling_params,
request_id_item, request_id_item,
......
...@@ -8,7 +8,7 @@ from pydantic import Field ...@@ -8,7 +8,7 @@ from pydantic import Field
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -39,6 +39,12 @@ from vllm.utils import AtomicCounter ...@@ -39,6 +39,12 @@ from vllm.utils import AtomicCounter
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class BaseModelPath:
name: str
model_path: str
@dataclass @dataclass
class PromptAdapterPath: class PromptAdapterPath:
name: str name: str
...@@ -49,6 +55,7 @@ class PromptAdapterPath: ...@@ -49,6 +55,7 @@ class PromptAdapterPath:
class LoRAModulePath: class LoRAModulePath:
name: str name: str
path: str path: str
base_model_name: Optional[str] = None
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest, AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
...@@ -64,9 +71,9 @@ class OpenAIServing: ...@@ -64,9 +71,9 @@ class OpenAIServing:
def __init__( def __init__(
self, self,
async_engine_client: AsyncEngineClient, engine_client: EngineClient,
model_config: ModelConfig, model_config: ModelConfig,
served_model_names: List[str], base_model_paths: List[BaseModelPath],
*, *,
lora_modules: Optional[List[LoRAModulePath]], lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]], prompt_adapters: Optional[List[PromptAdapterPath]],
...@@ -75,21 +82,24 @@ class OpenAIServing: ...@@ -75,21 +82,24 @@ class OpenAIServing:
): ):
super().__init__() super().__init__()
self.async_engine_client = async_engine_client self.engine_client = engine_client
self.model_config = model_config self.model_config = model_config
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
self.served_model_names = served_model_names self.base_model_paths = base_model_paths
self.lora_id_counter = AtomicCounter(0) self.lora_id_counter = AtomicCounter(0)
self.lora_requests = [] self.lora_requests = []
if lora_modules is not None: if lora_modules is not None:
self.lora_requests = [ self.lora_requests = [
LoRARequest( LoRARequest(lora_name=lora.name,
lora_name=lora.name, lora_int_id=i,
lora_int_id=i, lora_path=lora.path,
lora_path=lora.path, base_model_name=lora.base_model_name
) for i, lora in enumerate(lora_modules, start=1) if lora.base_model_name
and self._is_model_supported(lora.base_model_name)
else self.base_model_paths[0].name)
for i, lora in enumerate(lora_modules, start=1)
] ]
self.prompt_adapter_requests = [] self.prompt_adapter_requests = []
...@@ -112,21 +122,23 @@ class OpenAIServing: ...@@ -112,21 +122,23 @@ class OpenAIServing:
async def show_available_models(self) -> ModelList: async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
model_cards = [ model_cards = [
ModelCard(id=served_model_name, ModelCard(id=base_model.name,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
root=self.served_model_names[0], root=base_model.model_path,
permission=[ModelPermission()]) permission=[ModelPermission()])
for served_model_name in self.served_model_names for base_model in self.base_model_paths
] ]
lora_cards = [ lora_cards = [
ModelCard(id=lora.lora_name, ModelCard(id=lora.lora_name,
root=self.served_model_names[0], root=lora.local_path,
parent=lora.base_model_name if lora.base_model_name else
self.base_model_paths[0].name,
permission=[ModelPermission()]) permission=[ModelPermission()])
for lora in self.lora_requests for lora in self.lora_requests
] ]
prompt_adapter_cards = [ prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name, ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.served_model_names[0], root=self.base_model_paths[0].name,
permission=[ModelPermission()]) permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests for prompt_adapter in self.prompt_adapter_requests
] ]
...@@ -159,7 +171,7 @@ class OpenAIServing: ...@@ -159,7 +171,7 @@ class OpenAIServing:
async def _guided_decode_logits_processor( async def _guided_decode_logits_processor(
self, request: Union[ChatCompletionRequest, CompletionRequest], self, request: Union[ChatCompletionRequest, CompletionRequest],
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
decoding_config = await self.async_engine_client.get_decoding_config() decoding_config = await self.engine_client.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \ guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend or decoding_config.guided_decoding_backend
return await get_guided_decoding_logits_processor( return await get_guided_decoding_logits_processor(
...@@ -169,7 +181,7 @@ class OpenAIServing: ...@@ -169,7 +181,7 @@ class OpenAIServing:
self, self,
request: AnyRequest, request: AnyRequest,
) -> Optional[ErrorResponse]: ) -> Optional[ErrorResponse]:
if request.model in self.served_model_names: if self._is_model_supported(request.model):
return None return None
if request.model in [lora.lora_name for lora in self.lora_requests]: if request.model in [lora.lora_name for lora in self.lora_requests]:
return None return None
...@@ -187,7 +199,7 @@ class OpenAIServing: ...@@ -187,7 +199,7 @@ class OpenAIServing:
self, request: AnyRequest self, request: AnyRequest
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[ ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
None, PromptAdapterRequest]]: None, PromptAdapterRequest]]:
if request.model in self.served_model_names: if self._is_model_supported(request.model):
return None, None return None, None
for lora in self.lora_requests: for lora in self.lora_requests:
if request.model == lora.lora_name: if request.model == lora.lora_name:
...@@ -480,3 +492,6 @@ class OpenAIServing: ...@@ -480,3 +492,6 @@ class OpenAIServing:
if lora_request.lora_name != lora_name if lora_request.lora_name != lora_name
] ]
return f"Success: LoRA adapter '{lora_name}' removed successfully." return f"Success: LoRA adapter '{lora_name}' removed successfully."
def _is_model_supported(self, model_name):
return any(model.name == model_name for model in self.base_model_paths)
from typing import List, Optional, Union from typing import List, Optional, Union
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (apply_hf_chat_template, from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
apply_mistral_chat_template, apply_mistral_chat_template,
load_chat_template, load_chat_template,
...@@ -16,7 +16,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest, ...@@ -16,7 +16,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
TokenizeRequest, TokenizeRequest,
TokenizeResponse) TokenizeResponse)
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing) OpenAIServing)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.transformers_utils.tokenizer import MistralTokenizer
...@@ -29,17 +30,17 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -29,17 +30,17 @@ class OpenAIServingTokenization(OpenAIServing):
def __init__( def __init__(
self, self,
async_engine_client: AsyncEngineClient, engine_client: EngineClient,
model_config: ModelConfig, model_config: ModelConfig,
served_model_names: List[str], base_model_paths: List[BaseModelPath],
*, *,
lora_modules: Optional[List[LoRAModulePath]], lora_modules: Optional[List[LoRAModulePath]],
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
chat_template: Optional[str], chat_template: Optional[str],
): ):
super().__init__(async_engine_client=async_engine_client, super().__init__(engine_client=engine_client,
model_config=model_config, model_config=model_config,
served_model_names=served_model_names, base_model_paths=base_model_paths,
lora_modules=lora_modules, lora_modules=lora_modules,
prompt_adapters=None, prompt_adapters=None,
request_logger=request_logger) request_logger=request_logger)
...@@ -66,7 +67,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -66,7 +67,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request, prompt_adapter_request,
) = self._maybe_get_adapters(request) ) = self._maybe_get_adapters(request)
tokenizer = await self.async_engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
prompt: Union[str, List[int]] prompt: Union[str, List[int]]
if isinstance(request, TokenizeChatRequest): if isinstance(request, TokenizeChatRequest):
...@@ -132,7 +133,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -132,7 +133,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request, prompt_adapter_request,
) = self._maybe_get_adapters(request) ) = self._maybe_get_adapters(request)
tokenizer = await self.async_engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
self._log_inputs(request_id, self._log_inputs(request_id,
request.tokens, request.tokens,
......
...@@ -59,10 +59,12 @@ if TYPE_CHECKING: ...@@ -59,10 +59,12 @@ if TYPE_CHECKING:
VERBOSE: bool = False VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 VLLM_RPC_TIMEOUT: int = 10000 # ms
VLLM_PLUGINS: Optional[List[str]] = None VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -206,6 +208,10 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -206,6 +208,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in
("true", "1")), ("true", "1")),
# If set, allowing the use of deprecated beam search implementation
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH":
lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BEAM_SEARCH", "0") == "1",
# Internal flag to enable Dynamo graph capture # Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE": "VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
...@@ -214,6 +220,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -214,6 +220,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
("true", "1")), ("true", "1")),
# Internal flag to control whether we use custom op,
# or use the native pytorch implementation
"VLLM_TEST_COMPILE_NO_CUSTOM_OPS":
lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")),
# Internal flag to enable Dynamo fullgraph capture # Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
lambda: bool( lambda: bool(
...@@ -399,8 +410,8 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -399,8 +410,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Time in ms for the zmq client to wait for a response from the backend # Time in ms for the zmq client to wait for a response from the backend
# server for simple data operations # server for simple data operations
"VLLM_RPC_GET_DATA_TIMEOUT_MS": "VLLM_RPC_TIMEOUT":
lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")), lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")),
# a list of plugin names to load, separated by commas. # a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded # if this is not set, it means all plugins will be loaded
......
...@@ -106,6 +106,7 @@ class CPUExecutor(ExecutorBase): ...@@ -106,6 +106,7 @@ class CPUExecutor(ExecutorBase):
)) for rank in range(1, world_size) )) for rank in range(1, world_size)
] ]
self.worker_monitor = None
if world_size != 1 or is_async: if world_size != 1 or is_async:
if is_async: if is_async:
async_worker_list = self.workers + [self.driver_worker] async_worker_list = self.workers + [self.driver_worker]
......
import asyncio import asyncio
import os import os
import signal
import threading
import weakref
from functools import partial from functools import partial
from typing import Any, List, Optional from typing import Any, List, Optional
...@@ -108,17 +105,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): ...@@ -108,17 +105,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
# Set up signal handlers to shutdown the executor cleanly # Set up signal handlers to shutdown the executor cleanly
# sometimes gc does not work well # sometimes gc does not work well
# Use weakref to avoid holding a reference to self
ref = weakref.ref(self)
def shutdown(signum, frame):
if executor := ref():
executor.shutdown()
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
self.driver_worker = self._create_worker( self.driver_worker = self._create_worker(
distributed_init_method=distributed_init_method) distributed_init_method=distributed_init_method)
self._run_workers("init_device") self._run_workers("init_device")
......
...@@ -76,8 +76,7 @@ class ResultHandler(threading.Thread): ...@@ -76,8 +76,7 @@ class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)""" """Handle results from all workers (in background thread)"""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(daemon=False) super().__init__(daemon=True)
# super().__init__(daemon=True)
self.result_queue = mp.Queue() self.result_queue = mp.Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
...@@ -101,8 +100,7 @@ class WorkerMonitor(threading.Thread): ...@@ -101,8 +100,7 @@ class WorkerMonitor(threading.Thread):
def __init__(self, workers: List['ProcessWorkerWrapper'], def __init__(self, workers: List['ProcessWorkerWrapper'],
result_handler: ResultHandler): result_handler: ResultHandler):
super().__init__(daemon=False) super().__init__(daemon=True)
# super().__init__(daemon=True)
self.workers = workers self.workers = workers
self.result_handler = result_handler self.result_handler = result_handler
self._close = False self._close = False
...@@ -114,30 +112,16 @@ class WorkerMonitor(threading.Thread): ...@@ -114,30 +112,16 @@ class WorkerMonitor(threading.Thread):
self._close = True self._close = True
# Kill / cleanup all workers # Kill / cleanup all workers
# for worker in self.workers: for worker in self.workers:
# process = worker.process process = worker.process
# if process.sentinel in dead_sentinels: if process.sentinel in dead_sentinels:
# process.join(JOIN_TIMEOUT_S) process.join(JOIN_TIMEOUT_S)
# if process.exitcode is not None and process.exitcode != 0: if process.exitcode is not None and process.exitcode != 0:
# logger.error("Worker %s pid %s died, exit code: %s", logger.error("Worker %s pid %s died, exit code: %s",
# process.name, process.pid, process.exitcode) process.name, process.pid, process.exitcode)
if not sys.is_finalizing():
# Kill / cleanup all workers
died_count = 0
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
died_count += 1
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid, process.exitcode)
if died_count < len(self.workers):
logger.info(
"Killing remaining local vLLM worker processes")
# Cleanup any remaining workers # Cleanup any remaining workers
# logger.info("Killing local vLLM worker processes") if logger:
logger.info("Killing local vLLM worker processes")
for worker in self.workers: for worker in self.workers:
worker.kill_worker() worker.kill_worker()
# Must be done after worker task queues are all closed # Must be done after worker task queues are all closed
...@@ -184,6 +168,8 @@ class ProcessWorkerWrapper: ...@@ -184,6 +168,8 @@ class ProcessWorkerWrapper:
self.tasks[task_id] = future self.tasks[task_id] = future
try: try:
self._task_queue.put((task_id, method, args, kwargs)) self._task_queue.put((task_id, method, args, kwargs))
except SystemExit:
raise
except BaseException as e: except BaseException as e:
del self.tasks[task_id] del self.tasks[task_id]
raise ChildProcessError("worker died") from e raise ChildProcessError("worker died") from e
...@@ -238,6 +224,10 @@ def _run_worker_process( ...@@ -238,6 +224,10 @@ def _run_worker_process(
try: try:
executor = getattr(worker, method) executor = getattr(worker, method)
output = executor(*args, **kwargs) output = executor(*args, **kwargs)
except SystemExit:
raise
except KeyboardInterrupt:
break
except BaseException as e: except BaseException as e:
tb = traceback.format_exc() tb = traceback.format_exc()
logger.error( logger.error(
...@@ -278,4 +268,4 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: ...@@ -278,4 +268,4 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
file.start_new_line = False # type: ignore[attr-defined] file.start_new_line = False # type: ignore[attr-defined]
file.start_new_line = True # type: ignore[attr-defined] file.start_new_line = True # type: ignore[attr-defined]
file.write = write_with_prefix # type: ignore[method-assign] file.write = write_with_prefix # type: ignore[method-assign]
\ No newline at end of file
...@@ -437,8 +437,10 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -437,8 +437,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
required_version = version.parse("2.35") required_version = version.parse("2.35")
current_version = version.parse( current_version = version.parse(
pkg_resources.get_distribution("ray").version) pkg_resources.get_distribution("ray").version)
if current_version < required_version: # TODO: update the constraint once we adapt to the backward
raise ValueError(f"Ray version {required_version} or greater is " # incompatible API change from ray 2.36
if current_version != required_version:
raise ValueError(f"Ray version {required_version} is "
f"required, but found {current_version}") f"required, but found {current_version}")
import importlib.util import importlib.util
......
...@@ -26,6 +26,8 @@ logger = init_logger(__name__) ...@@ -26,6 +26,8 @@ logger = init_logger(__name__)
class RayTPUExecutor(TPUExecutor): class RayTPUExecutor(TPUExecutor):
uses_ray: bool = True
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# This is non-None when the execute model loop is running # This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case. # in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
...@@ -68,8 +70,12 @@ class RayTPUExecutor(TPUExecutor): ...@@ -68,8 +70,12 @@ class RayTPUExecutor(TPUExecutor):
) )
assert self.speculative_config is None assert self.speculative_config is None
worker_module_name = "vllm.worker.tpu_worker" if self.scheduler_config.is_multi_step:
worker_class_name = "TPUWorker" worker_module_name = "vllm.worker.multi_step_tpu_worker"
worker_class_name = "MultiStepTPUWorker"
else:
worker_module_name = "vllm.worker.tpu_worker"
worker_class_name = "TPUWorker"
# GKE does not fetch environment information from metadata server # GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we # and instead sets these from within the Ray process. Therefore we
......
...@@ -18,9 +18,14 @@ PG_WAIT_TIMEOUT = 1800 ...@@ -18,9 +18,14 @@ PG_WAIT_TIMEOUT = 1800
try: try:
import ray import ray
from ray._private.state import available_resources_per_node
from ray.util import placement_group_table from ray.util import placement_group_table
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
try:
from ray._private.state import available_resources_per_node
except ImportError:
# Ray 2.9.x doesn't expose `available_resources_per_node`
from ray._private.state import state as _state
available_resources_per_node = _state._available_resources_per_node
class RayWorkerWrapper(WorkerWrapperBase): class RayWorkerWrapper(WorkerWrapperBase):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be """Ray wrapper for vllm.worker.Worker, allowing Worker to be
......
...@@ -62,11 +62,17 @@ class TPUExecutor(ExecutorBase): ...@@ -62,11 +62,17 @@ class TPUExecutor(ExecutorBase):
rank: int = 0, rank: int = 0,
distributed_init_method: Optional[str] = None, distributed_init_method: Optional[str] = None,
): ):
from vllm.worker.tpu_worker import TPUWorker if self.scheduler_config.is_multi_step:
from vllm.worker.multi_step_tpu_worker import MultiStepTPUWorker
worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank, worker = MultiStepTPUWorker(**self._get_worker_kwargs(
distributed_init_method)) local_rank, rank, distributed_init_method))
return worker return worker
else:
from vllm.worker.tpu_worker import TPUWorker
worker = TPUWorker(**self._get_worker_kwargs(
local_rank, rank, distributed_init_method))
return worker
def initialize_cache( def initialize_cache(
self, self,
......
...@@ -139,6 +139,12 @@ class EncoderDecoderLLMInputs(LLMInputs): ...@@ -139,6 +139,12 @@ class EncoderDecoderLLMInputs(LLMInputs):
available. available.
""" """
encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
"""
Optional multi-modal data to pass to the encoder model,
if the model supports it.
"""
_T1 = TypeVar("_T1", _T1 = TypeVar("_T1",
bound=SingletonPromptInputs, bound=SingletonPromptInputs,
......
...@@ -128,6 +128,7 @@ class InputPreprocessor: ...@@ -128,6 +128,7 @@ class InputPreprocessor:
def _prepare_decoder_input_ids_for_generation( def _prepare_decoder_input_ids_for_generation(
self, self,
decoder_input_ids: Optional[List[int]], decoder_input_ids: Optional[List[int]],
force_bos: bool = True,
) -> List[int]: ) -> List[int]:
""" """
Prepares `decoder_input_ids` for generation with encoder-decoder models. Prepares `decoder_input_ids` for generation with encoder-decoder models.
...@@ -157,8 +158,8 @@ class InputPreprocessor: ...@@ -157,8 +158,8 @@ class InputPreprocessor:
# use decoder_start_token_id as decoder_input_ids # use decoder_start_token_id as decoder_input_ids
decoder_input_ids = self._get_default_enc_dec_decoder_prompt() decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
if (len(decoder_input_ids) == 0 if force_bos and (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id): or decoder_input_ids[0] != decoder_start_token_id):
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids return decoder_input_ids
...@@ -295,18 +296,25 @@ class InputPreprocessor: ...@@ -295,18 +296,25 @@ class InputPreprocessor:
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
if encoder_mm_data is not None or decoder_mm_data is not None: if decoder_mm_data is not None:
raise ValueError("Multi-modal encoder-decoder models are " raise ValueError(
"not supported yet") "Multi-modality decoder inputs of encoder-decoder models are "
"not supported yet")
decoder_prompt_ids = ( # For Multi-Modal models (e.g., mllama), the text input can be
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) # <|image|><|begin_of_text|>hello world. And we should not add
# another <|begin_of_text|> to the beginning.
decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation(
decoder_prompt_ids,
force_bos=(encoder_mm_data is None and decoder_mm_data is None)))
return EncoderDecoderLLMInputs( return EncoderDecoderLLMInputs(
prompt_token_ids=decoder_prompt_ids, prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt, prompt=decoder_prompt,
multi_modal_data=decoder_mm_data,
encoder_prompt_token_ids=encoder_prompt_ids, encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt, encoder_prompt=encoder_prompt,
encoder_multi_modal_data=encoder_mm_data,
) )
def _process_encoder_decoder_prompt( def _process_encoder_decoder_prompt(
......
import functools import functools
from array import array
from collections import UserDict from collections import UserDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
...@@ -10,6 +9,7 @@ from transformers import PretrainedConfig ...@@ -10,6 +9,7 @@ from transformers import PretrainedConfig
from typing_extensions import TypeVar from typing_extensions import TypeVar
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_allowed_kwarg_only_overrides
from .data import LLMInputs from .data import LLMInputs
...@@ -22,10 +22,6 @@ logger = init_logger(__name__) ...@@ -22,10 +22,6 @@ logger = init_logger(__name__)
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig) C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
# We cannot import it here because of circular dependencies.
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
@dataclass(frozen=True) @dataclass(frozen=True)
class InputContext: class InputContext:
...@@ -73,12 +69,17 @@ class DummyDataFactory(Protocol): ...@@ -73,12 +69,17 @@ class DummyDataFactory(Protocol):
ctx: InputContext, ctx: InputContext,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
**mm_processor_kwargs: Any,
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
""" """
Create dummy data to be inputted into the model. Create dummy data to be inputted into the model.
Note: Note:
:data:`InputProcessor` is not applied to the dummy data. :data:`InputProcessor` is not applied to the dummy data.
The :code:`mm_processor_kwargs` are overrides provided at
initialization time to values in the config whose values
may affect the number of tokens per instance.
""" """
... ...
...@@ -111,6 +112,8 @@ class InputRegistry: ...@@ -111,6 +112,8 @@ class InputRegistry:
def __init__(self) -> None: def __init__(self) -> None:
self._dummy_factories_by_model_type: Dict[Type[nn.Module], self._dummy_factories_by_model_type: Dict[Type[nn.Module],
DummyDataFactory] = {} DummyDataFactory] = {}
self._dummy_encoder_factories_by_model_type: Dict[
Type[nn.Module], DummyDataFactory] = {}
self._input_processors_by_model_type: Dict[Type[nn.Module], self._input_processors_by_model_type: Dict[Type[nn.Module],
InputProcessor] = {} InputProcessor] = {}
...@@ -130,8 +133,7 @@ class InputRegistry: ...@@ -130,8 +133,7 @@ class InputRegistry:
# Avoid circular import # Avoid circular import
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
dummy_seq_data = SequenceData( dummy_seq_data = SequenceData.from_token_counts((0, seq_len))
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len)
dummy_multi_modal_data = None dummy_multi_modal_data = None
return dummy_seq_data, dummy_multi_modal_data return dummy_seq_data, dummy_multi_modal_data
...@@ -158,11 +160,48 @@ class InputRegistry: ...@@ -158,11 +160,48 @@ class InputRegistry:
return wrapper return wrapper
def _get_dummy_data_factory(self, model_cls: Type[nn.Module]):
return self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
def register_dummy_encoder_data(self, factory: DummyDataFactory):
"""
Register a dummy encoder data factory to a model class
This is similar to :meth:`~register_dummy_data`, but for encoder input.
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_encoder_factories_by_model_type:
logger.warning(
"Model class %s already has dummy encoder data "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._dummy_encoder_factories_by_model_type[model_cls] = factory
return model_cls
return wrapper
def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]):
if model_cls in self._dummy_encoder_factories_by_model_type:
dummy_factory = self._dummy_encoder_factories_by_model_type[
model_cls]
else:
logger.warning(
"No dummy encoder data factory registered to %s. "
"Using the dummy data factory for the model instead.",
model_cls)
dummy_factory = self._get_dummy_data_factory(model_cls)
return dummy_factory
def dummy_data_for_profiling( def dummy_data_for_profiling(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
seq_len: int, seq_len: int,
mm_registry: "MultiModalRegistry", mm_registry: "MultiModalRegistry",
is_encoder_data: bool = False,
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
""" """
Create dummy data for profiling the memory usage of a model. Create dummy data for profiling the memory usage of a model.
...@@ -180,22 +219,29 @@ class InputRegistry: ...@@ -180,22 +219,29 @@ class InputRegistry:
from vllm.model_executor.model_loader import get_model_architecture from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config) model_cls, _ = get_model_architecture(model_config)
dummy_factory = self._dummy_factories_by_model_type \ if is_encoder_data:
.get(model_cls, self._default_dummy_data_factory) dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
else:
dummy_factory = self._get_dummy_data_factory(model_cls)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
dummy_factory, overrides=model_config.mm_processor_kwargs)
seq_data, mm_data = dummy_factory( seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len,
InputContext(model_config), _MultiModalCounts(mm_counts),
seq_len, **mm_processor_kwargs)
_MultiModalCounts(mm_counts),
)
# Having more tokens is over-conservative but otherwise fine # Having more tokens is over-conservative but otherwise fine
num_tokens = seq_data.prompt_token_ids num_tokens = seq_data.prompt_token_ids
assert len(num_tokens) >= seq_len, ( if len(num_tokens) < seq_len:
f"Expected at least {seq_len} dummy tokens for profiling, " if is_encoder_data:
f"but found {len(num_tokens)} tokens instead.") logger.warning(
"Expected at least %d dummy encoder tokens for profiling, "
"but found %d tokens instead.", seq_len, len(num_tokens))
else:
raise AssertionError(
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(num_tokens)} tokens instead.")
if mm_data is not None: if mm_data is not None:
for k, v in mm_data.items(): for k, v in mm_data.items():
num_items = len(v) if isinstance(v, list) else 1 num_items = len(v) if isinstance(v, list) else 1
...@@ -235,6 +281,10 @@ class InputRegistry: ...@@ -235,6 +281,10 @@ class InputRegistry:
return wrapper return wrapper
def _get_model_input_processor(self, model_cls: Type[nn.Module]):
return self._input_processors_by_model_type \
.get(model_cls, self._default_input_processor)
def process_input(self, model_config: "ModelConfig", def process_input(self, model_config: "ModelConfig",
inputs: LLMInputs) -> LLMInputs: inputs: LLMInputs) -> LLMInputs:
""" """
...@@ -249,15 +299,17 @@ class InputRegistry: ...@@ -249,15 +299,17 @@ class InputRegistry:
from vllm.model_executor.model_loader import get_model_architecture from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config) model_cls, _ = get_model_architecture(model_config)
processor = self._get_model_input_processor(model_cls)
processor = self._input_processors_by_model_type \ mm_processor_kwargs = get_allowed_kwarg_only_overrides(
.get(model_cls, self._default_input_processor) processor, overrides=model_config.mm_processor_kwargs)
return processor(InputContext(model_config), inputs) return processor(InputContext(model_config), inputs,
**mm_processor_kwargs)
def create_input_processor(self, model_config: "ModelConfig"): def create_input_processor(self, model_config: "ModelConfig"):
""" """
Create an input processor (see :meth:`process_input`) for a Create an input processor (see :meth:`_process_input`) for a
specific model. specific model.
""" """
return functools.partial(self.process_input, model_config) return functools.partial(self.process_input, model_config)
...@@ -100,7 +100,7 @@ def _bgmv_expand( ...@@ -100,7 +100,7 @@ def _bgmv_expand(
corresponding to each batch, An index of -1 means no lora should be corresponding to each batch, An index of -1 means no lora should be
applied. applied.
batches (int): batch size batches (int): batch size
add_inputs (bool, optional): Defaults to False. adds the final lora add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output. results to the output.
""" """
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
......
...@@ -104,7 +104,7 @@ def _bgmv_expand_slice( ...@@ -104,7 +104,7 @@ def _bgmv_expand_slice(
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be corresponding to each batch, An index of -1 means no lora should be
applied. applied.
slice_offst (int): output_tensor's offst slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size slice_size (int): current output_tensor's size
batches (int): batch size batches (int): batch size
add_inputs (bool, optional): Defaults to False. add_inputs (bool, optional): Defaults to False.
......
...@@ -106,6 +106,7 @@ def _sgmv_expand( ...@@ -106,6 +106,7 @@ def _sgmv_expand(
lora_indices_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor,
batches: int, batches: int,
max_seq_length: int, max_seq_length: int,
token_nums: int,
add_inputs: bool = False, add_inputs: bool = False,
) -> None: ) -> None:
""" """
...@@ -115,17 +116,19 @@ def _sgmv_expand( ...@@ -115,17 +116,19 @@ def _sgmv_expand(
output_tensor (torch.Tensor): output tensor output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10]. [0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be corresponding to each batch. An index of -1 means no lora should be
applied. applied.
batches (int): batch size batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences max_seq_length (int): The max sequence lengths of the sequences in the
in the batch batch.
add_inputs (bool, optional): Defaults to False. adds the final lora token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output. results to the output.
""" """
...@@ -134,6 +137,7 @@ def _sgmv_expand( ...@@ -134,6 +137,7 @@ def _sgmv_expand(
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
] ]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1) assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches assert lora_indices_tensor.size(0) == batches
......
...@@ -112,6 +112,7 @@ def _sgmv_expand_slice( ...@@ -112,6 +112,7 @@ def _sgmv_expand_slice(
lora_indices_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor,
batches: int, batches: int,
max_seq_length: int, max_seq_length: int,
token_nums: int,
slice_offset: int, slice_offset: int,
slice_size: int, slice_size: int,
add_inputs: bool = False, add_inputs: bool = False,
...@@ -124,20 +125,22 @@ def _sgmv_expand_slice( ...@@ -124,20 +125,22 @@ def _sgmv_expand_slice(
output_tensor (torch.Tensor): output tensor output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10]. [0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be corresponding to each batch. An index of -1 means no lora should be
applied. applied.
batches (int): batch size batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences max_seq_length (int): The max sequence lengths of the sequences
in the batch in the batch
slice_offst (int): output_tensor's offst token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size slice_size (int): current output_tensor's size
add_inputs (bool, optional): Defaults to False. adds the final lora add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.. results to the output.
""" """
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
...@@ -145,6 +148,7 @@ def _sgmv_expand_slice( ...@@ -145,6 +148,7 @@ def _sgmv_expand_slice(
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
] ]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1) assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches assert lora_indices_tensor.size(0) == batches
......
...@@ -110,6 +110,7 @@ def _sgmv_shrink( ...@@ -110,6 +110,7 @@ def _sgmv_shrink(
lora_indices_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor,
batches: int, batches: int,
max_seq_length: int, max_seq_length: int,
token_nums: int,
scaling: float, scaling: float,
) -> None: ) -> None:
""" """
...@@ -120,17 +121,19 @@ def _sgmv_shrink( ...@@ -120,17 +121,19 @@ def _sgmv_shrink(
output_tensor (torch.Tensor): output tensor output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4]. [0, 4].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be corresponding to each batch. An index of -1 means no lora should be
applied. applied.
batches (int): batch size batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences max_seq_length (int): The max sequence lengths of the sequences in the
in the batch batch.
scaling (float): Scaling factor. token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
scaling (float): Scaling factor.
""" """
assert inputs.dtype == lora_a_weights.dtype assert inputs.dtype == lora_a_weights.dtype
assert inputs.dtype in [torch.float16, torch.bfloat16] assert inputs.dtype in [torch.float16, torch.bfloat16]
...@@ -138,6 +141,7 @@ def _sgmv_shrink( ...@@ -138,6 +141,7 @@ def _sgmv_shrink(
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
] ]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_a_weights.size(-1) assert inputs.size(1) == lora_a_weights.size(-1)
assert b_seq_start_loc.size(0) == batches assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches assert lora_indices_tensor.size(0) == batches
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment