"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "2cd4d58df4cb4187e36ba9bdabc2819e6f579848"
Unverified Commit b96f7314 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Pass Renderer to Input Processor (#34329)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent ced2a92f
...@@ -125,6 +125,7 @@ class TestInitializeToolSessions: ...@@ -125,6 +125,7 @@ class TestInitializeToolSessions:
engine_client = MagicMock() engine_client = MagicMock()
model_config = MagicMock() model_config = MagicMock()
model_config.max_model_len = 100
model_config.hf_config.model_type = "test" model_config.hf_config.model_type = "test"
model_config.get_diff_sampling_param.return_value = {} model_config.get_diff_sampling_param.return_value = {}
engine_client.model_config = model_config engine_client.model_config = model_config
...@@ -212,6 +213,7 @@ class TestValidateGeneratorInput: ...@@ -212,6 +213,7 @@ class TestValidateGeneratorInput:
engine_client = MagicMock() engine_client = MagicMock()
model_config = MagicMock() model_config = MagicMock()
model_config.max_model_len = 100
model_config.hf_config.model_type = "test" model_config.hf_config.model_type = "test"
model_config.get_diff_sampling_param.return_value = {} model_config.get_diff_sampling_param.return_value = {}
engine_client.model_config = model_config engine_client.model_config = model_config
...@@ -231,9 +233,6 @@ class TestValidateGeneratorInput: ...@@ -231,9 +233,6 @@ class TestValidateGeneratorInput:
chat_template_content_format="auto", chat_template_content_format="auto",
) )
# Set max_model_len for testing
instance.max_model_len = 100
return instance return instance
def test_validate_generator_input(self, serving_responses_instance): def test_validate_generator_input(self, serving_responses_instance):
......
...@@ -507,7 +507,8 @@ def test_apc_single_prompt_block_align_alignment( ...@@ -507,7 +507,8 @@ def test_apc_single_prompt_block_align_alignment(
vllm_runner_kwargs["enable_prefix_caching"] = True vllm_runner_kwargs["enable_prefix_caching"] = True
with vllm_runner(**vllm_runner_kwargs) as vllm_model: with vllm_runner(**vllm_runner_kwargs) as vllm_model:
# Retrieve the default mamba state block size # Retrieve the default mamba state block size
mamba_block_size = vllm_model.llm.llm_engine.cache_config.mamba_block_size vllm_config = vllm_model.llm.llm_engine.vllm_config
mamba_block_size = vllm_config.cache_config.mamba_block_size
# In case the hybrid model does not have the # In case the hybrid model does not have the
# "mamba_block_size" assume a fixed constant # "mamba_block_size" assume a fixed constant
...@@ -660,7 +661,8 @@ def test_apc_multiple_prompts_block_align_alignment( ...@@ -660,7 +661,8 @@ def test_apc_multiple_prompts_block_align_alignment(
vllm_runner_kwargs["enable_prefix_caching"] = True vllm_runner_kwargs["enable_prefix_caching"] = True
with vllm_runner(**vllm_runner_kwargs) as vllm_model: with vllm_runner(**vllm_runner_kwargs) as vllm_model:
# Retrieve the default mamba state block size # Retrieve the default mamba state block size
mamba_block_size = vllm_model.llm.llm_engine.cache_config.mamba_block_size vllm_config = vllm_model.llm.llm_engine.vllm_config
mamba_block_size = vllm_config.cache_config.mamba_block_size
# In case the hybrid model does not have the # In case the hybrid model does not have the
# "mamba_block_size" assume a fixed constant # "mamba_block_size" assume a fixed constant
......
...@@ -25,7 +25,8 @@ def test_classify_models( ...@@ -25,7 +25,8 @@ def test_classify_models(
with vllm_runner( with vllm_runner(
model, max_model_len=512, dtype=dtype, enable_prefix_caching=True model, max_model_len=512, dtype=dtype, enable_prefix_caching=True
) as vllm_model: ) as vllm_model:
cache_config = vllm_model.llm.llm_engine.cache_config vllm_config = vllm_model.llm.llm_engine.vllm_config
cache_config = vllm_config.cache_config
assert cache_config.enable_prefix_caching assert cache_config.enable_prefix_caching
# First Run # First Run
...@@ -74,7 +75,8 @@ def test_embed_models( ...@@ -74,7 +75,8 @@ def test_embed_models(
max_model_len=None, max_model_len=None,
enable_prefix_caching=True, enable_prefix_caching=True,
) as vllm_model: ) as vllm_model:
cache_config = vllm_model.llm.llm_engine.cache_config vllm_config = vllm_model.llm.llm_engine.vllm_config
cache_config = vllm_config.cache_config
assert cache_config.enable_prefix_caching assert cache_config.enable_prefix_caching
# First Run # First Run
...@@ -106,5 +108,6 @@ def test_non_causal_models( ...@@ -106,5 +108,6 @@ def test_non_causal_models(
hf_runner, vllm_runner, example_prompts, model: str, dtype: str hf_runner, vllm_runner, example_prompts, model: str, dtype: str
) -> None: ) -> None:
with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model: with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
cache_config = vllm_model.llm.llm_engine.cache_config vllm_config = vllm_model.llm.llm_engine.vllm_config
cache_config = vllm_config.cache_config
assert not cache_config.enable_prefix_caching assert not cache_config.enable_prefix_caching
...@@ -161,7 +161,8 @@ def test_pooling_prefix_cache(vllm_runner, monkeypatch): ...@@ -161,7 +161,8 @@ def test_pooling_prefix_cache(vllm_runner, monkeypatch):
assert chunks[0] <= prompt1_len assert chunks[0] <= prompt1_len
assert chunks[0] < prompt2_len assert chunks[0] < prompt2_len
cache_config = llm.get_llm().llm_engine.cache_config vllm_config = llm.get_llm().llm_engine.vllm_config
cache_config = vllm_config.cache_config
print(f"{cache_config=}") print(f"{cache_config=}")
# Prefixes are cached in blocks # Prefixes are cached in blocks
assert (prompt2_len - chunks[0]) % cache_config.block_size == 0 assert (prompt2_len - chunks[0]) % cache_config.block_size == 0
...@@ -311,7 +311,8 @@ def test_get_logprobs_and_prompt_logprobs( ...@@ -311,7 +311,8 @@ def test_get_logprobs_and_prompt_logprobs(
temperature: "temperature" sampling parameter temperature: "temperature" sampling parameter
example_prompts: example prompt fixture example_prompts: example prompt fixture
""" """
do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching vllm_config = vllm_model.llm.llm_engine.vllm_config
do_apc = vllm_config.cache_config.enable_prefix_caching
if do_apc and (temperature < 2.0 or batch_logprobs_composition != SAMPLE_PROMPT): if do_apc and (temperature < 2.0 or batch_logprobs_composition != SAMPLE_PROMPT):
# Skip some test-cases to save time. # Skip some test-cases to save time.
pytest.skip() pytest.skip()
......
...@@ -54,7 +54,7 @@ class PoolerConfig: ...@@ -54,7 +54,7 @@ class PoolerConfig:
Reduce the dimensions of embeddings if model Reduce the dimensions of embeddings if model
support matryoshka representation. Defaults to None. support matryoshka representation. Defaults to None.
""" """
enable_chunked_processing: bool | None = None enable_chunked_processing: bool = False
""" """
Whether to enable chunked processing for long inputs that exceed the model's Whether to enable chunked processing for long inputs that exceed the model's
maximum position embeddings. When enabled, long inputs will be split into maximum position embeddings. When enabled, long inputs will be split into
......
...@@ -31,12 +31,9 @@ class EngineClient(ABC): ...@@ -31,12 +31,9 @@ class EngineClient(ABC):
vllm_config: VllmConfig vllm_config: VllmConfig
model_config: ModelConfig model_config: ModelConfig
input_processor: InputProcessor renderer: BaseRenderer
io_processor: IOProcessor | None io_processor: IOProcessor | None
input_processor: InputProcessor
@property
@abstractmethod
def renderer(self) -> BaseRenderer: ...
@property @property
@abstractmethod @abstractmethod
......
...@@ -356,8 +356,9 @@ class LLM: ...@@ -356,8 +356,9 @@ class LLM:
self.supported_tasks = supported_tasks self.supported_tasks = supported_tasks
self.model_config = self.llm_engine.model_config self.model_config = self.llm_engine.model_config
self.input_processor = self.llm_engine.input_processor self.renderer = self.llm_engine.renderer
self.io_processor = self.llm_engine.io_processor self.io_processor = self.llm_engine.io_processor
self.input_processor = self.llm_engine.input_processor
# Cache for __repr__ to avoid repeated collective_rpc calls # Cache for __repr__ to avoid repeated collective_rpc calls
self._cached_repr: str | None = None self._cached_repr: str | None = None
...@@ -816,7 +817,7 @@ class LLM: ...@@ -816,7 +817,7 @@ class LLM:
A list of `TokensPrompts` objects containing the tokenized prompt A list of `TokensPrompts` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs. after chat template interpolation, and the raw multi-modal inputs.
""" """
renderer = self.llm_engine.renderer renderer = self.renderer
model_config = self.model_config model_config = self.model_config
parsed_prompts = [ parsed_prompts = [
...@@ -858,7 +859,7 @@ class LLM: ...@@ -858,7 +859,7 @@ class LLM:
A list of `TokensPrompts` objects containing the tokenized prompt A list of `TokensPrompts` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs. after chat template interpolation, and the raw multi-modal inputs.
""" """
renderer = self.llm_engine.renderer renderer = self.renderer
chat_params = ChatParams( chat_params = ChatParams(
chat_template=chat_template, chat_template=chat_template,
......
...@@ -239,8 +239,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -239,8 +239,7 @@ class OpenAIServingChat(OpenAIServing):
raise self.engine_client.dead_error raise self.engine_client.dead_error
try: try:
renderer = self.engine_client.renderer tokenizer = self.renderer.tokenizer
tokenizer = renderer.tokenizer
tool_parser = self.tool_parser tool_parser = self.tool_parser
...@@ -375,6 +374,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -375,6 +374,7 @@ class OpenAIServingChat(OpenAIServing):
data_parallel_rank = self._get_data_parallel_rank(raw_request) data_parallel_rank = self._get_data_parallel_rank(raw_request)
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
...@@ -387,7 +387,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -387,7 +387,7 @@ class OpenAIServingChat(OpenAIServing):
) )
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
self.max_model_len, max_model_len,
request.max_completion_tokens request.max_completion_tokens
if request.max_completion_tokens is not None if request.max_completion_tokens is not None
else request.max_tokens, else request.max_tokens,
......
...@@ -157,13 +157,14 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -157,13 +157,14 @@ class OpenAIServingCompletion(OpenAIServing):
data_parallel_rank = self._get_data_parallel_rank(raw_request) data_parallel_rank = self._get_data_parallel_rank(raw_request)
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text = self._extract_prompt_text(engine_prompt) prompt_text = self._extract_prompt_text(engine_prompt)
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
self.max_model_len, max_model_len,
request.max_tokens, request.max_tokens,
self._extract_prompt_len(engine_prompt), self._extract_prompt_len(engine_prompt),
self.default_sampling_params, self.default_sampling_params,
......
...@@ -242,11 +242,10 @@ class OpenAIServing: ...@@ -242,11 +242,10 @@ class OpenAIServing:
self.log_error_stack = log_error_stack self.log_error_stack = log_error_stack
self.input_processor = self.models.input_processor self.model_config = engine_client.model_config
self.io_processor = self.models.io_processor self.renderer = engine_client.renderer
self.renderer = self.models.renderer self.io_processor = engine_client.io_processor
self.model_config = self.models.model_config self.input_processor = engine_client.input_processor
self.max_model_len = self.model_config.max_model_len
async def beam_search( async def beam_search(
self, self,
...@@ -537,7 +536,7 @@ class OpenAIServing: ...@@ -537,7 +536,7 @@ class OpenAIServing:
if ( if (
truncate_prompt_tokens is not None truncate_prompt_tokens is not None
and truncate_prompt_tokens > self.max_model_len and truncate_prompt_tokens > self.model_config.max_model_len
): ):
return self.create_error_response( return self.create_error_response(
"truncate_prompt_tokens value is " "truncate_prompt_tokens value is "
...@@ -844,6 +843,7 @@ class OpenAIServing: ...@@ -844,6 +843,7 @@ class OpenAIServing:
input_text: str, input_text: str,
) -> TokensPrompt: ) -> TokensPrompt:
token_num = len(input_ids) token_num = len(input_ids)
max_model_len = self.model_config.max_model_len
# Note: EmbeddingRequest, ClassificationRequest, # Note: EmbeddingRequest, ClassificationRequest,
# and ScoreRequest doesn't have max_tokens # and ScoreRequest doesn't have max_tokens
...@@ -862,7 +862,7 @@ class OpenAIServing: ...@@ -862,7 +862,7 @@ class OpenAIServing:
): ):
# Note: input length can be up to the entire model context length # Note: input length can be up to the entire model context length
# since these requests don't generate tokens. # since these requests don't generate tokens.
if token_num > self.max_model_len: if token_num > max_model_len:
operations: dict[type[AnyRequest], str] = { operations: dict[type[AnyRequest], str] = {
ScoreDataRequest: "score", ScoreDataRequest: "score",
ScoreTextRequest: "score", ScoreTextRequest: "score",
...@@ -873,7 +873,7 @@ class OpenAIServing: ...@@ -873,7 +873,7 @@ class OpenAIServing:
operation = operations.get(type(request), "embedding generation") operation = operations.get(type(request), "embedding generation")
raise VLLMValidationError( raise VLLMValidationError(
f"This model's maximum context length is " f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested " f"{max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for {operation}. " f"{token_num} tokens in the input for {operation}. "
f"Please reduce the length of the input.", f"Please reduce the length of the input.",
parameter="input_tokens", parameter="input_tokens",
...@@ -898,22 +898,22 @@ class OpenAIServing: ...@@ -898,22 +898,22 @@ class OpenAIServing:
# Note: input length can be up to model context length - 1 for # Note: input length can be up to model context length - 1 for
# completion-like requests. # completion-like requests.
if token_num >= self.max_model_len: if token_num >= max_model_len:
raise VLLMValidationError( raise VLLMValidationError(
f"This model's maximum context length is " f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, your request has " f"{max_model_len} tokens. However, your request has "
f"{token_num} input tokens. Please reduce the length of " f"{token_num} input tokens. Please reduce the length of "
"the input messages.", "the input messages.",
parameter="input_tokens", parameter="input_tokens",
value=token_num, value=token_num,
) )
if max_tokens is not None and token_num + max_tokens > self.max_model_len: if max_tokens is not None and token_num + max_tokens > max_model_len:
raise VLLMValidationError( raise VLLMValidationError(
"'max_tokens' or 'max_completion_tokens' is too large: " "'max_tokens' or 'max_completion_tokens' is too large: "
f"{max_tokens}. This model's maximum context length is " f"{max_tokens}. This model's maximum context length is "
f"{self.max_model_len} tokens and your request has " f"{max_model_len} tokens and your request has "
f"{token_num} input tokens ({max_tokens} > {self.max_model_len}" f"{token_num} input tokens ({max_tokens} > {max_model_len}"
f" - {token_num}).", f" - {token_num}).",
parameter="max_tokens", parameter="max_tokens",
value=max_tokens, value=max_tokens,
...@@ -1089,6 +1089,7 @@ class OpenAIServing: ...@@ -1089,6 +1089,7 @@ class OpenAIServing:
priority: int = 0, priority: int = 0,
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
): ):
max_model_len = self.model_config.max_model_len
prompt_text = self._extract_prompt_text(engine_prompt) prompt_text = self._extract_prompt_text(engine_prompt)
orig_priority = priority orig_priority = priority
...@@ -1148,7 +1149,7 @@ class OpenAIServing: ...@@ -1148,7 +1149,7 @@ class OpenAIServing:
token_ids = context.render_for_completion() token_ids = context.render_for_completion()
engine_prompt = TokensPrompt(prompt_token_ids=token_ids) engine_prompt = TokensPrompt(prompt_token_ids=token_ids)
sampling_params.max_tokens = self.max_model_len - len(token_ids) sampling_params.max_tokens = max_model_len - len(token_ids)
elif isinstance(context, ParsableContext): elif isinstance(context, ParsableContext):
engine_prompts = await self._render_next_turn( engine_prompts = await self._render_next_turn(
context.request, context.request,
...@@ -1162,7 +1163,7 @@ class OpenAIServing: ...@@ -1162,7 +1163,7 @@ class OpenAIServing:
prompt_text = self._extract_prompt_text(engine_prompt) prompt_text = self._extract_prompt_text(engine_prompt)
sampling_params.max_tokens = get_max_tokens( sampling_params.max_tokens = get_max_tokens(
self.max_model_len, max_model_len,
context.request.max_output_tokens, context.request.max_output_tokens,
self._extract_prompt_len(engine_prompt), self._extract_prompt_len(engine_prompt),
self.default_sampling_params, # type: ignore self.default_sampling_params, # type: ignore
......
...@@ -59,11 +59,10 @@ class OpenAIServingModels: ...@@ -59,11 +59,10 @@ class OpenAIServingModels:
) )
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock) self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
self.input_processor = self.engine_client.input_processor
self.io_processor = self.engine_client.io_processor
self.renderer = self.engine_client.renderer
self.model_config = self.engine_client.model_config self.model_config = self.engine_client.model_config
self.max_model_len = self.model_config.max_model_len self.renderer = self.engine_client.renderer
self.io_processor = self.engine_client.io_processor
self.input_processor = self.engine_client.input_processor
async def init_static_loras(self): async def init_static_loras(self):
"""Loads all static LoRA modules. """Loads all static LoRA modules.
...@@ -96,12 +95,13 @@ class OpenAIServingModels: ...@@ -96,12 +95,13 @@ class OpenAIServingModels:
return self.base_model_paths[0].name return self.base_model_paths[0].name
async def show_available_models(self) -> ModelList: async def show_available_models(self) -> ModelList:
"""Show available models. This includes the base model and all """Show available models. This includes the base model and all adapters."""
adapters""" max_model_len = self.model_config.max_model_len
model_cards = [ model_cards = [
ModelCard( ModelCard(
id=base_model.name, id=base_model.name,
max_model_len=self.max_model_len, max_model_len=max_model_len,
root=base_model.model_path, root=base_model.model_path,
permission=[ModelPermission()], permission=[ModelPermission()],
) )
......
...@@ -296,10 +296,12 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -296,10 +296,12 @@ class OpenAIServingResponses(OpenAIServing):
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Add validations to the input to the generator here.""" """Add validations to the input to the generator here."""
prompt_len = self._extract_prompt_len(engine_prompt) prompt_len = self._extract_prompt_len(engine_prompt)
if self.max_model_len <= prompt_len: max_model_len = self.model_config.max_model_len
if prompt_len >= max_model_len:
error_message = ( error_message = (
f"The engine prompt length {prompt_len} " f"The engine prompt length {prompt_len} "
f"exceeds the max_model_len {self.max_model_len}. " f"exceeds the max_model_len {max_model_len}. "
"Please reduce prompt." "Please reduce prompt."
) )
return self.create_error_response( return self.create_error_response(
...@@ -414,6 +416,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -414,6 +416,7 @@ class OpenAIServingResponses(OpenAIServing):
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len
generators: list[AsyncGenerator[ConversationContext, None]] = [] generators: list[AsyncGenerator[ConversationContext, None]] = []
builtin_tool_list: list[str] = [] builtin_tool_list: list[str] = []
...@@ -431,8 +434,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -431,8 +434,7 @@ class OpenAIServingResponses(OpenAIServing):
assert len(builtin_tool_list) == 0 assert len(builtin_tool_list) == 0
available_tools = [] available_tools = []
try: try:
renderer = self.engine_client.renderer tokenizer = self.renderer.get_tokenizer()
tokenizer = renderer.get_tokenizer()
for engine_prompt in engine_prompts: for engine_prompt in engine_prompts:
maybe_error = self._validate_generator_input(engine_prompt) maybe_error = self._validate_generator_input(engine_prompt)
...@@ -440,7 +442,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -440,7 +442,7 @@ class OpenAIServingResponses(OpenAIServing):
return maybe_error return maybe_error
default_max_tokens = get_max_tokens( default_max_tokens = get_max_tokens(
self.max_model_len, max_model_len,
request.max_output_tokens, request.max_output_tokens,
self._extract_prompt_len(engine_prompt), self._extract_prompt_len(engine_prompt),
self.default_sampling_params, self.default_sampling_params,
......
...@@ -69,16 +69,8 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -69,16 +69,8 @@ class OpenAIServingEmbedding(OpenAIServing):
self.trust_request_chat_template = trust_request_chat_template self.trust_request_chat_template = trust_request_chat_template
pooler_config = self.model_config.pooler_config pooler_config = self.model_config.pooler_config
assert pooler_config is not None
# Avoid repeated attribute lookups self.pooler_config = pooler_config
self.supports_chunked_processing = bool(
pooler_config and pooler_config.enable_chunked_processing
)
self.max_embed_len = (
pooler_config.max_embed_len
if pooler_config and pooler_config.max_embed_len
else None
)
async def _preprocess( async def _preprocess(
self, self,
...@@ -240,7 +232,7 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -240,7 +232,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"""Check if chunked processing should be used for this request.""" """Check if chunked processing should be used for this request."""
return ( return (
isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest)) isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest))
and self.supports_chunked_processing and self.pooler_config.enable_chunked_processing
) )
async def _process_chunked_request( async def _process_chunked_request(
...@@ -310,14 +302,14 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -310,14 +302,14 @@ class OpenAIServingEmbedding(OpenAIServing):
max_pos_embeddings = self._get_max_position_embeddings() max_pos_embeddings = self._get_max_position_embeddings()
# Determine the effective max length for validation # Determine the effective max length for validation
if self.max_embed_len is not None: if self.pooler_config.max_embed_len:
# Use max_embed_len for validation instead of max_model_len # Use max_embed_len for validation instead of max_model_len
length_type = "maximum embedding input length" length_type = "maximum embedding input length"
max_length_value = self.max_embed_len max_length_value = self.pooler_config.max_embed_len
else: else:
# Fall back to max_model_len validation (original behavior) # Fall back to max_model_len validation (original behavior)
length_type = "maximum context length" length_type = "maximum context length"
max_length_value = self.max_model_len max_length_value = self.model_config.max_model_len
validation_error_msg = ( validation_error_msg = (
"This model's {length_type} is {max_length_value} tokens. " "This model's {length_type} is {max_length_value} tokens. "
......
...@@ -117,7 +117,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -117,7 +117,7 @@ class OpenAIServingTokenization(OpenAIServing):
tokens=input_ids, tokens=input_ids,
token_strs=token_strs, token_strs=token_strs,
count=len(input_ids), count=len(input_ids),
max_model_len=self.max_model_len, max_model_len=self.model_config.max_model_len,
) )
async def create_detokenize( async def create_detokenize(
......
...@@ -16,7 +16,7 @@ from vllm.multimodal.inputs import ( ...@@ -16,7 +16,7 @@ from vllm.multimodal.inputs import (
MultiModalUUIDDict, MultiModalUUIDDict,
) )
from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.renderers import renderer_from_config from vllm.renderers import BaseRenderer, renderer_from_config
from vllm.renderers.inputs import ( from vllm.renderers.inputs import (
DecoderDictPrompt, DecoderDictPrompt,
DecoderOnlyDictPrompt, DecoderOnlyDictPrompt,
...@@ -56,6 +56,7 @@ class InputPreprocessor: ...@@ -56,6 +56,7 @@ class InputPreprocessor:
self, self,
model_config: ModelConfig, model_config: ModelConfig,
observability_config: ObservabilityConfig | None = None, observability_config: ObservabilityConfig | None = None,
renderer: BaseRenderer | None = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: BaseMultiModalProcessorCache | None = None, mm_processor_cache: BaseMultiModalProcessorCache | None = None,
) -> None: ) -> None:
...@@ -63,7 +64,7 @@ class InputPreprocessor: ...@@ -63,7 +64,7 @@ class InputPreprocessor:
self.model_config = model_config self.model_config = model_config
self.observability_config = observability_config self.observability_config = observability_config
self.renderer = renderer_from_config(model_config) self.renderer = renderer or renderer_from_config(model_config)
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache self.mm_processor_cache = mm_processor_cache
......
...@@ -27,7 +27,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry ...@@ -27,7 +27,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer, merge_kwargs from vllm.renderers import merge_kwargs, renderer_from_config
from vllm.renderers.inputs import DictPrompt, TokPrompt from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
...@@ -110,9 +110,10 @@ class AsyncLLM(EngineClient): ...@@ -110,9 +110,10 @@ class AsyncLLM(EngineClient):
# Ensure we can serialize custom transformer configs # Ensure we can serialize custom transformer configs
maybe_register_config_serialize_by_value() maybe_register_config_serialize_by_value()
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
tracing_endpoint = self.observability_config.otlp_traces_endpoint tracing_endpoint = self.observability_config.otlp_traces_endpoint
if tracing_endpoint is not None: if tracing_endpoint is not None:
init_tracer("vllm.llm_engine", tracing_endpoint) init_tracer("vllm.llm_engine", tracing_endpoint)
...@@ -131,20 +132,22 @@ class AsyncLLM(EngineClient): ...@@ -131,20 +132,22 @@ class AsyncLLM(EngineClient):
"enabling logging without default stat loggers." "enabling logging without default stat loggers."
) )
self.input_processor = InputProcessor(self.vllm_config) self.renderer = renderer = renderer_from_config(self.model_config)
self.io_processor = get_io_processor( self.io_processor = get_io_processor(
self.vllm_config, self.vllm_config,
self.model_config.io_processor_plugin, self.model_config.io_processor_plugin,
) )
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput). # Convert TokPrompt --> EngineCoreRequest.
self.input_processor = InputProcessor(self.vllm_config, renderer)
# Converts EngineCoreOutputs --> RequestOutput.
self.output_processor = OutputProcessor( self.output_processor = OutputProcessor(
self.tokenizer, renderer.tokenizer,
log_stats=self.log_stats, log_stats=self.log_stats,
stream_interval=self.vllm_config.scheduler_config.stream_interval, stream_interval=self.vllm_config.scheduler_config.stream_interval,
tracing_enabled=tracing_endpoint is not None,
) )
if tracing_endpoint is not None:
self.output_processor.tracing_enabled = True
# EngineCore (starts the engine in background process). # EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_async_mp_client( self.engine_core = EngineCoreClient.make_async_mp_client(
...@@ -891,17 +894,13 @@ class AsyncLLM(EngineClient): ...@@ -891,17 +894,13 @@ class AsyncLLM(EngineClient):
@property @property
def tokenizer(self) -> TokenizerLike | None: def tokenizer(self) -> TokenizerLike | None:
return self.input_processor.tokenizer return self.renderer.tokenizer
def get_tokenizer(self) -> TokenizerLike: def get_tokenizer(self) -> TokenizerLike:
return self.input_processor.get_tokenizer() return self.renderer.get_tokenizer()
@property
def renderer(self) -> BaseRenderer:
return self.input_processor.renderer
async def is_tracing_enabled(self) -> bool: async def is_tracing_enabled(self) -> bool:
return self.observability_config.otlp_traces_endpoint is not None # type: ignore return self.observability_config.otlp_traces_endpoint is not None
async def do_log_stats(self) -> None: async def do_log_stats(self) -> None:
if self.logger_manager: if self.logger_manager:
......
...@@ -27,7 +27,7 @@ from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems ...@@ -27,7 +27,7 @@ from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems
from vllm.multimodal.processing.context import set_request_id from vllm.multimodal.processing.context import set_request_id
from vllm.multimodal.utils import argsort_mm_positions from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer from vllm.renderers import BaseRenderer, renderer_from_config
from vllm.renderers.inputs import DictPrompt, TokPrompt from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tasks import POOLING_TASKS, SupportedTask
...@@ -44,6 +44,8 @@ class InputProcessor: ...@@ -44,6 +44,8 @@ class InputProcessor:
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
renderer: BaseRenderer | None = None,
*,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None: ) -> None:
self.vllm_config = vllm_config self.vllm_config = vllm_config
...@@ -57,6 +59,7 @@ class InputProcessor: ...@@ -57,6 +59,7 @@ class InputProcessor:
self.generation_config_fields = model_config.try_get_generation_config() self.generation_config_fields = model_config.try_get_generation_config()
self.renderer = renderer or renderer_from_config(model_config)
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config) self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config)
...@@ -74,20 +77,17 @@ class InputProcessor: ...@@ -74,20 +77,17 @@ class InputProcessor:
self.input_preprocessor = InputPreprocessor( self.input_preprocessor = InputPreprocessor(
model_config, model_config,
self.observability_config, self.observability_config,
mm_registry, renderer=renderer,
mm_registry=mm_registry,
mm_processor_cache=self.mm_processor_cache, mm_processor_cache=self.mm_processor_cache,
) )
@property @property
def tokenizer(self) -> TokenizerLike | None: def tokenizer(self) -> TokenizerLike | None:
return self.input_preprocessor.tokenizer return self.renderer.tokenizer
def get_tokenizer(self) -> TokenizerLike: def get_tokenizer(self) -> TokenizerLike:
return self.input_preprocessor.get_tokenizer() return self.renderer.get_tokenizer()
@property
def renderer(self) -> BaseRenderer:
return self.input_preprocessor.renderer
def _validate_params( def _validate_params(
self, self,
......
...@@ -21,7 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry ...@@ -21,7 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer from vllm.renderers import renderer_from_config
from vllm.renderers.inputs import DictPrompt, TokPrompt from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -62,9 +62,12 @@ class LLMEngine: ...@@ -62,9 +62,12 @@ class LLMEngine:
multiprocess_mode: bool = False, multiprocess_mode: bool = False,
) -> None: ) -> None:
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.observability_config = vllm_config.observability_config
tracing_endpoint = self.observability_config.otlp_traces_endpoint
if tracing_endpoint is not None:
init_tracer("vllm.llm_engine", tracing_endpoint)
self.log_stats = log_stats self.log_stats = log_stats
...@@ -87,22 +90,22 @@ class LLMEngine: ...@@ -87,22 +90,22 @@ class LLMEngine:
self.dp_group = None self.dp_group = None
self.should_execute_dummy_batch = False self.should_execute_dummy_batch = False
self.input_processor = InputProcessor(self.vllm_config) self.renderer = renderer = renderer_from_config(self.model_config)
self.io_processor = get_io_processor( self.io_processor = get_io_processor(
self.vllm_config, self.vllm_config,
self.model_config.io_processor_plugin, self.model_config.io_processor_plugin,
) )
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput). # Convert TokPrompt --> EngineCoreRequest.
self.input_processor = InputProcessor(self.vllm_config, renderer)
# Converts EngineCoreOutputs --> RequestOutput.
self.output_processor = OutputProcessor( self.output_processor = OutputProcessor(
self.tokenizer, renderer.tokenizer,
log_stats=self.log_stats, log_stats=self.log_stats,
stream_interval=self.vllm_config.scheduler_config.stream_interval, stream_interval=self.vllm_config.scheduler_config.stream_interval,
tracing_enabled=tracing_endpoint is not None,
) )
endpoint = self.observability_config.otlp_traces_endpoint
if endpoint is not None:
init_tracer("vllm.llm_engine", endpoint)
self.output_processor.tracing_enabled = True
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self.engine_core = EngineCoreClient.make_client( self.engine_core = EngineCoreClient.make_client(
...@@ -365,14 +368,10 @@ class LLMEngine: ...@@ -365,14 +368,10 @@ class LLMEngine:
@property @property
def tokenizer(self) -> TokenizerLike | None: def tokenizer(self) -> TokenizerLike | None:
return self.input_processor.tokenizer return self.renderer.tokenizer
def get_tokenizer(self) -> TokenizerLike: def get_tokenizer(self) -> TokenizerLike:
return self.input_processor.get_tokenizer() return self.renderer.get_tokenizer()
@property
def renderer(self) -> BaseRenderer:
return self.input_processor.renderer
def do_log_stats(self) -> None: def do_log_stats(self) -> None:
"""Log stats if logging is enabled.""" """Log stats if logging is enabled."""
......
...@@ -417,8 +417,10 @@ class OutputProcessor: ...@@ -417,8 +417,10 @@ class OutputProcessor:
def __init__( def __init__(
self, self,
tokenizer: TokenizerLike | None, tokenizer: TokenizerLike | None,
*,
log_stats: bool, log_stats: bool,
stream_interval: int = 1, stream_interval: int = 1,
tracing_enabled: bool = False,
): ):
self.log_stats = log_stats self.log_stats = log_stats
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -427,7 +429,7 @@ class OutputProcessor: ...@@ -427,7 +429,7 @@ class OutputProcessor:
self.parent_requests: dict[str, ParentRequest] = {} self.parent_requests: dict[str, ParentRequest] = {}
self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list) self.external_req_ids: defaultdict[str, list[str]] = defaultdict(list)
self.lora_states = LoRARequestStates(log_stats) self.lora_states = LoRARequestStates(log_stats)
self.tracing_enabled: bool = False self.tracing_enabled = tracing_enabled
self._requests_drained = asyncio.Event() self._requests_drained = asyncio.Event()
self._requests_drained.set() self._requests_drained.set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment