Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7eb4a51c
Unverified
Commit
7eb4a51c
authored
Aug 09, 2024
by
Cyrus Leung
Committed by
GitHub
Aug 09, 2024
Browse files
[Core] Support serving encoder/decoder models (#7258)
parent
0fa14907
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
563 additions
and
394 deletions
+563
-394
.github/workflows/mypy.yaml
.github/workflows/mypy.yaml
+1
-1
examples/offline_inference_encoder_decoder.py
examples/offline_inference_encoder_decoder.py
+4
-4
requirements-common.txt
requirements-common.txt
+1
-1
requirements-lint.txt
requirements-lint.txt
+1
-1
tests/conftest.py
tests/conftest.py
+19
-13
tests/distributed/test_basic_distributed_correctness_enc_dec.py
...distributed/test_basic_distributed_correctness_enc_dec.py
+1
-1
tests/entrypoints/openai/test_encoder_decoder.py
tests/entrypoints/openai/test_encoder_decoder.py
+50
-0
tests/models/test_bart.py
tests/models/test_bart.py
+27
-11
tests/models/utils.py
tests/models/utils.py
+0
-11
tests/test_inputs.py
tests/test_inputs.py
+1
-1
vllm/config.py
vllm/config.py
+10
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+130
-24
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+151
-171
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+2
-3
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+2
-2
vllm/entrypoints/openai/logits_processors.py
vllm/entrypoints/openai/logits_processors.py
+5
-3
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+1
-1
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+10
-11
vllm/inputs/data.py
vllm/inputs/data.py
+72
-135
vllm/inputs/parse.py
vllm/inputs/parse.py
+75
-0
No files found.
.github/workflows/mypy.yaml
View file @
7eb4a51c
...
...
@@ -25,7 +25,7 @@ jobs:
-
name
:
Install dependencies
run
:
|
python -m pip install --upgrade pip
pip install mypy==1.
9.0
pip install mypy==1.
11.1
pip install types-setuptools
pip install types-PyYAML
pip install types-requests
...
...
examples/offline_inference_encoder_decoder.py
View file @
7eb4a51c
...
...
@@ -4,8 +4,8 @@ encoder/decoder models, specifically BART
'''
from
vllm
import
LLM
,
SamplingParams
from
vllm.inputs
import
ExplicitEncoderDecoderPrompt
,
TextPrompt
,
TokensPrompt
from
vllm.utils
import
zip_enc_dec_prompt
_list
s
from
vllm.inputs
import
(
ExplicitEncoderDecoderPrompt
,
TextPrompt
,
TokensPrompt
,
zip_enc_dec_prompts
)
dtype
=
"float"
...
...
@@ -61,9 +61,9 @@ enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
)
# - Finally, here's a useful helper function for zipping encoder and
# decoder prompt
list
s together into a list of ExplicitEncoderDecoderPrompt
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
# instances
zipped_prompt_list
=
zip_enc_dec_prompt
_list
s
(
zipped_prompt_list
=
zip_enc_dec_prompts
(
[
'An encoder prompt'
,
'Another encoder prompt'
],
[
'A decoder prompt'
,
'Another decoder prompt'
])
...
...
requirements-common.txt
View file @
7eb4a51c
...
...
@@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.3
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions
typing_extensions
>= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
gguf == 0.9.1
requirements-lint.txt
View file @
7eb4a51c
...
...
@@ -8,7 +8,7 @@ isort==5.13.2
clang-format==18.1.5
# type checking
mypy==1.
9.0
mypy==1.
11.1
types-PyYAML
types-requests
types-setuptools
tests/conftest.py
View file @
7eb4a51c
...
...
@@ -3,6 +3,7 @@ import gc
import
os
import
sys
from
collections
import
UserList
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
TypedDict
,
TypeVar
,
Union
import
pytest
...
...
@@ -14,20 +15,19 @@ from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForVision2Seq
,
AutoTokenizer
,
BatchEncoding
,
BatchFeature
)
from
tests.models.utils
import
DecoderPromptType
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.image
import
ImageAsset
from
vllm.config
import
TokenizerPoolConfig
from
vllm.connections
import
global_http_connection
from
vllm.distributed
import
(
destroy_distributed_environment
,
destroy_model_parallel
)
from
vllm.inputs
import
TextPrompt
from
vllm.inputs
import
(
ExplicitEncoderDecoderPrompt
,
TextPrompt
,
to_enc_dec_tuple_list
,
zip_enc_dec_prompts
)
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
SampleLogprobs
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
cuda_device_count_stateless
,
is_cpu
,
to_enc_dec_tuple_list
,
zip_enc_dec_prompt_lists
)
is_cpu
)
logger
=
init_logger
(
__name__
)
...
...
@@ -124,10 +124,16 @@ def example_prompts() -> List[str]:
return
prompts
class
DecoderPromptType
(
Enum
):
"""For encoder/decoder models only."""
CUSTOM
=
1
NONE
=
2
EMPTY_STR
=
3
@
pytest
.
fixture
def
example_encoder_decoder_prompts
()
\
->
Dict
[
DecoderPromptType
,
Tuple
[
List
[
str
],
List
[
Optional
[
str
]]]]:
def
example_encoder_decoder_prompts
(
)
->
Dict
[
DecoderPromptType
,
List
[
ExplicitEncoderDecoderPrompt
]]:
'''
Returns an encoder prompt list and a decoder prompt list, wherein each pair
of same-index entries in both lists corresponds to an (encoder prompt,
...
...
@@ -150,11 +156,11 @@ def example_encoder_decoder_prompts() \
# NONE decoder prompt type
return
{
DecoderPromptType
.
NONE
:
zip_enc_dec_prompt
_list
s
(
encoder_prompts
,
none_decoder_prompts
),
zip_enc_dec_prompts
(
encoder_prompts
,
none_decoder_prompts
),
DecoderPromptType
.
EMPTY_STR
:
zip_enc_dec_prompt
_list
s
(
encoder_prompts
,
empty_str_decoder_prompts
),
zip_enc_dec_prompts
(
encoder_prompts
,
empty_str_decoder_prompts
),
DecoderPromptType
.
CUSTOM
:
zip_enc_dec_prompt
_list
s
(
encoder_prompts
,
custom_decoder_prompts
),
zip_enc_dec_prompts
(
encoder_prompts
,
custom_decoder_prompts
),
}
...
...
@@ -444,7 +450,7 @@ class HfRunner:
def
generate_encoder_decoder_greedy_logprobs_limit
(
self
,
encoder_decoder_prompts
:
Tuple
[
Lis
t
[
str
]
,
List
[
str
]],
encoder_decoder_prompts
:
List
[
ExplicitEncoderDecoderPromp
t
[
str
,
str
]],
max_tokens
:
int
,
num_logprobs
:
int
,
**
kwargs
:
Any
,
...
...
@@ -608,7 +614,7 @@ class VllmRunner:
def
generate_encoder_decoder_w_logprobs
(
self
,
encoder_decoder_prompts
:
Tuple
[
Lis
t
[
str
]
,
List
[
str
]],
encoder_decoder_prompts
:
List
[
ExplicitEncoderDecoderPromp
t
[
str
,
str
]],
sampling_params
:
SamplingParams
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
'''
...
...
@@ -653,7 +659,7 @@ class VllmRunner:
def
generate_encoder_decoder_greedy_logprobs
(
self
,
encoder_decoder_prompts
:
Tuple
[
Lis
t
[
str
]
,
List
[
str
]],
encoder_decoder_prompts
:
List
[
ExplicitEncoderDecoderPromp
t
[
str
,
str
]],
max_tokens
:
int
,
num_logprobs
:
int
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
...
...
tests/distributed/test_basic_distributed_correctness_enc_dec.py
View file @
7eb4a51c
...
...
@@ -11,9 +11,9 @@ pytest distributed/test_basic_distributed_correctness_enc_dec.py
import
pytest
from
tests.models.utils
import
DecoderPromptType
from
vllm.utils
import
cuda_device_count_stateless
from
..conftest
import
DecoderPromptType
from
..models.utils
import
check_logprobs_close
from
..utils
import
fork_new_process_for_each_test
...
...
tests/entrypoints/openai/test_encoder_decoder.py
0 → 100644
View file @
7eb4a51c
import
openai
import
pytest
from
...utils
import
RemoteOpenAIServer
MODEL_NAME
=
"facebook/bart-base"
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
args
=
[
"--dtype"
,
"bfloat16"
,
"--enforce-eager"
,
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
@
pytest
.
fixture
(
scope
=
"module"
)
def
client
(
server
):
return
server
.
get_async_client
()
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
async
def
test_single_completion
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
"Hello, my name is"
,
max_tokens
=
5
,
temperature
=
0.0
)
assert
completion
.
id
is
not
None
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
1
choice
=
completion
.
choices
[
0
]
assert
len
(
choice
.
text
)
>=
5
assert
choice
.
finish_reason
==
"length"
assert
completion
.
usage
==
openai
.
types
.
CompletionUsage
(
completion_tokens
=
5
,
prompt_tokens
=
2
,
total_tokens
=
7
)
# test using token IDs
completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
[
0
,
0
,
0
,
0
,
0
],
max_tokens
=
5
,
temperature
=
0.0
,
)
assert
len
(
completion
.
choices
[
0
].
text
)
>=
1
tests/models/test_bart.py
View file @
7eb4a51c
...
...
@@ -2,6 +2,8 @@
Run `pytest tests/models/test_bart.py`.
"""
from
typing
import
List
,
Optional
,
Tuple
from
vllm.utils
import
is_cpu
if
not
is_cpu
():
...
...
@@ -11,22 +13,31 @@ if not is_cpu():
import
pytest
from
tests.models.utils
import
DecoderPromptType
from
vllm.sequence
import
SampleLogprobs
from
..conftest
import
DecoderPromptType
from
.utils
import
check_logprobs_close
MODELS
=
[
"facebook/bart-base"
,
"facebook/bart-large-cnn"
]
DECODER_PROMPT_TYPES
=
([
DecoderPromptType
.
CUSTOM
,
DecoderPromptType
.
EMPTY_STR
,
DecoderPromptType
.
NONE
])
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]],
decoder_prompt_type
:
DecoderPromptType
,
):
"""Sanitize vllm output to be comparable with hf output."""
output_ids
,
output_str
,
out_logprobs
=
vllm_output
hf_output_str
=
output_str
+
"</s>"
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
:
hf_output_str
=
"<s>"
+
hf_output_str
return
output_ids
,
hf_output_str
,
out_logprobs
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
,
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
DECODER_PROMPT_TYPES
)
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
list
(
DecoderPromptType
)
)
def
test_models
(
hf_runner
,
vllm_runner
,
...
...
@@ -146,8 +157,13 @@ if not is_cpu():
hf_skip_tokens
=
(
1
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
else
0
)
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
num_outputs_0_skip_tokens
=
hf_skip_tokens
)
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
[
vllm_to_hf_output
(
vllm_output
,
decoder_prompt_type
)
for
vllm_output
in
vllm_outputs
],
name_0
=
"hf"
,
name_1
=
"vllm"
,
num_outputs_0_skip_tokens
=
hf_skip_tokens
,
)
tests/models/utils.py
View file @
7eb4a51c
import
warnings
from
enum
import
Enum
from
typing
import
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
vllm.sequence
import
SampleLogprobs
...
...
@@ -136,13 +135,3 @@ def check_logprobs_close(
warnings
.
simplefilter
(
"always"
)
warnings
.
warn
(
fail_msg
,
stacklevel
=
2
)
class
DecoderPromptType
(
Enum
):
'''
For encoder/decoder models only -
'''
CUSTOM
=
1
NONE
=
2
EMPTY_STR
=
3
tests/test_inputs.py
View file @
7eb4a51c
...
...
@@ -2,7 +2,7 @@ from typing import List
import
pytest
from
vllm.inputs
import
parse_and_batch_prompt
from
vllm.inputs
.parse
import
parse_and_batch_prompt
STRING_INPUTS
=
[
''
,
...
...
vllm/config.py
View file @
7eb4a51c
...
...
@@ -464,6 +464,16 @@ class ModelConfig:
if
t
!=
"attention"
])
@
property
def
is_encoder_decoder_model
(
self
)
->
bool
:
"""Extract the HF encoder/decoder model flag."""
return
getattr
(
self
.
hf_config
,
"is_encoder_decoder"
,
False
)
@
property
def
is_embedding_model
(
self
)
->
bool
:
"""Extract the embedding model flag."""
return
self
.
embedding_mode
class
CacheConfig
:
"""Configuration for the KV cache.
...
...
vllm/engine/async_llm_engine.py
View file @
7eb4a51c
...
...
@@ -5,6 +5,7 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
from
transformers
import
PreTrainedTokenizer
from
typing_extensions
import
assert_never
import
vllm.envs
as
envs
from
vllm.config
import
(
DecodingConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
...
...
@@ -12,11 +13,14 @@ from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
(
DecoderPromptComponents
,
LLMEngine
,
PromptComponents
)
from
vllm.engine.metrics
import
StatLoggerBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.inputs
import
LLMInputs
,
PromptInputs
from
vllm.inputs
import
(
EncoderDecoderLLMInputs
,
LLMInputs
,
PromptInputs
,
SingletonPromptInputs
)
from
vllm.inputs.parse
import
is_explicit_encoder_decoder_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
...
...
@@ -293,38 +297,138 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
await
self
.
model_executor
.
stop_remote_worker_execution_loop_async
()
async
def
process_model_inputs_async
(
async
def
_tokenize_prompt_async
(
self
,
prompt
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
],
)
->
List
[
int
]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer
=
self
.
get_tokenizer_group
(
"prompts must be None if "
"skip_tokenizer_init is True"
)
return
await
tokenizer
.
encode_async
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
)
async
def
_extract_prompt_components_async
(
self
,
inputs
:
SingletonPromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
PromptComponents
:
"""Async version of :meth:`_extract_prompt_components`."""
if
isinstance
(
inputs
,
str
):
prompt
=
inputs
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
multi_modal_data
=
None
elif
isinstance
(
inputs
,
dict
):
if
"prompt_token_ids"
in
inputs
:
prompt
=
None
prompt_token_ids
=
inputs
[
"prompt_token_ids"
]
else
:
# NOTE: This extra assignment is required to pass mypy
prompt
=
parsed_prompt
=
inputs
[
"prompt"
]
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
parsed_prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
else
:
assert_never
(
inputs
)
return
prompt
,
prompt_token_ids
,
multi_modal_data
async
def
_process_encoder_decoder_prompt_async
(
self
,
inputs
:
PromptInputs
,
request_id
:
str
,
)
->
EncoderDecoderLLMInputs
:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps
:
PromptComponents
decoder_comps
:
DecoderPromptComponents
if
is_explicit_encoder_decoder_prompt
(
inputs
):
encoder_task
=
self
.
_extract_prompt_components_async
(
inputs
[
"encoder_prompt"
],
request_id
=
request_id
,
)
if
(
decoder_input
:
=
inputs
[
"decoder_prompt"
])
is
None
:
encoder_comps
=
await
encoder_task
decoder_comps
=
None
,
None
,
None
else
:
decoder_task
=
self
.
_extract_prompt_components_async
(
decoder_input
,
request_id
=
request_id
,
)
encoder_comps
,
decoder_comps
=
await
asyncio
.
gather
(
encoder_task
,
decoder_task
)
else
:
encoder_comps
=
await
self
.
_extract_prompt_components_async
(
inputs
,
request_id
=
request_id
,
)
decoder_comps
=
None
,
None
,
None
return
self
.
_build_enc_dec_llm_inputs
(
encoder_comps
,
decoder_comps
)
async
def
_process_decoder_only_prompt_async
(
self
,
inputs
:
SingletonPromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
LLMInputs
:
if
isinstance
(
inputs
,
str
):
inputs
=
{
"prompt"
:
inputs
}
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps
=
await
self
.
_extract_prompt_components_async
(
inputs
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
if
"prompt_token_ids"
not
in
inputs
:
tokenizer
=
self
.
get_tokenizer_group
(
"prompts must be None if "
"skip_tokenizer_init is True"
)
return
self
.
_build_decoder_only_llm_inputs
(
prompt_comps
,
prompt_adapter_request
=
prompt_adapter_request
,
)
prompt_token_ids
=
await
tokenizer
.
encode_async
(
async
def
process_model_inputs_async
(
self
,
inputs
:
PromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
Union
[
LLMInputs
,
EncoderDecoderLLMInputs
]:
"""Async version of :meth:`process_model_inputs`."""
if
self
.
is_encoder_decoder_model
():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs
=
await
self
.
_process_encoder_decoder_prompt_async
(
inputs
,
request_id
=
request_id
,
prompt
=
inputs
[
"prompt"
],
lora_request
=
lora_request
)
)
else
:
prompt_token_ids
=
inputs
[
"prompt_token_ids"
]
if
is_explicit_encoder_decoder_prompt
(
inputs
):
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
"to decoder-only models"
)
if
prompt_adapter_request
:
prompt_token_ids
=
[
0
]
*
prompt_adapter_request
.
prompt_adapter_num_virtual_tokens
+
\
prompt_token_ids
llm_inputs
=
LLMInputs
(
prompt_token_ids
=
prompt_token_ids
,
prompt
=
inputs
.
get
(
"prompt"
),
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
))
# Decoder-only operation
model_inputs
=
await
self
.
_process_decoder_only_prompt_async
(
inputs
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
return
self
.
input_processor
(
llm
_inputs
)
return
self
.
input_processor
(
model
_inputs
)
async
def
add_request_async
(
self
,
...
...
@@ -336,6 +440,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
"""Async version of :meth:`add_request`."""
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
...
...
@@ -343,10 +448,11 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time
=
time
.
time
()
processed_inputs
=
await
self
.
process_model_inputs_async
(
inputs
,
request_id
=
request_id
,
inputs
=
inputs
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
)
self
.
_add_processed_request
(
request_id
=
request_id
,
...
...
vllm/engine/llm_engine.py
View file @
7eb4a51c
...
...
@@ -5,6 +5,8 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Type
,
TypeVar
,
Union
from
typing_extensions
import
assert_never
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
...
...
@@ -22,10 +24,12 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from
vllm.engine.output_processor.util
import
create_output_by_sequence_group
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
(
INPUT_REGISTRY
,
LLMInputs
,
PromptInputs
,
get_prompt_type
)
from
vllm.inputs
import
(
INPUT_REGISTRY
,
EncoderDecoderLLMInputs
,
LLMInputs
,
PromptInputs
,
SingletonPromptInputs
)
from
vllm.inputs.parse
import
is_explicit_encoder_decoder_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.outputs
import
(
EmbeddingRequestOutput
,
RequestOutput
,
RequestOutputFactory
)
from
vllm.pooling_params
import
PoolingParams
...
...
@@ -43,8 +47,7 @@ from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer
,
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.utils
import
(
Counter
,
is_embedding_model_config
,
is_encoder_decoder_model_config
)
from
vllm.utils
import
Counter
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
...
...
@@ -66,6 +69,11 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
_O
=
TypeVar
(
"_O"
,
RequestOutput
,
EmbeddingRequestOutput
)
PromptComponents
=
Tuple
[
Optional
[
str
],
List
[
int
],
Optional
[
MultiModalDataDict
]]
DecoderPromptComponents
=
Tuple
[
Optional
[
str
],
Optional
[
List
[
int
]],
Optional
[
MultiModalDataDict
]]
class
LLMEngine
:
"""An LLM engine that receives requests and generates texts.
...
...
@@ -524,7 +532,7 @@ class LLMEngine:
return
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
eos_token_id
def
_get_decoder_start_token_id
(
self
,
)
->
Optional
[
int
]:
def
_get_decoder_start_token_id
(
self
)
->
Optional
[
int
]:
'''
Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the
...
...
@@ -553,7 +561,7 @@ class LLMEngine:
def
_add_processed_request
(
self
,
request_id
:
str
,
processed_inputs
:
LLMInputs
,
processed_inputs
:
Union
[
LLMInputs
,
EncoderDecoderLLMInputs
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
float
,
lora_request
:
Optional
[
LoRARequest
],
...
...
@@ -613,11 +621,11 @@ class LLMEngine:
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
self
.
model_executor
.
stop_remote_worker_execution_loop
()
_LLMInputComponentsType
=
Tuple
[
str
,
List
[
int
]
,
]
_LLMInputComponentsType
=
Tuple
[
str
,
List
[
int
]]
def
_prepare_decoder_input_ids_for_generation
(
self
,
decoder_input_ids
:
Optional
[
List
[
int
]]
=
None
,
decoder_input_ids
:
Optional
[
List
[
int
]],
)
->
List
[
int
]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
...
...
@@ -639,14 +647,13 @@ class LLMEngine:
* Processed token list
"""
decoder_start_token_id
:
Optional
[
int
]
=
(
self
.
_get_decoder_start_token_id
())
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
()
assert
decoder_start_token_id
is
not
None
if
decoder_input_ids
is
None
:
# no decoder prompt input ->
# use decoder_start_token_id as decoder_input_ids
(
decoder_input_ids
)
=
self
.
_get_default_enc_dec_decoder_prompt
()
decoder_input_ids
=
self
.
_get_default_enc_dec_decoder_prompt
()
if
(
len
(
decoder_input_ids
)
==
0
or
decoder_input_ids
[
0
]
!=
decoder_start_token_id
):
...
...
@@ -657,12 +664,11 @@ class LLMEngine:
def
_tokenize_prompt
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
str
]
=
None
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
,
)
->
List
[
int
]:
'''
Wrapper around application of the model's
tokenizer.
Wrapper around application of the model's tokenizer.
Arguments:
...
...
@@ -678,87 +684,72 @@ class LLMEngine:
tokenizer
=
self
.
get_tokenizer_group
(
"prompts must be None if "
"skip_tokenizer_init is True"
)
prompt_token_ids
=
tokenizer
.
encode
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
)
return
prompt_token_ids
return
tokenizer
.
encode
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
)
def
_extract_
single_prompt_for_enc_dec_input
(
def
_extract_
prompt_components
(
self
,
inputs
:
Optional
[
PromptInputs
],
request_id
:
Optional
[
str
]
=
None
,
ptype
:
Optional
[
str
]
=
None
,
is_encoder_prompt
:
bool
=
False
,
)
->
Tuple
[
Optional
[
str
],
List
[
int
]]:
inputs
:
SingletonPromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
PromptComponents
:
'''
Only for encoder/decoder models:
Extract prompt & prompt_token_ids from any single
encoder or decoder input prompt. For encoder input prompts
in particular, also extract multi-modal data.
This function handles the following scenarios:
1. The user supplied a singleton encoder prompt
& the prompt/prompt-token-ids must be extracted.
2. The user supplied an explicit encoder/decoder
prompt & the prompt/prompt-token-ids must be
extracted from either the encoder and decoder prompts.
For decoder prompts in particular (scenario 2), special
processing is applied to the returned decoder token ids.
Extract the components of any single encoder or decoder input prompt.
Arguments:
* request_id
* ptype: str representation of the input prompt type.
If `ptype` is `None`, assume that the prompt
type is unknown and must be inferred. This is the
case for ExplicitEncoderDecoder sub-prompts.
* inputs: single encoder or decoder input prompt
* is_encoder_prompt: True if encoder input prompt.
If False, decoder prompt tokens
are preprocessed.
* lora_request: this is only valid for decoder prompts
Returns:
* prompt
* prompt_token_ids
* multi_modal_data
'''
prompt_token_ids
=
None
ptype
=
(
get_prompt_type
(
inputs
)
if
ptype
is
None
else
ptype
)
if
inputs
is
None
:
prompt
=
None
elif
ptype
==
'str'
:
if
isinstance
(
inputs
,
str
):
prompt
=
inputs
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
elif
ptype
==
'TokensPrompt'
:
prompt
=
None
prompt_token_ids
=
inputs
[
'prompt_token_ids'
]
multi_modal_data
=
None
elif
isinstance
(
inputs
,
dict
):
if
"prompt_token_ids"
in
inputs
:
prompt
=
None
prompt_token_ids
=
inputs
[
"prompt_token_ids"
]
else
:
# NOTE: This extra assignment is required to pass mypy
prompt
=
parsed_prompt
=
inputs
[
"prompt"
]
prompt_token_ids
=
self
.
_tokenize_prompt
(
parsed_prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
else
:
prompt
=
inputs
[
'prompt'
]
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt
,
request_id
=
request_id
,
)
assert_never
(
inputs
)
if
not
is_encoder_prompt
:
# Apply special pre-processing to
# decoder prompts
prompt_token_ids
=
(
self
.
_prepare_decoder_input_ids_for_generation
(
prompt_token_ids
,
))
return
prompt
,
prompt_token_ids
,
multi_modal_data
assert
prompt_token_ids
is
not
None
def
_apply_prompt_adapter
(
self
,
prompt_token_ids
:
List
[
int
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
List
[
int
]:
if
prompt_adapter_request
:
prompt_token_ids
=
(
[
0
]
*
prompt_adapter_request
.
prompt_adapter_num_virtual_tokens
+
prompt_token_ids
)
return
(
prompt
,
prompt_token_ids
,
)
return
prompt_token_ids
def
_get_default_enc_dec_decoder_prompt
(
self
,
)
->
List
[
int
]:
def
_get_default_enc_dec_decoder_prompt
(
self
)
->
List
[
int
]:
'''
Specifically for encoder/decoder models:
generate a default decoder prompt for when
...
...
@@ -792,18 +783,39 @@ class LLMEngine:
bos_token_id
=
self
.
_get_bos_token_id
()
assert
bos_token_id
is
not
None
prompt_token_ids
:
List
[
int
]
=
[
bos_token_id
]
return
prompt_token_ids
return
[
bos_token_id
]
def
_build_enc_dec_llm_inputs
(
self
,
encoder_comps
:
PromptComponents
,
decoder_comps
:
DecoderPromptComponents
,
)
->
EncoderDecoderLLMInputs
:
encoder_prompt
,
encoder_prompt_ids
,
encoder_mm_data
=
encoder_comps
decoder_prompt
,
decoder_prompt_ids
,
decoder_mm_data
=
decoder_comps
if
encoder_mm_data
is
not
None
or
decoder_mm_data
is
not
None
:
raise
ValueError
(
"Multi-modal encoder-decoder models are "
"not supported yet"
)
decoder_prompt_ids
=
(
self
.
_prepare_decoder_input_ids_for_generation
(
decoder_prompt_ids
))
return
EncoderDecoderLLMInputs
(
prompt_token_ids
=
decoder_prompt_ids
,
prompt
=
decoder_prompt
,
encoder_prompt_token_ids
=
encoder_prompt_ids
,
encoder_prompt
=
encoder_prompt
,
)
def
_process_encoder_decoder_prompt
(
self
,
inputs
:
PromptInputs
,
request_id
:
Optional
[
str
]
=
None
,
)
->
LLMInputs
:
request_id
:
str
,
)
->
EncoderDecoder
LLMInputs
:
'''
For encoder/decoder models only:
Process an input prompt
into an `
LLMInputs` instance.
Process an input prompt
into an
:class:`EncoderDecoder
LLMInputs` instance.
There are two types of input prompts:
singleton prompts which carry only the
...
...
@@ -830,136 +842,103 @@ class LLMEngine:
Returns:
*
`
LLMInputs` instance
*
:class:`EncoderDecoder
LLMInputs` instance
'''
ptype
=
get_prompt_type
(
inputs
)
# Obtain encoder and decoder prompt tokens. Note
# that, no matter what, the decoder
# prompt type is unknown.
if
ptype
==
"ExplicitEncoderDecoder"
:
# If input is explicit encoder/decoder prompt,
# then it remains to be determined what type
# of encoder prompt we have
extracted_encoder_prompt
=
inputs
.
get
(
'encoder_prompt'
)
encoder_ptype
=
None
# Extract decoder prompt from explicit
# encoder/decoder prompt
extracted_decoder_prompt
=
inputs
.
get
(
'decoder_prompt'
)
encoder_comps
:
PromptComponents
decoder_comps
:
DecoderPromptComponents
if
is_explicit_encoder_decoder_prompt
(
inputs
):
encoder_comps
=
self
.
_extract_prompt_components
(
inputs
[
"encoder_prompt"
],
request_id
=
request_id
,
)
if
(
decoder_input
:
=
inputs
[
"decoder_prompt"
])
is
None
:
decoder_comps
=
None
,
None
,
None
else
:
decoder_comps
=
self
.
_extract_prompt_components
(
decoder_input
,
request_id
=
request_id
,
)
else
:
# If input is singleton encoder prompt, then
# we know the encoder prompt type
extracted_encoder_prompt
=
inputs
encoder_ptype
=
ptype
# Decoder prompt is always unknown if
# encoder/decoder prompt is not explicit
extracted_decoder_prompt
=
None
# Invoke helper function to obtain encoder
# prompt and prompt token ids, either from
# singleton encoder prompt or from the
# encoder sub-prompt of an explicit
# encoder/decode scenario 2), special
# processing is applied to the returned decoder token ids
(
encoder_prompt
,
encoder_prompt_token_ids
,
)
=
self
.
_extract_single_prompt_for_enc_dec_input
(
extracted_encoder_prompt
,
request_id
=
request_id
,
ptype
=
encoder_ptype
,
is_encoder_prompt
=
True
,
)
encoder_comps
=
self
.
_extract_prompt_components
(
inputs
,
request_id
=
request_id
,
)
# Invoke helper method to obtain
# decoder prompt and prompt token ids.
#
# The helper method will detect the decoder
# prompt type.
#
# Helper method will also apply special
# preprocessing unique to decoder prompts.
(
decoder_prompt
,
decoder_prompt_token_ids
,
)
=
self
.
_extract_single_prompt_for_enc_dec_input
(
extracted_decoder_prompt
,
request_id
=
request_id
,
ptype
=
None
,
is_encoder_prompt
=
False
,
)
decoder_comps
=
None
,
None
,
None
return
LLMInputs
(
prompt_token_ids
=
decoder_prompt_token_ids
,
prompt
=
decoder_prompt
,
encoder_prompt_token_ids
=
encoder_prompt_token_ids
,
encoder_prompt
=
encoder_prompt
,
)
return
self
.
_build_enc_dec_llm_inputs
(
encoder_comps
,
decoder_comps
)
def
_build_decoder_only_llm_inputs
(
self
,
prompt_comps
:
PromptComponents
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
LLMInputs
:
prompt
,
prompt_token_ids
,
multi_modal_data
=
prompt_comps
prompt_token_ids
=
self
.
_apply_prompt_adapter
(
prompt_token_ids
,
prompt_adapter_request
=
prompt_adapter_request
)
return
LLMInputs
(
prompt_token_ids
=
prompt_token_ids
,
prompt
=
prompt
,
multi_modal_data
=
multi_modal_data
)
def
_process_decoder_only_prompt
(
self
,
inputs
:
PromptInputs
,
inputs
:
SingletonPromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
LLMInputs
:
'''
For decoder-only models:
Process an input prompt
into an `LLMInputs` instance.
Process an input prompt into an :class:`LLMInputs` instance.
Arguments:
* inputs: input prompt
* lora_request
* request_id
* lora_request
* prompt_adapter_request
Returns:
* `LLMInputs` instance
*
:class:
`LLMInputs` instance
'''
if
isinstance
(
inputs
,
str
):
inputs
=
{
"prompt"
:
inputs
}
prompt
=
inputs
.
get
(
"prompt"
)
if
"prompt_token_ids"
not
in
inputs
:
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
else
:
prompt_token_ids
=
inputs
[
"prompt_token_ids"
]
if
prompt_adapter_request
:
prompt_token_ids
=
(
[
0
]
*
prompt_adapter_request
.
prompt_adapter_num_virtual_tokens
+
prompt_token_ids
)
prompt_comps
=
self
.
_extract_prompt_components
(
inputs
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
return
LLMInputs
(
prompt_token_ids
=
prompt_token_ids
,
prompt
=
prompt
,
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
))
return
self
.
_build_decoder_only_llm_inputs
(
prompt_comps
,
prompt_adapter_request
=
prompt_adapter_request
,
)
def
process_model_inputs
(
self
,
request_id
:
str
,
inputs
:
PromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
LLMInputs
:
)
->
Union
[
LLMInputs
,
EncoderDecoderLLMInputs
]
:
if
self
.
is_encoder_decoder_model
():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs
=
self
.
_process_encoder_decoder_prompt
(
inputs
,
request_id
=
request_id
,
)
else
:
if
is_explicit_encoder_decoder_prompt
(
inputs
):
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
"to decoder-only models"
)
# Decoder-only operation
model_inputs
=
self
.
_process_decoder_only_prompt
(
inputs
,
...
...
@@ -1029,10 +1008,11 @@ class LLMEngine:
arrival_time
=
time
.
time
()
processed_inputs
=
self
.
process_model_inputs
(
inputs
,
request_id
=
request_id
,
inputs
=
inputs
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
)
self
.
_add_processed_request
(
request_id
=
request_id
,
...
...
@@ -1597,7 +1577,7 @@ class LLMEngine:
seq_span
.
set_attribute
(
SpanAttributes
.
LLM_LATENCY_E2E
,
e2e_time
)
def
is_encoder_decoder_model
(
self
):
return
is_encoder_decoder_model
_config
(
self
.
model_config
)
return
self
.
model_config
.
is_encoder_decoder_model
def
is_embedding_model
(
self
):
return
is_embedding_model
_config
(
self
.
model_config
)
return
self
.
model_config
.
is_embedding_model
vllm/entrypoints/chat_utils.py
View file @
7eb4a51c
...
...
@@ -2,8 +2,7 @@ import codecs
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
pathlib
import
Path
from
typing
import
(
Any
,
Awaitable
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
cast
,
final
)
from
typing
import
Any
,
Awaitable
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
cast
# yapf conflicts with isort for this block
# yapf: disable
...
...
@@ -59,7 +58,7 @@ ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam
]
@
final
# So that it should be compatible with Dict[str, str]
# TODO: Make fields ReadOnly once mypy supports it
class
ConversationMessage
(
TypedDict
):
role
:
str
content
:
str
...
...
vllm/entrypoints/llm.py
View file @
7eb4a51c
...
...
@@ -6,8 +6,8 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.inputs
import
(
PromptInputs
,
TextPrompt
,
TokensPrompt
,
parse_and_batch_prompt
)
from
vllm.inputs
import
PromptInputs
,
TextPrompt
,
TokensPrompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
...
...
vllm/entrypoints/openai/logits_processors.py
View file @
7eb4a51c
...
...
@@ -40,9 +40,11 @@ def _get_allowed_token_ids_logits_processor(
return
AllowedTokenIdsLogitsProcessor
(
allowed_token_ids
)
def
logit_bias_logits_processor
(
logit_bias
:
Dict
[
str
,
float
],
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
logit_bias_logits_processor
(
logit_bias
:
Dict
[
int
,
float
],
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
for
token_id
,
bias
in
logit_bias
.
items
():
logits
[
token_id
]
+=
bias
return
logits
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
7eb4a51c
...
...
@@ -22,7 +22,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeCompletionRequest
,
TokenizeRequest
)
# yapf: enable
from
vllm.inputs
import
parse_and_batch_prompt
from
vllm.inputs
.parse
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
...
...
vllm/inputs/__init__.py
View file @
7eb4a51c
from
.data
import
(
ExplicitEncoderDecoderPrompt
,
LLMInputs
,
ParsedText
,
ParsedToken
s
,
PromptInputs
,
SingletonPromptInputs
,
TextPrompt
,
TokensPrompt
,
get
_prompt
_type
,
is_valid_encoder_decoder_llm_inputs
,
parse_and_batch
_prompt
)
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
LLMInput
s
,
PromptInputs
,
SingletonPromptInputs
,
TextPrompt
,
TokensPrompt
,
build_explicit_enc_dec
_prompt
,
to_enc_dec_tuple_list
,
zip_enc_dec
_prompt
s
)
from
.registry
import
InputContext
,
InputRegistry
INPUT_REGISTRY
=
InputRegistry
()
...
...
@@ -14,18 +14,17 @@ See also:
"""
__all__
=
[
"ParsedText"
,
"ParsedTokens"
,
"parse_and_batch_prompt"
,
"TextPrompt"
,
"TokensPrompt"
,
"PromptInputs"
,
"SingletonPromptInputs"
,
"ExplicitEncoderDecoderPrompt"
,
"LLMInputs"
,
"EncoderDecoderLLMInputs"
,
"build_explicit_enc_dec_prompt"
,
"to_enc_dec_tuple_list"
,
"zip_enc_dec_prompts"
,
"INPUT_REGISTRY"
,
"InputContext"
,
"InputRegistry"
,
"get_prompt_type"
,
"is_valid_encoder_decoder_llm_inputs"
,
"ExplicitEncoderDecoderPrompt"
,
"SingletonPromptInputs"
,
]
vllm/inputs/data.py
View file @
7eb4a51c
from
typing
import
(
TYPE_CHECKING
,
List
,
Literal
,
Optional
,
Sequenc
e
,
TypedDict
,
Union
,
cast
,
overload
)
from
typing
import
(
TYPE_CHECKING
,
Generic
,
Iterable
,
List
,
Optional
,
Tupl
e
,
Union
)
from
typing_extensions
import
NotRequired
from
typing_extensions
import
NotRequired
,
TypedDict
,
TypeVar
if
TYPE_CHECKING
:
from
vllm.multimodal
import
MultiModalDataDict
class
ParsedText
(
TypedDict
):
content
:
str
is_tokens
:
Literal
[
False
]
class
ParsedTokens
(
TypedDict
):
content
:
List
[
int
]
is_tokens
:
Literal
[
True
]
# https://github.com/vllm-project/vllm/pull/4028
@
overload
def
parse_and_batch_prompt
(
prompt
:
Union
[
str
,
List
[
str
]])
->
Sequence
[
ParsedText
]:
...
@
overload
def
parse_and_batch_prompt
(
prompt
:
Union
[
List
[
int
],
List
[
List
[
int
]]])
->
Sequence
[
ParsedTokens
]:
...
def
parse_and_batch_prompt
(
prompt
:
Union
[
str
,
List
[
str
],
List
[
int
],
List
[
List
[
int
]]],
)
->
Union
[
Sequence
[
ParsedText
],
Sequence
[
ParsedTokens
]]:
if
isinstance
(
prompt
,
str
):
# case 1: a string
return
[
ParsedText
(
content
=
prompt
,
is_tokens
=
False
)]
if
isinstance
(
prompt
,
list
):
if
len
(
prompt
)
==
0
:
raise
ValueError
(
"please provide at least one prompt"
)
if
isinstance
(
prompt
[
0
],
str
):
# case 2: array of strings
return
[
ParsedText
(
content
=
elem
,
is_tokens
=
False
)
for
elem
in
cast
(
List
[
str
],
prompt
)
]
if
isinstance
(
prompt
[
0
],
int
):
# case 3: array of tokens
elem
=
cast
(
List
[
int
],
prompt
)
return
[
ParsedTokens
(
content
=
elem
,
is_tokens
=
True
)]
if
isinstance
(
prompt
[
0
],
list
):
if
len
(
prompt
[
0
])
==
0
:
raise
ValueError
(
"please provide at least one prompt"
)
if
isinstance
(
prompt
[
0
][
0
],
int
):
# case 4: array of token arrays
return
[
ParsedTokens
(
content
=
elem
,
is_tokens
=
True
)
for
elem
in
cast
(
List
[
List
[
int
]],
prompt
)
]
raise
ValueError
(
"prompt must be a string, array of strings, "
"array of tokens, or array of token arrays"
)
class
TextPrompt
(
TypedDict
):
"""Schema for a text prompt."""
...
...
@@ -103,39 +44,49 @@ Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e. ExplicitEncoderDecoderPrompt
prompts explicitly, i.e.
:class:`
ExplicitEncoderDecoderPrompt
`
A prompt of type SingletonPromptInputs may be employed
A prompt of type
:class:`
SingletonPromptInputs
`
may be employed
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
more than one prompt, i.e. ExplicitEncoderDecoderPrompt
more than one prompt, i.e.
:class:`
ExplicitEncoderDecoderPrompt
`
"""
_T1_co
=
TypeVar
(
"_T1_co"
,
bound
=
SingletonPromptInputs
,
default
=
SingletonPromptInputs
,
covariant
=
True
)
_T2_co
=
TypeVar
(
"_T2_co"
,
bound
=
SingletonPromptInputs
,
default
=
SingletonPromptInputs
,
covariant
=
True
)
class
ExplicitEncoderDecoderPrompt
(
TypedDict
):
# TODO: Make fields ReadOnly once mypy supports it
class
ExplicitEncoderDecoderPrompt
(
TypedDict
,
Generic
[
_T1_co
,
_T2_co
]):
"""Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a
decoder prompt.
The encoder and decoder prompts, respectively,
may formatted according to any of the
SingletonPromptInputs schemas, and are not
:class:`
SingletonPromptInputs
`
schemas, and are not
required to have the same schema.
Only the encoder prompt may have multi-modal data.
Note that an ExplicitEncoderDecoderPrompt may not
Note that an
:class:`
ExplicitEncoderDecoderPrompt
`
may not
be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure
may not
themselves
must be
SingletonPromptInputs instances.
fields of this data structure themselves
must be
:class:`
SingletonPromptInputs
`
instances.
"""
encoder_prompt
:
SingletonPromptInputs
encoder_prompt
:
_T1_co
decoder_prompt
:
SingletonPromptInputs
decoder_prompt
:
Optional
[
_T2_co
]
PromptInputs
=
Union
[
SingletonPromptInputs
,
ExplicitEncoderDecoderPrompt
]
...
...
@@ -150,60 +101,12 @@ both decoder-only and encoder/decoder input types:
"""
def
_has_required_keys
(
d
:
dict
,
required_keys
:
set
,
)
->
bool
:
return
required_keys
.
issubset
(
d
.
keys
())
def
get_prompt_type
(
prompt
:
Optional
[
PromptInputs
])
->
Optional
[
str
]:
"""
Get the type-name of the prompt argument instance, given that
isinstance() cannot apply to TypedDict subclasses directly.
If the prompt is None, return 'None' as the type name.
Arguments:
* prompt: LLM input prompt or None
Returns:
* String representation of prompt type
"""
if
prompt
is
None
:
return
'None'
required_keys_dict
=
{
'TextPrompt'
:
{
'prompt'
},
'TokensPrompt'
:
{
'prompt_token_ids'
},
'ExplicitEncoderDecoder'
:
{
'encoder_prompt'
,
'decoder_prompt'
},
}
if
isinstance
(
prompt
,
dict
):
for
(
ptype
,
required_keys
)
in
required_keys_dict
.
items
():
# Ignore type checking in the conditional below because type
# checker does not understand that is_dict(prompt) narrows
# down the possible types
if
_has_required_keys
(
prompt
,
# type: ignore
required_keys
):
return
ptype
raise
ValueError
(
f
"Invalid prompt
{
prompt
}
, valid types are "
"required_keys_dict={required_keys_dict}"
)
if
isinstance
(
prompt
,
str
):
return
"str"
raise
ValueError
(
f
"Invalid prompt
{
prompt
}
"
)
class
LLMInputs
(
TypedDict
):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""
prompt_token_ids
:
List
[
int
]
"""The token IDs of the prompt."""
...
...
@@ -213,7 +116,21 @@ class LLMInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available.
"""
encoder_prompt_token_ids
:
NotRequired
[
List
[
int
]]
multi_modal_data
:
NotRequired
[
Optional
[
"MultiModalDataDict"
]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class
EncoderDecoderLLMInputs
(
LLMInputs
):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the required data for encoder-decoder models.
"""
encoder_prompt_token_ids
:
List
[
int
]
"""The token IDs of the encoder prompt."""
encoder_prompt
:
NotRequired
[
Optional
[
str
]]
...
...
@@ -222,20 +139,40 @@ class LLMInputs(TypedDict):
available.
"""
multi_modal_data
:
NotRequired
[
Optional
[
"MultiModalDataDict"
]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
_T1
=
TypeVar
(
"_T1"
,
bound
=
SingletonPromptInputs
,
default
=
SingletonPromptInputs
)
_T2
=
TypeVar
(
"_T2"
,
bound
=
SingletonPromptInputs
,
default
=
SingletonPromptInputs
)
def
is_valid_encoder_decoder_llm_inputs
(
inputs
:
LLMInputs
)
->
bool
:
def
build_explicit_enc_dec_prompt
(
encoder_prompt
:
_T1
,
decoder_prompt
:
Optional
[
_T2
],
)
->
ExplicitEncoderDecoderPrompt
[
_T1
,
_T2
]:
return
ExplicitEncoderDecoderPrompt
(
encoder_prompt
=
encoder_prompt
,
decoder_prompt
=
decoder_prompt
)
def
zip_enc_dec_prompts
(
enc_prompts
:
Iterable
[
_T1
],
dec_prompts
:
Iterable
[
Optional
[
_T2
]],
)
->
List
[
ExplicitEncoderDecoderPrompt
[
_T1
,
_T2
]]:
"""
Return True if the LLMInputs instance has the correct configuration
for e
ncoder
/d
ecoder.
Zip encoder and decoder prompts together into a list of
:class:`ExplicitE
ncoder
D
ecoder
Prompt` instances
.
"""
return
[
build_explicit_enc_dec_prompt
(
encoder_prompt
,
decoder_prompt
)
for
(
encoder_prompt
,
decoder_prompt
)
in
zip
(
enc_prompts
,
dec_prompts
)
]
# True if encoder prompt token ids field exists &
# is not None
return
(
'encoder_prompt_token_ids'
in
inputs
and
inputs
[
'encoder_prompt_token_ids'
]
is
not
None
)
def
to_enc_dec_tuple_list
(
enc_dec_prompts
:
Iterable
[
ExplicitEncoderDecoderPrompt
[
_T1
,
_T2
]],
)
->
List
[
Tuple
[
_T1
,
Optional
[
_T2
]]]:
return
[(
enc_dec_prompt
[
"encoder_prompt"
],
enc_dec_prompt
[
"decoder_prompt"
])
for
enc_dec_prompt
in
enc_dec_prompts
]
vllm/inputs/parse.py
0 → 100644
View file @
7eb4a51c
from
typing
import
List
,
Literal
,
Sequence
,
TypedDict
,
Union
,
overload
from
typing_extensions
import
TypeIs
from
vllm.utils
import
is_list_of
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
LLMInputs
,
PromptInputs
)
class
ParsedText
(
TypedDict
):
content
:
str
is_tokens
:
Literal
[
False
]
class
ParsedTokens
(
TypedDict
):
content
:
List
[
int
]
is_tokens
:
Literal
[
True
]
@
overload
def
parse_and_batch_prompt
(
prompt
:
Union
[
str
,
List
[
str
]])
->
Sequence
[
ParsedText
]:
...
@
overload
def
parse_and_batch_prompt
(
prompt
:
Union
[
List
[
int
],
List
[
List
[
int
]]])
->
Sequence
[
ParsedTokens
]:
...
def
parse_and_batch_prompt
(
prompt
:
Union
[
str
,
List
[
str
],
List
[
int
],
List
[
List
[
int
]]],
)
->
Union
[
Sequence
[
ParsedText
],
Sequence
[
ParsedTokens
]]:
if
isinstance
(
prompt
,
str
):
# case 1: a string
return
[
ParsedText
(
content
=
prompt
,
is_tokens
=
False
)]
if
isinstance
(
prompt
,
list
):
if
len
(
prompt
)
==
0
:
raise
ValueError
(
"please provide at least one prompt"
)
if
is_list_of
(
prompt
,
str
):
# case 2: array of strings
return
[
ParsedText
(
content
=
elem
,
is_tokens
=
False
)
for
elem
in
prompt
]
if
is_list_of
(
prompt
,
int
):
# case 3: array of tokens
return
[
ParsedTokens
(
content
=
prompt
,
is_tokens
=
True
)]
if
is_list_of
(
prompt
,
list
):
if
len
(
prompt
[
0
])
==
0
:
raise
ValueError
(
"please provide at least one prompt"
)
if
is_list_of
(
prompt
[
0
],
int
):
# case 4: array of token arrays
return
[
ParsedTokens
(
content
=
elem
,
is_tokens
=
True
)
for
elem
in
prompt
]
raise
ValueError
(
"prompt must be a string, array of strings, "
"array of tokens, or array of token arrays"
)
def
is_explicit_encoder_decoder_prompt
(
inputs
:
PromptInputs
)
->
TypeIs
[
ExplicitEncoderDecoderPrompt
]:
return
isinstance
(
inputs
,
dict
)
and
"encoder_prompt"
in
inputs
def
is_valid_encoder_decoder_llm_inputs
(
inputs
:
Union
[
LLMInputs
,
EncoderDecoderLLMInputs
],
)
->
TypeIs
[
EncoderDecoderLLMInputs
]:
return
"encoder_prompt_token_ids"
in
inputs
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment