Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
da6ea29f
Unverified
Commit
da6ea29f
authored
Mar 20, 2025
by
Nick Hill
Committed by
GitHub
Mar 20, 2025
Browse files
[V1] Avoid redundant input processing in n>1 case (#14985)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
7297941b
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
82 additions
and
142 deletions
+82
-142
tests/lora/test_tokenizer_group.py
tests/lora/test_tokenizer_group.py
+2
-4
tests/tokenization/test_tokenizer_group.py
tests/tokenization/test_tokenizer_group.py
+9
-18
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+0
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+0
-1
vllm/engine/protocol.py
vllm/engine/protocol.py
+1
-4
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+9
-54
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
...ransformers_utils/tokenizer_group/base_tokenizer_group.py
+0
-2
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
...transformers_utils/tokenizer_group/ray_tokenizer_group.py
+2
-8
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
+0
-2
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+32
-19
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+26
-16
vllm/v1/engine/parallel_sampling.py
vllm/v1/engine/parallel_sampling.py
+1
-12
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+0
-1
No files found.
tests/lora/test_tokenizer_group.py
View file @
da6ea29f
...
...
@@ -24,12 +24,10 @@ async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
)
lora_request
=
LoRARequest
(
"1"
,
1
,
sql_lora_files
)
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
tokenizer_group
.
encode
(
request_id
=
"request_id"
,
prompt
=
"prompt"
,
lora_request
=
lora_request
)
prompt
=
"prompt"
,
lora_request
=
lora_request
)
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
await
tokenizer_group
.
encode_async
(
request_id
=
"request_id"
,
prompt
=
"prompt"
,
lora_request
=
lora_request
)
prompt
=
"prompt"
,
lora_request
=
lora_request
)
assert
isinstance
(
tokenizer_group
.
get_lora_tokenizer
(
None
),
PreTrainedTokenizerBase
)
assert
tokenizer_group
.
get_lora_tokenizer
(
...
...
tests/tokenization/test_tokenizer_group.py
View file @
da6ea29f
...
...
@@ -41,10 +41,10 @@ async def test_tokenizer_group(tokenizer_group_type):
max_input_length
=
None
,
)
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
tokenizer_group
.
encode
(
request_id
=
"request_id"
,
prompt
=
"prompt"
,
lora_request
=
None
)
prompt
=
"prompt"
,
lora_request
=
None
)
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
await
tokenizer_group
.
encode_async
(
request_id
=
"request_id"
,
prompt
=
"prompt"
,
lora_request
=
None
)
"prompt"
)
==
await
tokenizer_group
.
encode_async
(
prompt
=
"prompt"
,
lora_request
=
None
)
assert
isinstance
(
tokenizer_group
.
get_lora_tokenizer
(
None
),
PreTrainedTokenizerBase
)
assert
tokenizer_group
.
get_lora_tokenizer
(
...
...
@@ -69,8 +69,7 @@ async def test_tokenizer_group_pool(tokenizer_group_type):
# and check that all requests are processed correctly.
num_requests
=
tokenizer_group_pool
.
pool_size
*
5
requests
=
[
tokenizer_group_pool
.
encode_async
(
request_id
=
str
(
i
),
prompt
=
f
"prompt
{
i
}
"
,
tokenizer_group_pool
.
encode_async
(
prompt
=
f
"prompt
{
i
}
"
,
lora_request
=
None
)
for
i
in
range
(
num_requests
)
]
...
...
@@ -161,12 +160,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
fail_at
[
0
]
=
1000
# We should recover successfully.
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
lora_request
=
None
)
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
lora_request
=
None
)
await
tokenizer_group_pool
.
encode_async
(
prompt
=
"prompt"
,
lora_request
=
None
)
await
tokenizer_group_pool
.
encode_async
(
prompt
=
"prompt"
,
lora_request
=
None
)
# Check that we have a new actor
assert
len
(
tokenizer_group_pool
.
tokenizer_actors
)
==
len
(
tokenizer_actors
)
...
...
@@ -184,8 +179,7 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
# We should fail after re-initialization.
with
pytest
.
raises
(
RuntimeError
):
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
await
tokenizer_group_pool
.
encode_async
(
prompt
=
"prompt"
,
lora_request
=
None
)
# check_health should raise the same thing
...
...
@@ -206,11 +200,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
# Prompt too long error
with
pytest
.
raises
(
ValueError
):
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
*
100
,
await
tokenizer_group_pool
.
encode_async
(
prompt
=
"prompt"
*
100
,
lora_request
=
None
)
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
lora_request
=
None
)
await
tokenizer_group_pool
.
encode_async
(
prompt
=
"prompt"
,
lora_request
=
None
)
# Actors should stay the same.
assert
tokenizer_group_pool
.
tokenizer_actors
==
tokenizer_actors
vllm/engine/async_llm_engine.py
View file @
da6ea29f
...
...
@@ -492,7 +492,6 @@ class _AsyncLLMEngine(LLMEngine):
preprocessed_inputs
=
await
self
.
input_preprocessor
.
preprocess_async
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
...
...
vllm/engine/llm_engine.py
View file @
da6ea29f
...
...
@@ -783,7 +783,6 @@ class LLMEngine:
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
...
...
vllm/engine/protocol.py
View file @
da6ea29f
...
...
@@ -81,10 +81,7 @@ class EngineClient(ABC):
if
is_explicit_encoder_decoder_prompt
(
prompt
):
raise
NotImplementedError
else
:
processed_inputs
=
preprocessor
.
_prompt_to_llm_inputs
(
prompt
,
request_id
=
request_id
,
)
processed_inputs
=
preprocessor
.
_prompt_to_llm_inputs
(
prompt
)
prompt_token_ids
=
processed_inputs
[
"prompt_token_ids"
]
prompt_text
=
processed_inputs
.
get
(
"prompt"
)
...
...
vllm/inputs/preprocess.py
View file @
da6ea29f
...
...
@@ -182,7 +182,6 @@ class InputPreprocessor:
def
_tokenize_prompt
(
self
,
prompt
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
],
)
->
list
[
int
]:
"""
...
...
@@ -202,15 +201,13 @@ class InputPreprocessor:
"do_lower_case"
,
False
)):
prompt
=
prompt
.
lower
()
return
tokenizer
.
encode
(
request_id
=
request_id
,
prompt
=
prompt
,
return
tokenizer
.
encode
(
prompt
=
prompt
,
lora_request
=
lora_request
,
add_special_tokens
=
add_special_tokens
)
async
def
_tokenize_prompt_async
(
self
,
prompt
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
],
)
->
list
[
int
]:
"""Async version of :meth:`_tokenize_prompt`."""
...
...
@@ -222,7 +219,6 @@ class InputPreprocessor:
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens
=
False
return
await
tokenizer
.
encode_async
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
,
add_special_tokens
=
add_special_tokens
)
...
...
@@ -309,7 +305,6 @@ class InputPreprocessor:
def
_prompt_to_llm_inputs
(
self
,
prompt
:
SingletonPrompt
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
)
->
SingletonInputs
:
...
...
@@ -318,7 +313,6 @@ class InputPreprocessor:
Arguments:
* request_id
* prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
* return_mm_hashes: whether to return multimodal hashes
...
...
@@ -333,7 +327,6 @@ class InputPreprocessor:
prompt_text
=
parsed
[
"content"
]
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt_text
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
...
...
@@ -384,7 +377,6 @@ class InputPreprocessor:
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt_text
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
...
...
@@ -400,7 +392,6 @@ class InputPreprocessor:
async
def
_prompt_to_llm_inputs_async
(
self
,
prompt
:
SingletonPrompt
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
)
->
SingletonInputs
:
...
...
@@ -411,7 +402,6 @@ class InputPreprocessor:
prompt_text
=
parsed
[
"content"
]
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt_text
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
...
...
@@ -460,7 +450,6 @@ class InputPreprocessor:
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt_text
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
...
...
@@ -560,7 +549,6 @@ class InputPreprocessor:
def
_process_encoder_decoder_prompt
(
self
,
prompt
:
PromptType
,
request_id
:
str
,
)
->
EncoderDecoderInputs
:
"""
For encoder/decoder models only:
...
...
@@ -587,7 +575,6 @@ class InputPreprocessor:
Arguments:
* prompt: an input prompt
* request_id
Returns:
...
...
@@ -598,16 +585,11 @@ class InputPreprocessor:
if
is_explicit_encoder_decoder_prompt
(
prompt
):
encoder_inputs
=
self
.
_prompt_to_llm_inputs
(
prompt
[
"encoder_prompt"
],
request_id
=
request_id
,
)
prompt
[
"encoder_prompt"
])
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
decoder_inputs
=
None
else
:
decoder_inputs
=
self
.
_prompt_to_llm_inputs
(
decoder_input
,
request_id
=
request_id
,
)
decoder_inputs
=
self
.
_prompt_to_llm_inputs
(
decoder_input
)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if
self
.
model_config
.
is_multimodal_model
and
(
...
...
@@ -616,10 +598,7 @@ class InputPreprocessor:
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
encoder_inputs
,
decoder_inputs
))
else
:
inputs
=
self
.
_prompt_to_llm_inputs
(
prompt
,
request_id
=
request_id
,
)
inputs
=
self
.
_prompt_to_llm_inputs
(
prompt
)
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
# Encoder-Decoder Multimodal model
...
...
@@ -636,7 +615,6 @@ class InputPreprocessor:
async
def
_process_encoder_decoder_prompt_async
(
self
,
prompt
:
PromptType
,
request_id
:
str
,
)
->
EncoderDecoderInputs
:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_inputs
:
SingletonInputs
...
...
@@ -644,18 +622,13 @@ class InputPreprocessor:
if
is_explicit_encoder_decoder_prompt
(
prompt
):
encoder_task
=
self
.
_prompt_to_llm_inputs_async
(
prompt
[
"encoder_prompt"
],
request_id
=
request_id
,
)
prompt
[
"encoder_prompt"
])
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
encoder_inputs
=
await
encoder_task
decoder_inputs
=
None
else
:
decoder_task
=
self
.
_prompt_to_llm_inputs_async
(
decoder_input
,
request_id
=
request_id
,
)
decoder_task
=
self
.
_prompt_to_llm_inputs_async
(
decoder_input
)
encoder_inputs
,
decoder_inputs
=
await
asyncio
.
gather
(
encoder_task
,
decoder_task
)
...
...
@@ -668,10 +641,7 @@ class InputPreprocessor:
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
encoder_inputs
,
decoder_inputs
))
else
:
inputs
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt
,
request_id
=
request_id
,
)
inputs
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt
)
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
# Encoder-Decoder Multimodal model
...
...
@@ -704,7 +674,6 @@ class InputPreprocessor:
def
_process_decoder_only_prompt
(
self
,
prompt
:
SingletonPrompt
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
...
...
@@ -716,7 +685,6 @@ class InputPreprocessor:
Arguments:
* prompt: input prompt
* request_id
* lora_request
* prompt_adapter_request
* return_mm_hashes
...
...
@@ -728,7 +696,6 @@ class InputPreprocessor:
prompt_comps
=
self
.
_prompt_to_llm_inputs
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hashes
,
)
...
...
@@ -741,7 +708,6 @@ class InputPreprocessor:
async
def
_process_decoder_only_prompt_async
(
self
,
prompt
:
SingletonPrompt
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
...
...
@@ -749,7 +715,6 @@ class InputPreprocessor:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hashes
,
)
...
...
@@ -762,7 +727,6 @@ class InputPreprocessor:
def
preprocess
(
self
,
prompt
:
PromptType
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
...
...
@@ -774,10 +738,7 @@ class InputPreprocessor:
"returned until they are supported on vLLM V1."
)
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return
self
.
_process_encoder_decoder_prompt
(
prompt
,
request_id
=
request_id
,
)
return
self
.
_process_encoder_decoder_prompt
(
prompt
)
if
is_explicit_encoder_decoder_prompt
(
prompt
):
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
...
...
@@ -786,7 +747,6 @@ class InputPreprocessor:
# Decoder-only operation
return
self
.
_process_decoder_only_prompt
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
return_mm_hashes
=
return_mm_hashes
,
...
...
@@ -795,7 +755,6 @@ class InputPreprocessor:
async
def
preprocess_async
(
self
,
prompt
:
PromptType
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
...
...
@@ -807,10 +766,7 @@ class InputPreprocessor:
"returned until they are supported on vLLM V1."
)
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return
await
self
.
_process_encoder_decoder_prompt_async
(
prompt
,
request_id
=
request_id
,
)
return
await
self
.
_process_encoder_decoder_prompt_async
(
prompt
)
if
is_explicit_encoder_decoder_prompt
(
prompt
):
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
...
...
@@ -819,7 +775,6 @@ class InputPreprocessor:
# Decoder-only operation
return
await
self
.
_process_decoder_only_prompt_async
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
return_mm_hashes
=
return_mm_hashes
,
...
...
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
View file @
da6ea29f
...
...
@@ -33,7 +33,6 @@ class BaseTokenizerGroup(ABC):
@
abstractmethod
def
encode
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
"""Encode a prompt using the tokenizer group."""
...
...
@@ -43,7 +42,6 @@ class BaseTokenizerGroup(ABC):
async
def
encode_async
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
"""Encode a prompt using the tokenizer group."""
...
...
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
View file @
da6ea29f
...
...
@@ -113,7 +113,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
def
encode
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
"""Encode a prompt using the tokenizer group.
...
...
@@ -133,8 +132,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
original_actor
=
actor
try
:
ret
=
ray
.
get
(
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
actor
.
encode
.
remote
(
prompt
=
prompt
,
lora_request
=
lora_request
,
add_special_tokens
=
add_special_tokens
))
except
ActorDiedError
as
e
:
...
...
@@ -145,8 +143,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
actor
=
self
.
_init_actor
()
try
:
ret
=
ray
.
get
(
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
actor
.
encode
.
remote
(
prompt
=
prompt
,
lora_request
=
lora_request
,
add_special_tokens
=
add_special_tokens
))
except
ActorDiedError
as
e
:
...
...
@@ -164,7 +161,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
async
def
encode_async
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
"""Encode a prompt using the tokenizer group.
...
...
@@ -184,7 +180,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
original_actor
=
actor
try
:
ret
=
await
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
,
add_special_tokens
=
add_special_tokens
)
...
...
@@ -196,7 +191,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
actor
=
self
.
_init_actor
()
try
:
ret
=
await
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
,
add_special_tokens
=
add_special_tokens
)
...
...
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
View file @
da6ea29f
...
...
@@ -56,7 +56,6 @@ class TokenizerGroup(BaseTokenizerGroup):
def
encode
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
tokenizer
=
self
.
get_lora_tokenizer
(
lora_request
)
...
...
@@ -69,7 +68,6 @@ class TokenizerGroup(BaseTokenizerGroup):
async
def
encode_async
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
tokenizer
=
await
self
.
get_lora_tokenizer_async
(
lora_request
)
...
...
vllm/v1/engine/async_llm.py
View file @
da6ea29f
...
...
@@ -4,6 +4,7 @@ import asyncio
import
logging
import
os
from
collections.abc
import
AsyncGenerator
,
Mapping
from
copy
import
copy
from
typing
import
Optional
,
Union
import
numpy
as
np
...
...
@@ -25,6 +26,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Device
,
cdiv
,
kill_process_tree
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.engine.parallel_sampling
import
ParentRequest
...
...
@@ -177,33 +179,44 @@ class AsyncLLM(EngineClient):
)
->
asyncio
.
Queue
[
RequestOutput
]:
"""Add new request to the AsyncLLM."""
#
1)
Create a new output queue for the request.
# Create a new output queue for the request.
queue
:
asyncio
.
Queue
[
RequestOutput
]
=
asyncio
.
Queue
()
# 2) Fan out child requests (for n>1)
parent_req
=
ParentRequest
.
from_params
(
request_id
,
params
)
# Convert Input --> Request.
request
=
self
.
processor
.
process_inputs
(
request_id
,
prompt
,
params
,
arrival_time
,
lora_request
,
trace_headers
,
prompt_adapter_request
,
priority
)
n
=
params
.
n
if
isinstance
(
params
,
SamplingParams
)
else
1
for
idx
in
range
(
n
):
if
parent_req
is
not
None
:
request_id
,
params
=
parent_req
.
get_child_info
(
idx
)
# 3) Convert Input --> Request.
request
=
self
.
processor
.
process_inputs
(
request_id
,
prompt
,
params
,
arrival_time
,
lora_request
,
trace_headers
,
prompt_adapter_request
,
priority
)
if
n
==
1
:
await
self
.
_add_request
(
request
,
None
,
0
,
queue
)
return
queue
# 4) Add the request to OutputProcessor (this process).
self
.
output_processor
.
add_request
(
request
,
parent_req
,
idx
,
queue
)
# Fan out child requests (for n>1).
parent_request
=
ParentRequest
(
request_id
,
params
)
for
idx
in
range
(
n
):
request_id
,
params
=
parent_request
.
get_child_info
(
idx
)
child_request
=
request
if
idx
==
n
-
1
else
copy
(
request
)
child_request
.
request_id
=
request_id
child_request
.
sampling_params
=
params
await
self
.
_add_request
(
child_request
,
parent_request
,
idx
,
queue
)
return
queue
# 5) Add the EngineCoreRequest to EngineCore (separate process).
await
self
.
engine_core
.
add_request_async
(
request
)
async
def
_add_request
(
self
,
request
:
EngineCoreRequest
,
parent_req
:
Optional
[
ParentRequest
],
index
:
int
,
queue
:
asyncio
.
Queue
[
RequestOutput
]):
if
self
.
log_requests
:
logger
.
info
(
"Added request %s."
,
request_id
)
# Add the request to OutputProcessor (this process).
self
.
output_processor
.
add_request
(
request
,
parent_req
,
index
,
queue
)
return
queue
# Add the EngineCoreRequest to EngineCore (separate process).
await
self
.
engine_core
.
add_request_async
(
request
)
if
self
.
log_requests
:
logger
.
info
(
"Added request %s."
,
request
.
request_id
)
# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
...
...
vllm/v1/engine/llm_engine.py
View file @
da6ea29f
# SPDX-License-Identifier: Apache-2.0
from
collections.abc
import
Mapping
from
copy
import
copy
from
typing
import
Optional
,
Union
from
typing_extensions
import
TypeVar
...
...
@@ -179,25 +180,34 @@ class LLMEngine:
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
# 1) Fan out child requests (for n>1)
parent_req
=
ParentRequest
.
from_params
(
request_id
,
params
)
n
=
params
.
n
if
isinstance
(
params
,
SamplingParams
)
else
1
for
idx
in
range
(
n
):
if
parent_req
is
not
None
:
request_id
,
params
=
parent_req
.
get_child_info
(
idx
)
# 2) Process raw inputs into the request.
request
=
self
.
processor
.
process_inputs
(
request_id
,
prompt
,
params
,
arrival_time
,
lora_request
,
trace_headers
,
prompt_adapter_request
,
priority
)
# Process raw inputs into the request.
request
=
self
.
processor
.
process_inputs
(
request_id
,
prompt
,
params
,
arrival_time
,
lora_request
,
trace_headers
,
prompt_adapter_request
,
priority
)
# 3) Make a new RequestState and queue.
self
.
output_processor
.
add_request
(
request
,
parent_req
,
idx
)
n
=
params
.
n
if
isinstance
(
params
,
SamplingParams
)
else
1
# 3) Add the request to EngineCore.
if
n
==
1
:
# Make a new RequestState and queue.
self
.
output_processor
.
add_request
(
request
,
None
,
0
)
# Add the request to EngineCore.
self
.
engine_core
.
add_request
(
request
)
return
# Fan out child requests (for n>1).
parent_req
=
ParentRequest
(
request_id
,
params
)
for
idx
in
range
(
n
):
request_id
,
params
=
parent_req
.
get_child_info
(
idx
)
child_request
=
request
if
idx
==
n
-
1
else
copy
(
request
)
child_request
.
request_id
=
request_id
child_request
.
sampling_params
=
params
# Make a new RequestState and queue.
self
.
output_processor
.
add_request
(
child_request
,
parent_req
,
idx
)
# Add the request to EngineCore.
self
.
engine_core
.
add_request
(
child_request
)
def
step
(
self
)
->
list
[
RequestOutput
]:
...
...
vllm/v1/engine/parallel_sampling.py
View file @
da6ea29f
# SPDX-License-Identifier: Apache-2.0
from
copy
import
copy
from
typing
import
Optional
,
Union
from
typing
import
Optional
from
vllm.outputs
import
CompletionOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.v1.metrics.stats
import
IterationStats
...
...
@@ -43,16 +42,6 @@ class ParentRequest:
self
.
max_num_generation_tokens
=
0
self
.
cached_child_sampling_params
=
None
@
classmethod
def
from_params
(
cls
,
request_id
:
str
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
)
->
Optional
[
'ParentRequest'
]:
if
not
isinstance
(
params
,
SamplingParams
)
or
params
.
n
==
1
:
return
None
return
cls
(
request_id
,
params
)
def
_get_child_sampling_params
(
self
,
index
:
int
,
...
...
vllm/v1/engine/processor.py
View file @
da6ea29f
...
...
@@ -173,7 +173,6 @@ class Processor:
# 3. Apply prompt adapter to prompt token ids if one exists.
processed_inputs
:
ProcessorInputs
=
self
.
input_preprocessor
.
preprocess
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
return_mm_hashes
=
self
.
use_hash
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment