Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7eb4a51c
Unverified
Commit
7eb4a51c
authored
Aug 09, 2024
by
Cyrus Leung
Committed by
GitHub
Aug 09, 2024
Browse files
[Core] Support serving encoder/decoder models (#7258)
parent
0fa14907
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
563 additions
and
394 deletions
+563
-394
.github/workflows/mypy.yaml
.github/workflows/mypy.yaml
+1
-1
examples/offline_inference_encoder_decoder.py
examples/offline_inference_encoder_decoder.py
+4
-4
requirements-common.txt
requirements-common.txt
+1
-1
requirements-lint.txt
requirements-lint.txt
+1
-1
tests/conftest.py
tests/conftest.py
+19
-13
tests/distributed/test_basic_distributed_correctness_enc_dec.py
...distributed/test_basic_distributed_correctness_enc_dec.py
+1
-1
tests/entrypoints/openai/test_encoder_decoder.py
tests/entrypoints/openai/test_encoder_decoder.py
+50
-0
tests/models/test_bart.py
tests/models/test_bart.py
+27
-11
tests/models/utils.py
tests/models/utils.py
+0
-11
tests/test_inputs.py
tests/test_inputs.py
+1
-1
vllm/config.py
vllm/config.py
+10
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+130
-24
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+151
-171
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+2
-3
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+2
-2
vllm/entrypoints/openai/logits_processors.py
vllm/entrypoints/openai/logits_processors.py
+5
-3
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+1
-1
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+10
-11
vllm/inputs/data.py
vllm/inputs/data.py
+72
-135
vllm/inputs/parse.py
vllm/inputs/parse.py
+75
-0
No files found.
.github/workflows/mypy.yaml
View file @
7eb4a51c
...
@@ -25,7 +25,7 @@ jobs:
...
@@ -25,7 +25,7 @@ jobs:
-
name
:
Install dependencies
-
name
:
Install dependencies
run
:
|
run
:
|
python -m pip install --upgrade pip
python -m pip install --upgrade pip
pip install mypy==1.
9.0
pip install mypy==1.
11.1
pip install types-setuptools
pip install types-setuptools
pip install types-PyYAML
pip install types-PyYAML
pip install types-requests
pip install types-requests
...
...
examples/offline_inference_encoder_decoder.py
View file @
7eb4a51c
...
@@ -4,8 +4,8 @@ encoder/decoder models, specifically BART
...
@@ -4,8 +4,8 @@ encoder/decoder models, specifically BART
'''
'''
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.inputs
import
ExplicitEncoderDecoderPrompt
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
(
ExplicitEncoderDecoderPrompt
,
TextPrompt
,
from
vllm.utils
import
zip_enc_dec_prompt
_list
s
TokensPrompt
,
zip_enc_dec_prompts
)
dtype
=
"float"
dtype
=
"float"
...
@@ -61,9 +61,9 @@ enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
...
@@ -61,9 +61,9 @@ enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
)
)
# - Finally, here's a useful helper function for zipping encoder and
# - Finally, here's a useful helper function for zipping encoder and
# decoder prompt
list
s together into a list of ExplicitEncoderDecoderPrompt
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
# instances
# instances
zipped_prompt_list
=
zip_enc_dec_prompt
_list
s
(
zipped_prompt_list
=
zip_enc_dec_prompts
(
[
'An encoder prompt'
,
'Another encoder prompt'
],
[
'An encoder prompt'
,
'Another encoder prompt'
],
[
'A decoder prompt'
,
'Another decoder prompt'
])
[
'A decoder prompt'
,
'Another decoder prompt'
])
...
...
requirements-common.txt
View file @
7eb4a51c
...
@@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
...
@@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.3
lm-format-enforcer == 0.10.3
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions
typing_extensions
>= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
pyzmq
gguf == 0.9.1
gguf == 0.9.1
requirements-lint.txt
View file @
7eb4a51c
...
@@ -8,7 +8,7 @@ isort==5.13.2
...
@@ -8,7 +8,7 @@ isort==5.13.2
clang-format==18.1.5
clang-format==18.1.5
# type checking
# type checking
mypy==1.
9.0
mypy==1.
11.1
types-PyYAML
types-PyYAML
types-requests
types-requests
types-setuptools
types-setuptools
tests/conftest.py
View file @
7eb4a51c
...
@@ -3,6 +3,7 @@ import gc
...
@@ -3,6 +3,7 @@ import gc
import
os
import
os
import
sys
import
sys
from
collections
import
UserList
from
collections
import
UserList
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
TypedDict
,
TypeVar
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
TypedDict
,
TypeVar
,
Union
import
pytest
import
pytest
...
@@ -14,20 +15,19 @@ from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
...
@@ -14,20 +15,19 @@ from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForVision2Seq
,
AutoTokenizer
,
BatchEncoding
,
AutoModelForVision2Seq
,
AutoTokenizer
,
BatchEncoding
,
BatchFeature
)
BatchFeature
)
from
tests.models.utils
import
DecoderPromptType
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.image
import
ImageAsset
from
vllm.config
import
TokenizerPoolConfig
from
vllm.config
import
TokenizerPoolConfig
from
vllm.connections
import
global_http_connection
from
vllm.connections
import
global_http_connection
from
vllm.distributed
import
(
destroy_distributed_environment
,
from
vllm.distributed
import
(
destroy_distributed_environment
,
destroy_model_parallel
)
destroy_model_parallel
)
from
vllm.inputs
import
TextPrompt
from
vllm.inputs
import
(
ExplicitEncoderDecoderPrompt
,
TextPrompt
,
to_enc_dec_tuple_list
,
zip_enc_dec_prompts
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
SampleLogprobs
from
vllm.sequence
import
SampleLogprobs
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
cuda_device_count_stateless
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
cuda_device_count_stateless
,
is_cpu
,
to_enc_dec_tuple_list
,
is_cpu
)
zip_enc_dec_prompt_lists
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -124,10 +124,16 @@ def example_prompts() -> List[str]:
...
@@ -124,10 +124,16 @@ def example_prompts() -> List[str]:
return
prompts
return
prompts
class
DecoderPromptType
(
Enum
):
"""For encoder/decoder models only."""
CUSTOM
=
1
NONE
=
2
EMPTY_STR
=
3
@
pytest
.
fixture
@
pytest
.
fixture
def
example_encoder_decoder_prompts
()
\
def
example_encoder_decoder_prompts
(
->
Dict
[
DecoderPromptType
,
)
->
Dict
[
DecoderPromptType
,
List
[
ExplicitEncoderDecoderPrompt
]]:
Tuple
[
List
[
str
],
List
[
Optional
[
str
]]]]:
'''
'''
Returns an encoder prompt list and a decoder prompt list, wherein each pair
Returns an encoder prompt list and a decoder prompt list, wherein each pair
of same-index entries in both lists corresponds to an (encoder prompt,
of same-index entries in both lists corresponds to an (encoder prompt,
...
@@ -150,11 +156,11 @@ def example_encoder_decoder_prompts() \
...
@@ -150,11 +156,11 @@ def example_encoder_decoder_prompts() \
# NONE decoder prompt type
# NONE decoder prompt type
return
{
return
{
DecoderPromptType
.
NONE
:
DecoderPromptType
.
NONE
:
zip_enc_dec_prompt
_list
s
(
encoder_prompts
,
none_decoder_prompts
),
zip_enc_dec_prompts
(
encoder_prompts
,
none_decoder_prompts
),
DecoderPromptType
.
EMPTY_STR
:
DecoderPromptType
.
EMPTY_STR
:
zip_enc_dec_prompt
_list
s
(
encoder_prompts
,
empty_str_decoder_prompts
),
zip_enc_dec_prompts
(
encoder_prompts
,
empty_str_decoder_prompts
),
DecoderPromptType
.
CUSTOM
:
DecoderPromptType
.
CUSTOM
:
zip_enc_dec_prompt
_list
s
(
encoder_prompts
,
custom_decoder_prompts
),
zip_enc_dec_prompts
(
encoder_prompts
,
custom_decoder_prompts
),
}
}
...
@@ -444,7 +450,7 @@ class HfRunner:
...
@@ -444,7 +450,7 @@ class HfRunner:
def
generate_encoder_decoder_greedy_logprobs_limit
(
def
generate_encoder_decoder_greedy_logprobs_limit
(
self
,
self
,
encoder_decoder_prompts
:
Tuple
[
Lis
t
[
str
]
,
List
[
str
]],
encoder_decoder_prompts
:
List
[
ExplicitEncoderDecoderPromp
t
[
str
,
str
]],
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
**
kwargs
:
Any
,
**
kwargs
:
Any
,
...
@@ -608,7 +614,7 @@ class VllmRunner:
...
@@ -608,7 +614,7 @@ class VllmRunner:
def
generate_encoder_decoder_w_logprobs
(
def
generate_encoder_decoder_w_logprobs
(
self
,
self
,
encoder_decoder_prompts
:
Tuple
[
Lis
t
[
str
]
,
List
[
str
]],
encoder_decoder_prompts
:
List
[
ExplicitEncoderDecoderPromp
t
[
str
,
str
]],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
'''
'''
...
@@ -653,7 +659,7 @@ class VllmRunner:
...
@@ -653,7 +659,7 @@ class VllmRunner:
def
generate_encoder_decoder_greedy_logprobs
(
def
generate_encoder_decoder_greedy_logprobs
(
self
,
self
,
encoder_decoder_prompts
:
Tuple
[
Lis
t
[
str
]
,
List
[
str
]],
encoder_decoder_prompts
:
List
[
ExplicitEncoderDecoderPromp
t
[
str
,
str
]],
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
...
...
tests/distributed/test_basic_distributed_correctness_enc_dec.py
View file @
7eb4a51c
...
@@ -11,9 +11,9 @@ pytest distributed/test_basic_distributed_correctness_enc_dec.py
...
@@ -11,9 +11,9 @@ pytest distributed/test_basic_distributed_correctness_enc_dec.py
import
pytest
import
pytest
from
tests.models.utils
import
DecoderPromptType
from
vllm.utils
import
cuda_device_count_stateless
from
vllm.utils
import
cuda_device_count_stateless
from
..conftest
import
DecoderPromptType
from
..models.utils
import
check_logprobs_close
from
..models.utils
import
check_logprobs_close
from
..utils
import
fork_new_process_for_each_test
from
..utils
import
fork_new_process_for_each_test
...
...
tests/entrypoints/openai/test_encoder_decoder.py
0 → 100644
View file @
7eb4a51c
import
openai
import
pytest
from
...utils
import
RemoteOpenAIServer
MODEL_NAME
=
"facebook/bart-base"
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
args
=
[
"--dtype"
,
"bfloat16"
,
"--enforce-eager"
,
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
@
pytest
.
fixture
(
scope
=
"module"
)
def
client
(
server
):
return
server
.
get_async_client
()
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
async
def
test_single_completion
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
"Hello, my name is"
,
max_tokens
=
5
,
temperature
=
0.0
)
assert
completion
.
id
is
not
None
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
1
choice
=
completion
.
choices
[
0
]
assert
len
(
choice
.
text
)
>=
5
assert
choice
.
finish_reason
==
"length"
assert
completion
.
usage
==
openai
.
types
.
CompletionUsage
(
completion_tokens
=
5
,
prompt_tokens
=
2
,
total_tokens
=
7
)
# test using token IDs
completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
[
0
,
0
,
0
,
0
,
0
],
max_tokens
=
5
,
temperature
=
0.0
,
)
assert
len
(
completion
.
choices
[
0
].
text
)
>=
1
tests/models/test_bart.py
View file @
7eb4a51c
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
Run `pytest tests/models/test_bart.py`.
Run `pytest tests/models/test_bart.py`.
"""
"""
from
typing
import
List
,
Optional
,
Tuple
from
vllm.utils
import
is_cpu
from
vllm.utils
import
is_cpu
if
not
is_cpu
():
if
not
is_cpu
():
...
@@ -11,22 +13,31 @@ if not is_cpu():
...
@@ -11,22 +13,31 @@ if not is_cpu():
import
pytest
import
pytest
from
tests.models.utils
import
DecoderPromptType
from
vllm.sequence
import
SampleLogprobs
from
..conftest
import
DecoderPromptType
from
.utils
import
check_logprobs_close
from
.utils
import
check_logprobs_close
MODELS
=
[
"facebook/bart-base"
,
"facebook/bart-large-cnn"
]
MODELS
=
[
"facebook/bart-base"
,
"facebook/bart-large-cnn"
]
DECODER_PROMPT_TYPES
=
([
def
vllm_to_hf_output
(
DecoderPromptType
.
CUSTOM
,
DecoderPromptType
.
EMPTY_STR
,
vllm_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]],
DecoderPromptType
.
NONE
decoder_prompt_type
:
DecoderPromptType
,
])
):
"""Sanitize vllm output to be comparable with hf output."""
output_ids
,
output_str
,
out_logprobs
=
vllm_output
hf_output_str
=
output_str
+
"</s>"
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
:
hf_output_str
=
"<s>"
+
hf_output_str
return
output_ids
,
hf_output_str
,
out_logprobs
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
,
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
,
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
DECODER_PROMPT_TYPES
)
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
list
(
DecoderPromptType
)
)
def
test_models
(
def
test_models
(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
...
@@ -146,8 +157,13 @@ if not is_cpu():
...
@@ -146,8 +157,13 @@ if not is_cpu():
hf_skip_tokens
=
(
1
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
hf_skip_tokens
=
(
1
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
else
0
)
else
0
)
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
check_logprobs_close
(
outputs_1_lst
=
vllm_outputs
,
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
[
vllm_to_hf_output
(
vllm_output
,
decoder_prompt_type
)
for
vllm_output
in
vllm_outputs
],
name_0
=
"hf"
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
name_1
=
"vllm"
,
num_outputs_0_skip_tokens
=
hf_skip_tokens
)
num_outputs_0_skip_tokens
=
hf_skip_tokens
,
)
tests/models/utils.py
View file @
7eb4a51c
import
warnings
import
warnings
from
enum
import
Enum
from
typing
import
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
vllm.sequence
import
SampleLogprobs
from
vllm.sequence
import
SampleLogprobs
...
@@ -136,13 +135,3 @@ def check_logprobs_close(
...
@@ -136,13 +135,3 @@ def check_logprobs_close(
warnings
.
simplefilter
(
"always"
)
warnings
.
simplefilter
(
"always"
)
warnings
.
warn
(
fail_msg
,
stacklevel
=
2
)
warnings
.
warn
(
fail_msg
,
stacklevel
=
2
)
class
DecoderPromptType
(
Enum
):
'''
For encoder/decoder models only -
'''
CUSTOM
=
1
NONE
=
2
EMPTY_STR
=
3
tests/test_inputs.py
View file @
7eb4a51c
...
@@ -2,7 +2,7 @@ from typing import List
...
@@ -2,7 +2,7 @@ from typing import List
import
pytest
import
pytest
from
vllm.inputs
import
parse_and_batch_prompt
from
vllm.inputs
.parse
import
parse_and_batch_prompt
STRING_INPUTS
=
[
STRING_INPUTS
=
[
''
,
''
,
...
...
vllm/config.py
View file @
7eb4a51c
...
@@ -464,6 +464,16 @@ class ModelConfig:
...
@@ -464,6 +464,16 @@ class ModelConfig:
if
t
!=
"attention"
if
t
!=
"attention"
])
])
@
property
def
is_encoder_decoder_model
(
self
)
->
bool
:
"""Extract the HF encoder/decoder model flag."""
return
getattr
(
self
.
hf_config
,
"is_encoder_decoder"
,
False
)
@
property
def
is_embedding_model
(
self
)
->
bool
:
"""Extract the embedding model flag."""
return
self
.
embedding_mode
class
CacheConfig
:
class
CacheConfig
:
"""Configuration for the KV cache.
"""Configuration for the KV cache.
...
...
vllm/engine/async_llm_engine.py
View file @
7eb4a51c
...
@@ -5,6 +5,7 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
...
@@ -5,6 +5,7 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
Optional
,
Set
,
Tuple
,
Type
,
Union
)
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
from
typing_extensions
import
assert_never
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
(
DecodingConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
DecodingConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
...
@@ -12,11 +13,14 @@ from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
...
@@ -12,11 +13,14 @@ from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
(
DecoderPromptComponents
,
LLMEngine
,
PromptComponents
)
from
vllm.engine.metrics
import
StatLoggerBase
from
vllm.engine.metrics
import
StatLoggerBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.inputs
import
LLMInputs
,
PromptInputs
from
vllm.inputs
import
(
EncoderDecoderLLMInputs
,
LLMInputs
,
PromptInputs
,
SingletonPromptInputs
)
from
vllm.inputs.parse
import
is_explicit_encoder_decoder_prompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
...
@@ -293,38 +297,138 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -293,38 +297,138 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
"""Stop the remote worker execution loop."""
await
self
.
model_executor
.
stop_remote_worker_execution_loop_async
()
await
self
.
model_executor
.
stop_remote_worker_execution_loop_async
()
async
def
process_model_inputs
_async
(
async
def
_tokenize_prompt
_async
(
self
,
self
,
prompt
:
str
,
request_id
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
],
)
->
List
[
int
]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer
=
self
.
get_tokenizer_group
(
"prompts must be None if "
"skip_tokenizer_init is True"
)
return
await
tokenizer
.
encode_async
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
)
async
def
_extract_prompt_components_async
(
self
,
inputs
:
SingletonPromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
PromptComponents
:
"""Async version of :meth:`_extract_prompt_components`."""
if
isinstance
(
inputs
,
str
):
prompt
=
inputs
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
multi_modal_data
=
None
elif
isinstance
(
inputs
,
dict
):
if
"prompt_token_ids"
in
inputs
:
prompt
=
None
prompt_token_ids
=
inputs
[
"prompt_token_ids"
]
else
:
# NOTE: This extra assignment is required to pass mypy
prompt
=
parsed_prompt
=
inputs
[
"prompt"
]
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
parsed_prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
else
:
assert_never
(
inputs
)
return
prompt
,
prompt_token_ids
,
multi_modal_data
async
def
_process_encoder_decoder_prompt_async
(
self
,
inputs
:
PromptInputs
,
inputs
:
PromptInputs
,
request_id
:
str
,
)
->
EncoderDecoderLLMInputs
:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps
:
PromptComponents
decoder_comps
:
DecoderPromptComponents
if
is_explicit_encoder_decoder_prompt
(
inputs
):
encoder_task
=
self
.
_extract_prompt_components_async
(
inputs
[
"encoder_prompt"
],
request_id
=
request_id
,
)
if
(
decoder_input
:
=
inputs
[
"decoder_prompt"
])
is
None
:
encoder_comps
=
await
encoder_task
decoder_comps
=
None
,
None
,
None
else
:
decoder_task
=
self
.
_extract_prompt_components_async
(
decoder_input
,
request_id
=
request_id
,
)
encoder_comps
,
decoder_comps
=
await
asyncio
.
gather
(
encoder_task
,
decoder_task
)
else
:
encoder_comps
=
await
self
.
_extract_prompt_components_async
(
inputs
,
request_id
=
request_id
,
)
decoder_comps
=
None
,
None
,
None
return
self
.
_build_enc_dec_llm_inputs
(
encoder_comps
,
decoder_comps
)
async
def
_process_decoder_only_prompt_async
(
self
,
inputs
:
SingletonPromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
LLMInputs
:
)
->
LLMInputs
:
if
isinstance
(
inputs
,
str
):
"""Async version of :meth:`_process_decoder_only_prompt`."""
inputs
=
{
"prompt"
:
inputs
}
prompt_comps
=
await
self
.
_extract_prompt_components_async
(
inputs
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
if
"prompt_token_ids"
not
in
inputs
:
return
self
.
_build_decoder_only_llm_inputs
(
tokenizer
=
self
.
get_tokenizer_group
(
"prompts must be None if "
prompt_comps
,
"skip_tokenizer_init is True"
)
prompt_adapter_request
=
prompt_adapter_request
,
)
prompt_token_ids
=
await
tokenizer
.
encode_async
(
async
def
process_model_inputs_async
(
self
,
inputs
:
PromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
Union
[
LLMInputs
,
EncoderDecoderLLMInputs
]:
"""Async version of :meth:`process_model_inputs`."""
if
self
.
is_encoder_decoder_model
():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs
=
await
self
.
_process_encoder_decoder_prompt_async
(
inputs
,
request_id
=
request_id
,
request_id
=
request_id
,
prompt
=
inputs
[
"prompt"
],
)
lora_request
=
lora_request
)
else
:
else
:
prompt_token_ids
=
inputs
[
"prompt_token_ids"
]
if
is_explicit_encoder_decoder_prompt
(
inputs
):
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
if
prompt_adapter_request
:
"to decoder-only models"
)
prompt_token_ids
=
[
0
]
*
prompt_adapter_request
.
prompt_adapter_num_virtual_tokens
+
\
prompt_token_ids
llm_inputs
=
LLMInputs
(
prompt_token_ids
=
prompt_token_ids
,
# Decoder-only operation
prompt
=
inputs
.
get
(
"prompt"
),
model_inputs
=
await
self
.
_process_decoder_only_prompt_async
(
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
))
inputs
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
return
self
.
input_processor
(
llm
_inputs
)
return
self
.
input_processor
(
model
_inputs
)
async
def
add_request_async
(
async
def
add_request_async
(
self
,
self
,
...
@@ -336,6 +440,7 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -336,6 +440,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
)
->
None
:
"""Async version of :meth:`add_request`."""
if
lora_request
is
not
None
and
not
self
.
lora_config
:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
"not enabled!"
)
...
@@ -343,10 +448,11 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -343,10 +448,11 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
processed_inputs
=
await
self
.
process_model_inputs_async
(
processed_inputs
=
await
self
.
process_model_inputs_async
(
inputs
,
request_id
=
request_id
,
request_id
=
request_id
,
inputs
=
inputs
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
)
self
.
_add_processed_request
(
self
.
_add_processed_request
(
request_id
=
request_id
,
request_id
=
request_id
,
...
...
vllm/engine/llm_engine.py
View file @
7eb4a51c
...
@@ -5,6 +5,8 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
...
@@ -5,6 +5,8 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Type
,
TypeVar
,
Union
from
typing
import
Set
,
Tuple
,
Type
,
TypeVar
,
Union
from
typing_extensions
import
assert_never
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
...
@@ -22,10 +24,12 @@ from vllm.engine.output_processor.stop_checker import StopChecker
...
@@ -22,10 +24,12 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from
vllm.engine.output_processor.util
import
create_output_by_sequence_group
from
vllm.engine.output_processor.util
import
create_output_by_sequence_group
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
(
INPUT_REGISTRY
,
LLMInputs
,
PromptInputs
,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
EncoderDecoderLLMInputs
,
LLMInputs
,
get_prompt_type
)
PromptInputs
,
SingletonPromptInputs
)
from
vllm.inputs.parse
import
is_explicit_encoder_decoder_prompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.outputs
import
(
EmbeddingRequestOutput
,
RequestOutput
,
from
vllm.outputs
import
(
EmbeddingRequestOutput
,
RequestOutput
,
RequestOutputFactory
)
RequestOutputFactory
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
...
@@ -43,8 +47,7 @@ from vllm.transformers_utils.tokenizer_group import (
...
@@ -43,8 +47,7 @@ from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer
,
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
AnyTokenizer
,
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
usage_message
)
from
vllm.utils
import
(
Counter
,
is_embedding_model_config
,
from
vllm.utils
import
Counter
is_encoder_decoder_model_config
)
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -66,6 +69,11 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
...
@@ -66,6 +69,11 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
_O
=
TypeVar
(
"_O"
,
RequestOutput
,
EmbeddingRequestOutput
)
_O
=
TypeVar
(
"_O"
,
RequestOutput
,
EmbeddingRequestOutput
)
PromptComponents
=
Tuple
[
Optional
[
str
],
List
[
int
],
Optional
[
MultiModalDataDict
]]
DecoderPromptComponents
=
Tuple
[
Optional
[
str
],
Optional
[
List
[
int
]],
Optional
[
MultiModalDataDict
]]
class
LLMEngine
:
class
LLMEngine
:
"""An LLM engine that receives requests and generates texts.
"""An LLM engine that receives requests and generates texts.
...
@@ -524,7 +532,7 @@ class LLMEngine:
...
@@ -524,7 +532,7 @@ class LLMEngine:
return
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
eos_token_id
return
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
eos_token_id
def
_get_decoder_start_token_id
(
self
,
)
->
Optional
[
int
]:
def
_get_decoder_start_token_id
(
self
)
->
Optional
[
int
]:
'''
'''
Obtain the decoder start token id employed by an encoder/decoder
Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the
model. Returns None for non-encoder/decoder models or if the
...
@@ -553,7 +561,7 @@ class LLMEngine:
...
@@ -553,7 +561,7 @@ class LLMEngine:
def
_add_processed_request
(
def
_add_processed_request
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
processed_inputs
:
LLMInputs
,
processed_inputs
:
Union
[
LLMInputs
,
EncoderDecoderLLMInputs
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
float
,
arrival_time
:
float
,
lora_request
:
Optional
[
LoRARequest
],
lora_request
:
Optional
[
LoRARequest
],
...
@@ -613,11 +621,11 @@ class LLMEngine:
...
@@ -613,11 +621,11 @@ class LLMEngine:
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
self
.
model_executor
.
stop_remote_worker_execution_loop
()
self
.
model_executor
.
stop_remote_worker_execution_loop
()
_LLMInputComponentsType
=
Tuple
[
str
,
List
[
int
]
,
]
_LLMInputComponentsType
=
Tuple
[
str
,
List
[
int
]]
def
_prepare_decoder_input_ids_for_generation
(
def
_prepare_decoder_input_ids_for_generation
(
self
,
self
,
decoder_input_ids
:
Optional
[
List
[
int
]]
=
None
,
decoder_input_ids
:
Optional
[
List
[
int
]],
)
->
List
[
int
]:
)
->
List
[
int
]:
"""
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
Prepares `decoder_input_ids` for generation with encoder-decoder models.
...
@@ -639,14 +647,13 @@ class LLMEngine:
...
@@ -639,14 +647,13 @@ class LLMEngine:
* Processed token list
* Processed token list
"""
"""
decoder_start_token_id
:
Optional
[
int
]
=
(
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
()
self
.
_get_decoder_start_token_id
())
assert
decoder_start_token_id
is
not
None
assert
decoder_start_token_id
is
not
None
if
decoder_input_ids
is
None
:
if
decoder_input_ids
is
None
:
# no decoder prompt input ->
# no decoder prompt input ->
# use decoder_start_token_id as decoder_input_ids
# use decoder_start_token_id as decoder_input_ids
(
decoder_input_ids
)
=
self
.
_get_default_enc_dec_decoder_prompt
()
decoder_input_ids
=
self
.
_get_default_enc_dec_decoder_prompt
()
if
(
len
(
decoder_input_ids
)
==
0
if
(
len
(
decoder_input_ids
)
==
0
or
decoder_input_ids
[
0
]
!=
decoder_start_token_id
):
or
decoder_input_ids
[
0
]
!=
decoder_start_token_id
):
...
@@ -657,12 +664,11 @@ class LLMEngine:
...
@@ -657,12 +664,11 @@ class LLMEngine:
def
_tokenize_prompt
(
def
_tokenize_prompt
(
self
,
self
,
prompt
:
str
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
request_id
:
str
,
lora_request
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
,
)
->
List
[
int
]:
)
->
List
[
int
]:
'''
'''
Wrapper around application of the model's
Wrapper around application of the model's tokenizer.
tokenizer.
Arguments:
Arguments:
...
@@ -678,87 +684,72 @@ class LLMEngine:
...
@@ -678,87 +684,72 @@ class LLMEngine:
tokenizer
=
self
.
get_tokenizer_group
(
"prompts must be None if "
tokenizer
=
self
.
get_tokenizer_group
(
"prompts must be None if "
"skip_tokenizer_init is True"
)
"skip_tokenizer_init is True"
)
prompt_token_ids
=
tokenizer
.
encode
(
request_id
=
request_id
,
return
tokenizer
.
encode
(
request_id
=
request_id
,
prompt
=
prompt
,
prompt
=
prompt
,
lora_request
=
lora_request
)
lora_request
=
lora_request
)
return
prompt_token_ids
def
_extract_prompt_components
(
def
_extract_single_prompt_for_enc_dec_input
(
self
,
self
,
inputs
:
Optional
[
PromptInputs
],
inputs
:
SingletonPromptInputs
,
request_id
:
Optional
[
str
]
=
None
,
request_id
:
str
,
ptype
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
is_encoder_prompt
:
bool
=
False
,
)
->
PromptComponents
:
)
->
Tuple
[
Optional
[
str
],
List
[
int
]]:
'''
'''
Only for encoder/decoder models:
Extract the components of any single encoder or decoder input prompt.
Extract prompt & prompt_token_ids from any single
encoder or decoder input prompt. For encoder input prompts
in particular, also extract multi-modal data.
This function handles the following scenarios:
1. The user supplied a singleton encoder prompt
& the prompt/prompt-token-ids must be extracted.
2. The user supplied an explicit encoder/decoder
prompt & the prompt/prompt-token-ids must be
extracted from either the encoder and decoder prompts.
For decoder prompts in particular (scenario 2), special
processing is applied to the returned decoder token ids.
Arguments:
Arguments:
* request_id
* request_id
* ptype: str representation of the input prompt type.
If `ptype` is `None`, assume that the prompt
type is unknown and must be inferred. This is the
case for ExplicitEncoderDecoder sub-prompts.
* inputs: single encoder or decoder input prompt
* inputs: single encoder or decoder input prompt
* is_encoder_prompt: True if encoder input prompt.
* lora_request: this is only valid for decoder prompts
If False, decoder prompt tokens
are preprocessed.
Returns:
Returns:
* prompt
* prompt
* prompt_token_ids
* prompt_token_ids
* multi_modal_data
'''
'''
prompt_token_ids
=
None
ptype
=
(
get_prompt_type
(
inputs
)
if
ptype
is
None
else
ptype
)
if
inputs
is
None
:
if
isinstance
(
inputs
,
str
):
prompt
=
None
elif
ptype
==
'str'
:
prompt
=
inputs
prompt
=
inputs
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
)
elif
ptype
==
'TokensPrompt'
:
multi_modal_data
=
None
elif
isinstance
(
inputs
,
dict
):
if
"prompt_token_ids"
in
inputs
:
prompt
=
None
prompt
=
None
prompt_token_ids
=
inputs
[
'
prompt_token_ids
'
]
prompt_token_ids
=
inputs
[
"
prompt_token_ids
"
]
else
:
else
:
prompt
=
inputs
[
'prompt'
]
# NOTE: This extra assignment is required to pass mypy
prompt
=
parsed_prompt
=
inputs
[
"prompt"
]
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt
,
parsed_
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
)
if
not
is_encoder_prompt
:
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
# Apply special pre-processing to
else
:
# decoder prompts
assert_never
(
inputs
)
prompt_token_ids
=
(
self
.
_prepare_decoder_input_ids_for_generation
(
prompt_token_ids
,
))
assert
prompt_token_ids
is
not
None
return
prompt
,
prompt_token_ids
,
multi_modal_data
return
(
def
_apply_prompt_adapter
(
prompt
,
self
,
prompt_token_ids
,
prompt_token_ids
:
List
[
int
],
)
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
List
[
int
]:
if
prompt_adapter_request
:
prompt_token_ids
=
(
[
0
]
*
prompt_adapter_request
.
prompt_adapter_num_virtual_tokens
+
prompt_token_ids
)
def
_get_default_enc_dec_decoder_prompt
(
self
,
)
->
List
[
int
]:
return
prompt_token_ids
def
_get_default_enc_dec_decoder_prompt
(
self
)
->
List
[
int
]:
'''
'''
Specifically for encoder/decoder models:
Specifically for encoder/decoder models:
generate a default decoder prompt for when
generate a default decoder prompt for when
...
@@ -792,18 +783,39 @@ class LLMEngine:
...
@@ -792,18 +783,39 @@ class LLMEngine:
bos_token_id
=
self
.
_get_bos_token_id
()
bos_token_id
=
self
.
_get_bos_token_id
()
assert
bos_token_id
is
not
None
assert
bos_token_id
is
not
None
prompt_token_ids
:
List
[
int
]
=
[
bos_token_id
]
return
[
bos_token_id
]
return
prompt_token_ids
def
_build_enc_dec_llm_inputs
(
self
,
encoder_comps
:
PromptComponents
,
decoder_comps
:
DecoderPromptComponents
,
)
->
EncoderDecoderLLMInputs
:
encoder_prompt
,
encoder_prompt_ids
,
encoder_mm_data
=
encoder_comps
decoder_prompt
,
decoder_prompt_ids
,
decoder_mm_data
=
decoder_comps
if
encoder_mm_data
is
not
None
or
decoder_mm_data
is
not
None
:
raise
ValueError
(
"Multi-modal encoder-decoder models are "
"not supported yet"
)
decoder_prompt_ids
=
(
self
.
_prepare_decoder_input_ids_for_generation
(
decoder_prompt_ids
))
return
EncoderDecoderLLMInputs
(
prompt_token_ids
=
decoder_prompt_ids
,
prompt
=
decoder_prompt
,
encoder_prompt_token_ids
=
encoder_prompt_ids
,
encoder_prompt
=
encoder_prompt
,
)
def
_process_encoder_decoder_prompt
(
def
_process_encoder_decoder_prompt
(
self
,
self
,
inputs
:
PromptInputs
,
inputs
:
PromptInputs
,
request_id
:
Optional
[
str
]
=
None
,
request_id
:
str
,
)
->
LLMInputs
:
)
->
EncoderDecoder
LLMInputs
:
'''
'''
For encoder/decoder models only:
For encoder/decoder models only:
Process an input prompt
Process an input prompt
into an
into an `
LLMInputs` instance.
:class:`EncoderDecoder
LLMInputs` instance.
There are two types of input prompts:
There are two types of input prompts:
singleton prompts which carry only the
singleton prompts which carry only the
...
@@ -830,136 +842,103 @@ class LLMEngine:
...
@@ -830,136 +842,103 @@ class LLMEngine:
Returns:
Returns:
*
`
LLMInputs` instance
*
:class:`EncoderDecoder
LLMInputs` instance
'''
'''
ptype
=
get_prompt_type
(
inputs
)
encoder_comps
:
PromptComponents
decoder_comps
:
DecoderPromptComponents
# Obtain encoder and decoder prompt tokens. Note
# that, no matter what, the decoder
if
is_explicit_encoder_decoder_prompt
(
inputs
):
# prompt type is unknown.
encoder_comps
=
self
.
_extract_prompt_components
(
if
ptype
==
"ExplicitEncoderDecoder"
:
inputs
[
"encoder_prompt"
],
# If input is explicit encoder/decoder prompt,
# then it remains to be determined what type
# of encoder prompt we have
extracted_encoder_prompt
=
inputs
.
get
(
'encoder_prompt'
)
encoder_ptype
=
None
# Extract decoder prompt from explicit
# encoder/decoder prompt
extracted_decoder_prompt
=
inputs
.
get
(
'decoder_prompt'
)
else
:
# If input is singleton encoder prompt, then
# we know the encoder prompt type
extracted_encoder_prompt
=
inputs
encoder_ptype
=
ptype
# Decoder prompt is always unknown if
# encoder/decoder prompt is not explicit
extracted_decoder_prompt
=
None
# Invoke helper function to obtain encoder
# prompt and prompt token ids, either from
# singleton encoder prompt or from the
# encoder sub-prompt of an explicit
# encoder/decode scenario 2), special
# processing is applied to the returned decoder token ids
(
encoder_prompt
,
encoder_prompt_token_ids
,
)
=
self
.
_extract_single_prompt_for_enc_dec_input
(
extracted_encoder_prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
ptype
=
encoder_ptype
,
is_encoder_prompt
=
True
,
)
)
# Invoke helper method to obtain
if
(
decoder_input
:
=
inputs
[
"decoder_prompt"
])
is
None
:
# decoder prompt and prompt token ids.
decoder_comps
=
None
,
None
,
None
#
else
:
# The helper method will detect the decoder
decoder_comps
=
self
.
_extract_prompt_components
(
# prompt type.
decoder_input
,
#
# Helper method will also apply special
# preprocessing unique to decoder prompts.
(
decoder_prompt
,
decoder_prompt_token_ids
,
)
=
self
.
_extract_single_prompt_for_enc_dec_input
(
extracted_decoder_prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
ptype
=
None
,
is_encoder_prompt
=
False
,
)
)
else
:
return
LLMInputs
(
encoder_comps
=
self
.
_extract_prompt_components
(
prompt_token_ids
=
decoder_prompt_token_ids
,
inputs
,
prompt
=
decoder_prompt
,
request_id
=
request_id
,
encoder_prompt_token_ids
=
encoder_prompt_token_ids
,
encoder_prompt
=
encoder_prompt
,
)
)
decoder_comps
=
None
,
None
,
None
return
self
.
_build_enc_dec_llm_inputs
(
encoder_comps
,
decoder_comps
)
def
_build_decoder_only_llm_inputs
(
self
,
prompt_comps
:
PromptComponents
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
LLMInputs
:
prompt
,
prompt_token_ids
,
multi_modal_data
=
prompt_comps
prompt_token_ids
=
self
.
_apply_prompt_adapter
(
prompt_token_ids
,
prompt_adapter_request
=
prompt_adapter_request
)
return
LLMInputs
(
prompt_token_ids
=
prompt_token_ids
,
prompt
=
prompt
,
multi_modal_data
=
multi_modal_data
)
def
_process_decoder_only_prompt
(
def
_process_decoder_only_prompt
(
self
,
self
,
inputs
:
PromptInputs
,
inputs
:
SingletonPromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
LLMInputs
:
)
->
LLMInputs
:
'''
'''
For decoder-only models:
For decoder-only models:
Process an input prompt
Process an input prompt into an :class:`LLMInputs` instance.
into an `LLMInputs` instance.
Arguments:
Arguments:
* inputs: input prompt
* inputs: input prompt
* lora_request
* request_id
* request_id
* lora_request
* prompt_adapter_request
* prompt_adapter_request
Returns:
Returns:
* `LLMInputs` instance
*
:class:
`LLMInputs` instance
'''
'''
if
isinstance
(
inputs
,
str
):
prompt_comps
=
self
.
_extract_prompt_components
(
inputs
=
{
"prompt"
:
inputs
}
inputs
,
prompt
=
inputs
.
get
(
"prompt"
)
if
"prompt_token_ids"
not
in
inputs
:
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
)
)
else
:
prompt_token_ids
=
inputs
[
"prompt_token_ids"
]
if
prompt_adapter_request
:
return
self
.
_build_decoder_only_llm_inputs
(
prompt_token_ids
=
(
prompt_comps
,
[
0
]
*
prompt_adapter_request
.
prompt_adapter_num_virtual_tokens
prompt_adapter_request
=
prompt_adapter_request
,
+
prompt_token_ids
)
)
return
LLMInputs
(
prompt_token_ids
=
prompt_token_ids
,
prompt
=
prompt
,
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
))
def
process_model_inputs
(
def
process_model_inputs
(
self
,
self
,
request_id
:
str
,
inputs
:
PromptInputs
,
inputs
:
PromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
LLMInputs
:
)
->
Union
[
LLMInputs
,
EncoderDecoderLLMInputs
]
:
if
self
.
is_encoder_decoder_model
():
if
self
.
is_encoder_decoder_model
():
# Encoder-decoder model requires special mapping of
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
# input prompts to encoder & decoder
model_inputs
=
self
.
_process_encoder_decoder_prompt
(
model_inputs
=
self
.
_process_encoder_decoder_prompt
(
inputs
,
inputs
,
request_id
=
request_id
,
request_id
=
request_id
,
)
)
else
:
else
:
if
is_explicit_encoder_decoder_prompt
(
inputs
):
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
"to decoder-only models"
)
# Decoder-only operation
# Decoder-only operation
model_inputs
=
self
.
_process_decoder_only_prompt
(
model_inputs
=
self
.
_process_decoder_only_prompt
(
inputs
,
inputs
,
...
@@ -1029,10 +1008,11 @@ class LLMEngine:
...
@@ -1029,10 +1008,11 @@ class LLMEngine:
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
processed_inputs
=
self
.
process_model_inputs
(
processed_inputs
=
self
.
process_model_inputs
(
inputs
,
request_id
=
request_id
,
request_id
=
request_id
,
inputs
=
inputs
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
)
self
.
_add_processed_request
(
self
.
_add_processed_request
(
request_id
=
request_id
,
request_id
=
request_id
,
...
@@ -1597,7 +1577,7 @@ class LLMEngine:
...
@@ -1597,7 +1577,7 @@ class LLMEngine:
seq_span
.
set_attribute
(
SpanAttributes
.
LLM_LATENCY_E2E
,
e2e_time
)
seq_span
.
set_attribute
(
SpanAttributes
.
LLM_LATENCY_E2E
,
e2e_time
)
def
is_encoder_decoder_model
(
self
):
def
is_encoder_decoder_model
(
self
):
return
is_encoder_decoder_model
_config
(
self
.
model_config
)
return
self
.
model_config
.
is_encoder_decoder_model
def
is_embedding_model
(
self
):
def
is_embedding_model
(
self
):
return
is_embedding_model
_config
(
self
.
model_config
)
return
self
.
model_config
.
is_embedding_model
vllm/entrypoints/chat_utils.py
View file @
7eb4a51c
...
@@ -2,8 +2,7 @@ import codecs
...
@@ -2,8 +2,7 @@ import codecs
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
functools
import
lru_cache
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
(
Any
,
Awaitable
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
from
typing
import
Any
,
Awaitable
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
cast
cast
,
final
)
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
...
@@ -59,7 +58,7 @@ ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
...
@@ -59,7 +58,7 @@ ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam
]
CustomChatCompletionMessageParam
]
@
final
# So that it should be compatible with Dict[str, str]
# TODO: Make fields ReadOnly once mypy supports it
class
ConversationMessage
(
TypedDict
):
class
ConversationMessage
(
TypedDict
):
role
:
str
role
:
str
content
:
str
content
:
str
...
...
vllm/entrypoints/llm.py
View file @
7eb4a51c
...
@@ -6,8 +6,8 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
...
@@ -6,8 +6,8 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.inputs
import
(
PromptInputs
,
TextPrompt
,
TokensPrompt
,
from
vllm.inputs
import
PromptInputs
,
TextPrompt
,
TokensPrompt
parse_and_batch_prompt
)
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
from
vllm.model_executor.guided_decoding
import
(
...
...
vllm/entrypoints/openai/logits_processors.py
View file @
7eb4a51c
...
@@ -40,9 +40,11 @@ def _get_allowed_token_ids_logits_processor(
...
@@ -40,9 +40,11 @@ def _get_allowed_token_ids_logits_processor(
return
AllowedTokenIdsLogitsProcessor
(
allowed_token_ids
)
return
AllowedTokenIdsLogitsProcessor
(
allowed_token_ids
)
def
logit_bias_logits_processor
(
logit_bias
:
Dict
[
str
,
def
logit_bias_logits_processor
(
float
],
token_ids
:
List
[
int
],
logit_bias
:
Dict
[
int
,
float
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
for
token_id
,
bias
in
logit_bias
.
items
():
for
token_id
,
bias
in
logit_bias
.
items
():
logits
[
token_id
]
+=
bias
logits
[
token_id
]
+=
bias
return
logits
return
logits
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
7eb4a51c
...
@@ -22,7 +22,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
...
@@ -22,7 +22,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeCompletionRequest
,
TokenizeCompletionRequest
,
TokenizeRequest
)
TokenizeRequest
)
# yapf: enable
# yapf: enable
from
vllm.inputs
import
parse_and_batch_prompt
from
vllm.inputs
.parse
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
from
vllm.model_executor.guided_decoding
import
(
...
...
vllm/inputs/__init__.py
View file @
7eb4a51c
from
.data
import
(
ExplicitEncoderDecoderPrompt
,
LLMInputs
,
ParsedText
,
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
ParsedToken
s
,
PromptInputs
,
SingletonPromptInputs
,
LLMInput
s
,
PromptInputs
,
SingletonPromptInputs
,
TextPrompt
,
TextPrompt
,
TokensPrompt
,
get
_prompt
_type
,
TokensPrompt
,
build_explicit_enc_dec
_prompt
,
is_valid_encoder_decoder_llm_inputs
,
parse_and_batch
_prompt
)
to_enc_dec_tuple_list
,
zip_enc_dec
_prompt
s
)
from
.registry
import
InputContext
,
InputRegistry
from
.registry
import
InputContext
,
InputRegistry
INPUT_REGISTRY
=
InputRegistry
()
INPUT_REGISTRY
=
InputRegistry
()
...
@@ -14,18 +14,17 @@ See also:
...
@@ -14,18 +14,17 @@ See also:
"""
"""
__all__
=
[
__all__
=
[
"ParsedText"
,
"ParsedTokens"
,
"parse_and_batch_prompt"
,
"TextPrompt"
,
"TextPrompt"
,
"TokensPrompt"
,
"TokensPrompt"
,
"PromptInputs"
,
"PromptInputs"
,
"SingletonPromptInputs"
,
"ExplicitEncoderDecoderPrompt"
,
"LLMInputs"
,
"LLMInputs"
,
"EncoderDecoderLLMInputs"
,
"build_explicit_enc_dec_prompt"
,
"to_enc_dec_tuple_list"
,
"zip_enc_dec_prompts"
,
"INPUT_REGISTRY"
,
"INPUT_REGISTRY"
,
"InputContext"
,
"InputContext"
,
"InputRegistry"
,
"InputRegistry"
,
"get_prompt_type"
,
"is_valid_encoder_decoder_llm_inputs"
,
"ExplicitEncoderDecoderPrompt"
,
"SingletonPromptInputs"
,
]
]
vllm/inputs/data.py
View file @
7eb4a51c
from
typing
import
(
TYPE_CHECKING
,
List
,
Literal
,
Optional
,
Sequenc
e
,
from
typing
import
(
TYPE_CHECKING
,
Generic
,
Iterable
,
List
,
Optional
,
Tupl
e
,
TypedDict
,
Union
,
cast
,
overload
)
Union
)
from
typing_extensions
import
NotRequired
from
typing_extensions
import
NotRequired
,
TypedDict
,
TypeVar
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal
import
MultiModalDataDict
class
ParsedText
(
TypedDict
):
content
:
str
is_tokens
:
Literal
[
False
]
class
ParsedTokens
(
TypedDict
):
content
:
List
[
int
]
is_tokens
:
Literal
[
True
]
# https://github.com/vllm-project/vllm/pull/4028
@
overload
def
parse_and_batch_prompt
(
prompt
:
Union
[
str
,
List
[
str
]])
->
Sequence
[
ParsedText
]:
...
@
overload
def
parse_and_batch_prompt
(
prompt
:
Union
[
List
[
int
],
List
[
List
[
int
]]])
->
Sequence
[
ParsedTokens
]:
...
def
parse_and_batch_prompt
(
prompt
:
Union
[
str
,
List
[
str
],
List
[
int
],
List
[
List
[
int
]]],
)
->
Union
[
Sequence
[
ParsedText
],
Sequence
[
ParsedTokens
]]:
if
isinstance
(
prompt
,
str
):
# case 1: a string
return
[
ParsedText
(
content
=
prompt
,
is_tokens
=
False
)]
if
isinstance
(
prompt
,
list
):
if
len
(
prompt
)
==
0
:
raise
ValueError
(
"please provide at least one prompt"
)
if
isinstance
(
prompt
[
0
],
str
):
# case 2: array of strings
return
[
ParsedText
(
content
=
elem
,
is_tokens
=
False
)
for
elem
in
cast
(
List
[
str
],
prompt
)
]
if
isinstance
(
prompt
[
0
],
int
):
# case 3: array of tokens
elem
=
cast
(
List
[
int
],
prompt
)
return
[
ParsedTokens
(
content
=
elem
,
is_tokens
=
True
)]
if
isinstance
(
prompt
[
0
],
list
):
if
len
(
prompt
[
0
])
==
0
:
raise
ValueError
(
"please provide at least one prompt"
)
if
isinstance
(
prompt
[
0
][
0
],
int
):
# case 4: array of token arrays
return
[
ParsedTokens
(
content
=
elem
,
is_tokens
=
True
)
for
elem
in
cast
(
List
[
List
[
int
]],
prompt
)
]
raise
ValueError
(
"prompt must be a string, array of strings, "
"array of tokens, or array of token arrays"
)
class
TextPrompt
(
TypedDict
):
class
TextPrompt
(
TypedDict
):
"""Schema for a text prompt."""
"""Schema for a text prompt."""
...
@@ -103,39 +44,49 @@ Note that "singleton" is as opposed to a data structure
...
@@ -103,39 +44,49 @@ Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
the user desires to express both the encoder & decoder
prompts explicitly, i.e. ExplicitEncoderDecoderPrompt
prompts explicitly, i.e.
:class:`
ExplicitEncoderDecoderPrompt
`
A prompt of type SingletonPromptInputs may be employed
A prompt of type
:class:`
SingletonPromptInputs
`
may be employed
as (1) input to a decoder-only model, (2) input to
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
(3) as a member of a larger data structure encapsulating
more than one prompt, i.e. ExplicitEncoderDecoderPrompt
more than one prompt, i.e.
:class:`
ExplicitEncoderDecoderPrompt
`
"""
"""
_T1_co
=
TypeVar
(
"_T1_co"
,
bound
=
SingletonPromptInputs
,
default
=
SingletonPromptInputs
,
covariant
=
True
)
_T2_co
=
TypeVar
(
"_T2_co"
,
bound
=
SingletonPromptInputs
,
default
=
SingletonPromptInputs
,
covariant
=
True
)
class
ExplicitEncoderDecoderPrompt
(
TypedDict
):
# TODO: Make fields ReadOnly once mypy supports it
class
ExplicitEncoderDecoderPrompt
(
TypedDict
,
Generic
[
_T1_co
,
_T2_co
]):
"""Represents an encoder/decoder model input prompt,
"""Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a
comprising an explicit encoder prompt and a
decoder prompt.
decoder prompt.
The encoder and decoder prompts, respectively,
The encoder and decoder prompts, respectively,
may formatted according to any of the
may formatted according to any of the
SingletonPromptInputs schemas, and are not
:class:`
SingletonPromptInputs
`
schemas, and are not
required to have the same schema.
required to have the same schema.
Only the encoder prompt may have multi-modal data.
Only the encoder prompt may have multi-modal data.
Note that an ExplicitEncoderDecoderPrompt may not
Note that an
:class:`
ExplicitEncoderDecoderPrompt
`
may not
be used as an input to a decoder-only model,
be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure
may not
themselves
fields of this data structure themselves
must be
must be
SingletonPromptInputs instances.
:class:`
SingletonPromptInputs
`
instances.
"""
"""
encoder_prompt
:
SingletonPromptInputs
encoder_prompt
:
_T1_co
decoder_prompt
:
SingletonPromptInputs
decoder_prompt
:
Optional
[
_T2_co
]
PromptInputs
=
Union
[
SingletonPromptInputs
,
ExplicitEncoderDecoderPrompt
]
PromptInputs
=
Union
[
SingletonPromptInputs
,
ExplicitEncoderDecoderPrompt
]
...
@@ -150,60 +101,12 @@ both decoder-only and encoder/decoder input types:
...
@@ -150,60 +101,12 @@ both decoder-only and encoder/decoder input types:
"""
"""
def
_has_required_keys
(
d
:
dict
,
required_keys
:
set
,
)
->
bool
:
return
required_keys
.
issubset
(
d
.
keys
())
def
get_prompt_type
(
prompt
:
Optional
[
PromptInputs
])
->
Optional
[
str
]:
"""
Get the type-name of the prompt argument instance, given that
isinstance() cannot apply to TypedDict subclasses directly.
If the prompt is None, return 'None' as the type name.
Arguments:
* prompt: LLM input prompt or None
Returns:
* String representation of prompt type
"""
if
prompt
is
None
:
return
'None'
required_keys_dict
=
{
'TextPrompt'
:
{
'prompt'
},
'TokensPrompt'
:
{
'prompt_token_ids'
},
'ExplicitEncoderDecoder'
:
{
'encoder_prompt'
,
'decoder_prompt'
},
}
if
isinstance
(
prompt
,
dict
):
for
(
ptype
,
required_keys
)
in
required_keys_dict
.
items
():
# Ignore type checking in the conditional below because type
# checker does not understand that is_dict(prompt) narrows
# down the possible types
if
_has_required_keys
(
prompt
,
# type: ignore
required_keys
):
return
ptype
raise
ValueError
(
f
"Invalid prompt
{
prompt
}
, valid types are "
"required_keys_dict={required_keys_dict}"
)
if
isinstance
(
prompt
,
str
):
return
"str"
raise
ValueError
(
f
"Invalid prompt
{
prompt
}
"
)
class
LLMInputs
(
TypedDict
):
class
LLMInputs
(
TypedDict
):
"""
"""
The inputs in :class:`~vllm.LLMEngine` before they are
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
passed to the model executor.
This specifies the data required for decoder-only models.
"""
"""
prompt_token_ids
:
List
[
int
]
prompt_token_ids
:
List
[
int
]
"""The token IDs of the prompt."""
"""The token IDs of the prompt."""
...
@@ -213,7 +116,21 @@ class LLMInputs(TypedDict):
...
@@ -213,7 +116,21 @@ class LLMInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available.
The original prompt text corresponding to the token IDs, if available.
"""
"""
encoder_prompt_token_ids
:
NotRequired
[
List
[
int
]]
multi_modal_data
:
NotRequired
[
Optional
[
"MultiModalDataDict"
]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class
EncoderDecoderLLMInputs
(
LLMInputs
):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the required data for encoder-decoder models.
"""
encoder_prompt_token_ids
:
List
[
int
]
"""The token IDs of the encoder prompt."""
"""The token IDs of the encoder prompt."""
encoder_prompt
:
NotRequired
[
Optional
[
str
]]
encoder_prompt
:
NotRequired
[
Optional
[
str
]]
...
@@ -222,20 +139,40 @@ class LLMInputs(TypedDict):
...
@@ -222,20 +139,40 @@ class LLMInputs(TypedDict):
available.
available.
"""
"""
multi_modal_data
:
NotRequired
[
Optional
[
"MultiModalDataDict"
]]
"""
_T1
=
TypeVar
(
"_T1"
,
Optional multi-modal data to pass to the model,
bound
=
SingletonPromptInputs
,
if the model supports it.
default
=
SingletonPromptInputs
)
"""
_T2
=
TypeVar
(
"_T2"
,
bound
=
SingletonPromptInputs
,
default
=
SingletonPromptInputs
)
def
build_explicit_enc_dec_prompt
(
encoder_prompt
:
_T1
,
decoder_prompt
:
Optional
[
_T2
],
)
->
ExplicitEncoderDecoderPrompt
[
_T1
,
_T2
]:
return
ExplicitEncoderDecoderPrompt
(
encoder_prompt
=
encoder_prompt
,
decoder_prompt
=
decoder_prompt
)
def
is_valid_encoder_decoder_llm_inputs
(
inputs
:
LLMInputs
)
->
bool
:
def
zip_enc_dec_prompts
(
enc_prompts
:
Iterable
[
_T1
],
dec_prompts
:
Iterable
[
Optional
[
_T2
]],
)
->
List
[
ExplicitEncoderDecoderPrompt
[
_T1
,
_T2
]]:
"""
"""
Return True if the LLMInputs instance has the correct configuration
Zip encoder and decoder prompts together into a list of
for e
ncoder
/d
ecoder.
:class:`ExplicitE
ncoder
D
ecoder
Prompt` instances
.
"""
"""
return
[
build_explicit_enc_dec_prompt
(
encoder_prompt
,
decoder_prompt
)
for
(
encoder_prompt
,
decoder_prompt
)
in
zip
(
enc_prompts
,
dec_prompts
)
]
# True if encoder prompt token ids field exists &
def
to_enc_dec_tuple_list
(
# is not None
enc_dec_prompts
:
Iterable
[
ExplicitEncoderDecoderPrompt
[
_T1
,
_T2
]],
return
(
'encoder_prompt_token_ids'
in
inputs
)
->
List
[
Tuple
[
_T1
,
Optional
[
_T2
]]]:
and
inputs
[
'encoder_prompt_token_ids'
]
is
not
None
)
return
[(
enc_dec_prompt
[
"encoder_prompt"
],
enc_dec_prompt
[
"decoder_prompt"
])
for
enc_dec_prompt
in
enc_dec_prompts
]
vllm/inputs/parse.py
0 → 100644
View file @
7eb4a51c
from
typing
import
List
,
Literal
,
Sequence
,
TypedDict
,
Union
,
overload
from
typing_extensions
import
TypeIs
from
vllm.utils
import
is_list_of
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
LLMInputs
,
PromptInputs
)
class
ParsedText
(
TypedDict
):
content
:
str
is_tokens
:
Literal
[
False
]
class
ParsedTokens
(
TypedDict
):
content
:
List
[
int
]
is_tokens
:
Literal
[
True
]
@
overload
def
parse_and_batch_prompt
(
prompt
:
Union
[
str
,
List
[
str
]])
->
Sequence
[
ParsedText
]:
...
@
overload
def
parse_and_batch_prompt
(
prompt
:
Union
[
List
[
int
],
List
[
List
[
int
]]])
->
Sequence
[
ParsedTokens
]:
...
def
parse_and_batch_prompt
(
prompt
:
Union
[
str
,
List
[
str
],
List
[
int
],
List
[
List
[
int
]]],
)
->
Union
[
Sequence
[
ParsedText
],
Sequence
[
ParsedTokens
]]:
if
isinstance
(
prompt
,
str
):
# case 1: a string
return
[
ParsedText
(
content
=
prompt
,
is_tokens
=
False
)]
if
isinstance
(
prompt
,
list
):
if
len
(
prompt
)
==
0
:
raise
ValueError
(
"please provide at least one prompt"
)
if
is_list_of
(
prompt
,
str
):
# case 2: array of strings
return
[
ParsedText
(
content
=
elem
,
is_tokens
=
False
)
for
elem
in
prompt
]
if
is_list_of
(
prompt
,
int
):
# case 3: array of tokens
return
[
ParsedTokens
(
content
=
prompt
,
is_tokens
=
True
)]
if
is_list_of
(
prompt
,
list
):
if
len
(
prompt
[
0
])
==
0
:
raise
ValueError
(
"please provide at least one prompt"
)
if
is_list_of
(
prompt
[
0
],
int
):
# case 4: array of token arrays
return
[
ParsedTokens
(
content
=
elem
,
is_tokens
=
True
)
for
elem
in
prompt
]
raise
ValueError
(
"prompt must be a string, array of strings, "
"array of tokens, or array of token arrays"
)
def
is_explicit_encoder_decoder_prompt
(
inputs
:
PromptInputs
)
->
TypeIs
[
ExplicitEncoderDecoderPrompt
]:
return
isinstance
(
inputs
,
dict
)
and
"encoder_prompt"
in
inputs
def
is_valid_encoder_decoder_llm_inputs
(
inputs
:
Union
[
LLMInputs
,
EncoderDecoderLLMInputs
],
)
->
TypeIs
[
EncoderDecoderLLMInputs
]:
return
"encoder_prompt_token_ids"
in
inputs
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment