Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e254497b
"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "5819ca8944af4f7dcbac3c6b73179f760e05910d"
Unverified
Commit
e254497b
authored
May 11, 2024
by
Chang Su
Committed by
GitHub
May 11, 2024
Browse files
[Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734)
parent
4e121310
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
599 additions
and
90 deletions
+599
-90
examples/offline_inference_embedding.py
examples/offline_inference_embedding.py
+17
-0
examples/openai_embedding_client.py
examples/openai_embedding_client.py
+23
-0
requirements-dev.txt
requirements-dev.txt
+6
-3
tests/conftest.py
tests/conftest.py
+30
-8
tests/engine/output_processor/test_multi_step.py
tests/engine/output_processor/test_multi_step.py
+6
-6
tests/entrypoints/openai/test_serving_chat.py
tests/entrypoints/openai/test_serving_chat.py
+1
-0
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+95
-1
tests/models/test_embedding.py
tests/models/test_embedding.py
+44
-0
tests/samplers/test_logits_processor.py
tests/samplers/test_logits_processor.py
+3
-3
tests/samplers/test_seeded_generate.py
tests/samplers/test_seeded_generate.py
+1
-1
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+3
-3
tests/test_sequence.py
tests/test_sequence.py
+6
-6
vllm/__init__.py
vllm/__init__.py
+6
-1
vllm/config.py
vllm/config.py
+15
-0
vllm/core/embedding_model_block_manager.py
vllm/core/embedding_model_block_manager.py
+84
-0
vllm/core/interfaces.py
vllm/core/interfaces.py
+5
-0
vllm/core/scheduler.py
vllm/core/scheduler.py
+8
-2
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+130
-28
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+115
-28
No files found.
examples/offline_inference_embedding.py
0 → 100644
View file @
e254497b
from
vllm
import
LLM
# Sample prompts.
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# Create an LLM.
model
=
LLM
(
model
=
"intfloat/e5-mistral-7b-instruct"
,
enforce_eager
=
True
)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs
=
model
.
encode
(
prompts
)
# Print the outputs.
for
output
in
outputs
:
print
(
output
.
outputs
.
embedding
)
# list of 4096 floats
examples/openai_embedding_client.py
0 → 100644
View file @
e254497b
from
openai
import
OpenAI
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key
=
"EMPTY"
openai_api_base
=
"http://localhost:8000/v1"
client
=
OpenAI
(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key
=
openai_api_key
,
base_url
=
openai_api_base
,
)
models
=
client
.
models
.
list
()
model
=
models
.
data
[
0
].
id
responses
=
client
.
embeddings
.
create
(
input
=
[
"Hello my name is"
,
"The best thing about vLLM is that it supports many different models"
],
model
=
model
)
for
data
in
responses
.
data
:
print
(
data
.
embedding
)
# list of float of len 4096
requirements-dev.txt
View file @
e254497b
...
@@ -19,12 +19,15 @@ pytest-forked
...
@@ -19,12 +19,15 @@ pytest-forked
pytest-asyncio
pytest-asyncio
pytest-rerunfailures
pytest-rerunfailures
pytest-shard
pytest-shard
httpx
# testing utils
awscli
einops # required for MPT
einops # required for MPT
httpx
peft
requests
requests
ray
ray
peft
sentence-transformers # required for embedding
awscli
# Benchmarking
# Benchmarking
aiohttp
aiohttp
...
...
tests/conftest.py
View file @
e254497b
...
@@ -133,6 +133,10 @@ _VISION_LANGUAGE_MODELS = {
...
@@ -133,6 +133,10 @@ _VISION_LANGUAGE_MODELS = {
"llava-hf/llava-1.5-7b-hf"
:
LlavaForConditionalGeneration
,
"llava-hf/llava-1.5-7b-hf"
:
LlavaForConditionalGeneration
,
}
}
_EMBEDDING_MODELS
=
[
"intfloat/e5-mistral-7b-instruct"
,
]
class
HfRunner
:
class
HfRunner
:
...
@@ -145,14 +149,7 @@ class HfRunner:
...
@@ -145,14 +149,7 @@ class HfRunner:
assert
dtype
in
_STR_DTYPE_TO_TORCH_DTYPE
assert
dtype
in
_STR_DTYPE_TO_TORCH_DTYPE
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
self
.
model_name
=
model_name
self
.
model_name
=
model_name
if
model_name
not
in
_VISION_LANGUAGE_MODELS
:
if
model_name
in
_VISION_LANGUAGE_MODELS
:
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
True
,
).
cuda
()
self
.
processor
=
None
else
:
self
.
model
=
_VISION_LANGUAGE_MODELS
[
model_name
].
from_pretrained
(
self
.
model
=
_VISION_LANGUAGE_MODELS
[
model_name
].
from_pretrained
(
model_name
,
model_name
,
torch_dtype
=
torch_dtype
,
torch_dtype
=
torch_dtype
,
...
@@ -162,6 +159,20 @@ class HfRunner:
...
@@ -162,6 +159,20 @@ class HfRunner:
model_name
,
model_name
,
torch_dtype
=
torch_dtype
,
torch_dtype
=
torch_dtype
,
)
)
elif
model_name
in
_EMBEDDING_MODELS
:
# Lazy init required for AMD CI
from
sentence_transformers
import
SentenceTransformer
self
.
model
=
SentenceTransformer
(
model_name
,
device
=
"cpu"
,
).
to
(
dtype
=
torch_dtype
).
cuda
()
else
:
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
True
,
).
cuda
()
self
.
processor
=
None
if
tokenizer_name
is
None
:
if
tokenizer_name
is
None
:
tokenizer_name
=
model_name
tokenizer_name
=
model_name
self
.
tokenizer
=
get_tokenizer
(
tokenizer_name
,
trust_remote_code
=
True
)
self
.
tokenizer
=
get_tokenizer
(
tokenizer_name
,
trust_remote_code
=
True
)
...
@@ -334,6 +345,9 @@ class HfRunner:
...
@@ -334,6 +345,9 @@ class HfRunner:
return
[(
output_ids
,
output_str
,
output_logprobs
)
return
[(
output_ids
,
output_str
,
output_logprobs
)
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
def
encode
(
self
,
prompts
:
List
[
str
])
->
List
[
List
[
torch
.
Tensor
]]:
return
self
.
model
.
encode
(
prompts
)
def
__del__
(
self
):
def
__del__
(
self
):
del
self
.
model
del
self
.
model
cleanup
()
cleanup
()
...
@@ -459,6 +473,14 @@ class VllmRunner:
...
@@ -459,6 +473,14 @@ class VllmRunner:
outputs
=
self
.
generate
(
prompts
,
beam_search_params
)
outputs
=
self
.
generate
(
prompts
,
beam_search_params
)
return
outputs
return
outputs
def
encode
(
self
,
prompts
:
List
[
str
])
->
List
[
List
[
float
]]:
req_outputs
=
self
.
model
.
encode
(
prompts
)
outputs
=
[]
for
req_output
in
req_outputs
:
embedding
=
req_output
.
outputs
.
embedding
outputs
.
append
(
embedding
)
return
outputs
def
__del__
(
self
):
def
__del__
(
self
):
del
self
.
model
del
self
.
model
cleanup
()
cleanup
()
...
...
tests/engine/output_processor/test_multi_step.py
View file @
e254497b
...
@@ -9,8 +9,8 @@ from vllm.core.scheduler import Scheduler
...
@@ -9,8 +9,8 @@ from vllm.core.scheduler import Scheduler
from
vllm.engine.output_processor.multi_step
import
MultiStepOutputProcessor
from
vllm.engine.output_processor.multi_step
import
MultiStepOutputProcessor
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
Logprob
,
SequenceGroupOutput
,
SequenceOutput
,
from
vllm.sequence
import
(
Completion
SequenceGroupOutput
,
Logprob
,
SequenceStatus
)
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
...
@@ -51,7 +51,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
...
@@ -51,7 +51,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
new_token_ids
=
list
(
range
(
num_new_tokens
))
new_token_ids
=
list
(
range
(
num_new_tokens
))
outputs
=
[
outputs
=
[
SequenceGroupOutput
(
Completion
SequenceGroupOutput
(
samples
=
[
samples
=
[
SequenceOutput
(
SequenceOutput
(
parent_seq_id
=
seq
.
seq_id
,
parent_seq_id
=
seq
.
seq_id
,
...
@@ -103,7 +103,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
...
@@ -103,7 +103,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
new_token_ids
=
list
(
range
(
num_new_tokens
))
new_token_ids
=
list
(
range
(
num_new_tokens
))
outputs
=
[
outputs
=
[
SequenceGroupOutput
(
Completion
SequenceGroupOutput
(
samples
=
[
samples
=
[
SequenceOutput
(
SequenceOutput
(
parent_seq_id
=
seq
.
seq_id
,
parent_seq_id
=
seq
.
seq_id
,
...
@@ -170,7 +170,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
...
@@ -170,7 +170,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
new_token_ids
[
eos_index
]
=
eos_token_id
new_token_ids
[
eos_index
]
=
eos_token_id
outputs
=
[
outputs
=
[
SequenceGroupOutput
(
Completion
SequenceGroupOutput
(
samples
=
[
samples
=
[
SequenceOutput
(
SequenceOutput
(
parent_seq_id
=
seq
.
seq_id
,
parent_seq_id
=
seq
.
seq_id
,
...
@@ -239,7 +239,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
...
@@ -239,7 +239,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
new_token_ids
[
eos_index
]
=
eos_token_id
new_token_ids
[
eos_index
]
=
eos_token_id
outputs
=
[
outputs
=
[
SequenceGroupOutput
(
Completion
SequenceGroupOutput
(
samples
=
[
samples
=
[
SequenceOutput
(
SequenceOutput
(
parent_seq_id
=
seq
.
seq_id
,
parent_seq_id
=
seq
.
seq_id
,
...
...
tests/entrypoints/openai/test_serving_chat.py
View file @
e254497b
...
@@ -14,6 +14,7 @@ class MockModelConfig:
...
@@ -14,6 +14,7 @@ class MockModelConfig:
tokenizer_mode
=
"auto"
tokenizer_mode
=
"auto"
max_model_len
=
100
max_model_len
=
100
tokenizer_revision
=
None
tokenizer_revision
=
None
embedding_mode
=
False
@
dataclass
@
dataclass
...
...
tests/entrypoints/test_openai_server.py
View file @
e254497b
...
@@ -23,6 +23,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
...
@@ -23,6 +23,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
MAX_SERVER_START_WAIT_S
=
600
# wait for server to start for 60 seconds
MAX_SERVER_START_WAIT_S
=
600
# wait for server to start for 60 seconds
# any model with a chat template should work here
# any model with a chat template should work here
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
EMBEDDING_MODEL_NAME
=
"intfloat/e5-mistral-7b-instruct"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
# generation quality here
LORA_NAME
=
"typeof/zephyr-7b-beta-lora"
LORA_NAME
=
"typeof/zephyr-7b-beta-lora"
...
@@ -121,7 +122,7 @@ def zephyr_lora_files():
...
@@ -121,7 +122,7 @@ def zephyr_lora_files():
return
snapshot_download
(
repo_id
=
LORA_NAME
)
return
snapshot_download
(
repo_id
=
LORA_NAME
)
@
pytest
.
fixture
(
scope
=
"
session
"
)
@
pytest
.
fixture
(
scope
=
"
module
"
)
def
server
(
zephyr_lora_files
):
def
server
(
zephyr_lora_files
):
ray
.
init
()
ray
.
init
()
server_runner
=
ServerRunner
.
remote
([
server_runner
=
ServerRunner
.
remote
([
...
@@ -150,6 +151,25 @@ def server(zephyr_lora_files):
...
@@ -150,6 +151,25 @@ def server(zephyr_lora_files):
ray
.
shutdown
()
ray
.
shutdown
()
@
pytest
.
fixture
(
scope
=
"module"
)
def
embedding_server
(
zephyr_lora_files
):
ray
.
shutdown
()
ray
.
init
()
server_runner
=
ServerRunner
.
remote
([
"--model"
,
EMBEDDING_MODEL_NAME
,
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"bfloat16"
,
"--max-model-len"
,
"8192"
,
"--enforce-eager"
,
])
ray
.
get
(
server_runner
.
ready
.
remote
())
yield
server_runner
ray
.
shutdown
()
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
client
():
def
client
():
client
=
openai
.
AsyncOpenAI
(
client
=
openai
.
AsyncOpenAI
(
...
@@ -890,5 +910,79 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
...
@@ -890,5 +910,79 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
or
"less_than_equal"
in
exc_info
.
value
.
message
)
or
"less_than_equal"
in
exc_info
.
value
.
message
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
EMBEDDING_MODEL_NAME
],
)
async
def
test_single_embedding
(
embedding_server
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
input
=
[
"The chef prepared a delicious meal."
,
]
# test single embedding
embeddings
=
await
client
.
embeddings
.
create
(
model
=
model_name
,
input
=
input
,
encoding_format
=
"float"
,
)
assert
embeddings
.
id
is
not
None
assert
embeddings
.
data
is
not
None
and
len
(
embeddings
.
data
)
==
1
assert
len
(
embeddings
.
data
[
0
].
embedding
)
==
4096
assert
embeddings
.
usage
.
completion_tokens
==
0
assert
embeddings
.
usage
.
prompt_tokens
==
9
assert
embeddings
.
usage
.
total_tokens
==
9
# test using token IDs
input
=
[
1
,
1
,
1
,
1
,
1
]
embeddings
=
await
client
.
embeddings
.
create
(
model
=
model_name
,
input
=
input
,
encoding_format
=
"float"
,
)
assert
embeddings
.
id
is
not
None
assert
embeddings
.
data
is
not
None
and
len
(
embeddings
.
data
)
==
1
assert
len
(
embeddings
.
data
[
0
].
embedding
)
==
4096
assert
embeddings
.
usage
.
completion_tokens
==
0
assert
embeddings
.
usage
.
prompt_tokens
==
5
assert
embeddings
.
usage
.
total_tokens
==
5
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
EMBEDDING_MODEL_NAME
],
)
async
def
test_batch_embedding
(
embedding_server
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
# test List[str]
inputs
=
[
"The cat sat on the mat."
,
"A feline was resting on a rug."
,
"Stars twinkle brightly in the night sky."
]
embeddings
=
await
client
.
embeddings
.
create
(
model
=
model_name
,
input
=
inputs
,
encoding_format
=
"float"
,
)
assert
embeddings
.
id
is
not
None
assert
embeddings
.
data
is
not
None
and
len
(
embeddings
.
data
)
==
3
assert
len
(
embeddings
.
data
[
0
].
embedding
)
==
4096
# test List[List[int]]
inputs
=
[[
4
,
5
,
7
,
9
,
20
],
[
15
,
29
,
499
],
[
24
,
24
,
24
,
24
,
24
],
[
25
,
32
,
64
,
77
]]
embeddings
=
await
client
.
embeddings
.
create
(
model
=
model_name
,
input
=
inputs
,
encoding_format
=
"float"
,
)
assert
embeddings
.
id
is
not
None
assert
embeddings
.
data
is
not
None
and
len
(
embeddings
.
data
)
==
4
assert
len
(
embeddings
.
data
[
0
].
embedding
)
==
4096
assert
embeddings
.
usage
.
completion_tokens
==
0
assert
embeddings
.
usage
.
prompt_tokens
==
17
assert
embeddings
.
usage
.
total_tokens
==
17
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
tests/models/test_embedding.py
0 → 100644
View file @
e254497b
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
Run `pytest tests/models/test_llama_embedding.py`.
"""
import
pytest
import
torch
import
torch.nn.functional
as
F
MODELS
=
[
"intfloat/e5-mistral-7b-instruct"
,
]
def
compare_embeddings
(
embeddings1
,
embeddings2
):
similarities
=
[
F
.
cosine_similarity
(
torch
.
tensor
(
e1
),
torch
.
tensor
(
e2
),
dim
=
0
)
for
e1
,
e2
in
zip
(
embeddings1
,
embeddings2
)
]
return
similarities
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
)
->
None
:
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_outputs
=
hf_model
.
encode
(
example_prompts
)
del
hf_model
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
)
vllm_outputs
=
vllm_model
.
encode
(
example_prompts
)
del
vllm_model
similarities
=
compare_embeddings
(
hf_outputs
,
vllm_outputs
)
all_similarities
=
torch
.
stack
(
similarities
)
tolerance
=
1e-2
assert
torch
.
all
((
all_similarities
<=
1.0
+
tolerance
)
&
(
all_similarities
>=
1.0
-
tolerance
)
),
f
"Not all values are within
{
tolerance
}
of 1.0"
tests/samplers/test_logits_processor.py
View file @
e254497b
...
@@ -36,14 +36,14 @@ def test_logits_processor_force_generate(
...
@@ -36,14 +36,14 @@ def test_logits_processor_force_generate(
# test logits_processors when prompt_logprobs is not None
# test logits_processors when prompt_logprobs is not None
vllm_model
.
model
.
_add_request
(
vllm_model
.
model
.
_add_request
(
prompt
=
example_prompts
[
0
],
prompt
=
example_prompts
[
0
],
sampling_
params
=
params_with_logprobs
,
params
=
params_with_logprobs
,
prompt_token_ids
=
None
,
prompt_token_ids
=
None
,
)
)
# test prompt_logprobs is not None
# test prompt_logprobs is not None
vllm_model
.
model
.
_add_request
(
vllm_model
.
model
.
_add_request
(
prompt
=
example_prompts
[
1
],
prompt
=
example_prompts
[
1
],
sampling_
params
=
SamplingParams
(
params
=
SamplingParams
(
prompt_logprobs
=
3
,
prompt_logprobs
=
3
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
),
),
...
@@ -53,7 +53,7 @@ def test_logits_processor_force_generate(
...
@@ -53,7 +53,7 @@ def test_logits_processor_force_generate(
# test grouped requests
# test grouped requests
vllm_model
.
model
.
_add_request
(
vllm_model
.
model
.
_add_request
(
prompt
=
example_prompts
[
2
],
prompt
=
example_prompts
[
2
],
sampling_
params
=
SamplingParams
(
max_tokens
=
max_tokens
),
params
=
SamplingParams
(
max_tokens
=
max_tokens
),
prompt_token_ids
=
None
,
prompt_token_ids
=
None
,
)
)
...
...
tests/samplers/test_seeded_generate.py
View file @
e254497b
...
@@ -60,7 +60,7 @@ def test_random_sample_with_seed(
...
@@ -60,7 +60,7 @@ def test_random_sample_with_seed(
llm
.
_add_request
(
llm
.
_add_request
(
prompt
=
prompt
,
prompt
=
prompt
,
prompt_token_ids
=
None
,
prompt_token_ids
=
None
,
sampling_
params
=
params
,
params
=
params
,
)
)
results
=
llm
.
_run_engine
(
use_tqdm
=
False
)
results
=
llm
.
_run_engine
(
use_tqdm
=
False
)
...
...
tests/spec_decode/utils.py
View file @
e254497b
...
@@ -7,8 +7,8 @@ import torch
...
@@ -7,8 +7,8 @@ import torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
Logprob
,
SamplerOutput
,
SequenceData
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
S
equenceGroupMetad
ata
,
SequenceGroup
Output
,
S
amplerOutput
,
SequenceD
ata
,
SequenceGroup
Metadata
,
SequenceOutput
)
SequenceOutput
)
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
...
@@ -170,7 +170,7 @@ def create_sampler_output_list(
...
@@ -170,7 +170,7 @@ def create_sampler_output_list(
return
[
return
[
SamplerOutput
(
outputs
=
[
SamplerOutput
(
outputs
=
[
SequenceGroupOutput
(
Completion
SequenceGroupOutput
(
samples
=
[
samples
=
[
SequenceOutput
(
SequenceOutput
(
output_token
=
token_id
,
output_token
=
token_id
,
...
...
tests/test_sequence.py
View file @
e254497b
import
pytest
import
pytest
from
tests.core.utils
import
create_dummy_prompt
from
tests.core.utils
import
create_dummy_prompt
from
vllm.sequence
import
(
Sa
mple
rOutput
,
SequenceData
,
SequenceGroupOutput
,
from
vllm.sequence
import
(
Co
mple
tion
SequenceGroupOutput
,
SamplerOutput
,
SequenceOutput
)
SequenceData
,
SequenceOutput
)
@
pytest
.
fixture
@
pytest
.
fixture
def
sample_outputs
():
def
sample_outputs
():
return
[
return
[
SequenceGroupOutput
(
samples
=
[
Completion
SequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
0
,
output_token
=
i
,
logprobs
=
{})
SequenceOutput
(
parent_seq_id
=
0
,
output_token
=
i
,
logprobs
=
{})
],
],
prompt_logprobs
=
None
)
for
i
in
range
(
5
)
prompt_logprobs
=
None
)
for
i
in
range
(
5
)
]
]
...
@@ -32,10 +32,10 @@ def test_sampler_output_getitem(sampler_output, sample_outputs):
...
@@ -32,10 +32,10 @@ def test_sampler_output_getitem(sampler_output, sample_outputs):
def
test_sampler_output_setitem
(
sampler_output
):
def
test_sampler_output_setitem
(
sampler_output
):
new_output
=
SequenceGroupOutput
(
samples
=
[
new_output
=
Completion
SequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
0
,
output_token
=
99
,
logprobs
=
{})
SequenceOutput
(
parent_seq_id
=
0
,
output_token
=
99
,
logprobs
=
{})
],
],
prompt_logprobs
=
None
)
prompt_logprobs
=
None
)
sampler_output
[
2
]
=
new_output
sampler_output
[
2
]
=
new_output
assert
sampler_output
[
2
]
==
new_output
assert
sampler_output
[
2
]
==
new_output
...
...
vllm/__init__.py
View file @
e254497b
...
@@ -6,7 +6,9 @@ from vllm.engine.llm_engine import LLMEngine
...
@@ -6,7 +6,9 @@ from vllm.engine.llm_engine import LLMEngine
from
vllm.entrypoints.llm
import
LLM
from
vllm.entrypoints.llm
import
LLM
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.outputs
import
(
CompletionOutput
,
EmbeddingOutput
,
EmbeddingRequestOutput
,
RequestOutput
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
__version__
=
"0.4.2"
__version__
=
"0.4.2"
...
@@ -17,9 +19,12 @@ __all__ = [
...
@@ -17,9 +19,12 @@ __all__ = [
"SamplingParams"
,
"SamplingParams"
,
"RequestOutput"
,
"RequestOutput"
,
"CompletionOutput"
,
"CompletionOutput"
,
"EmbeddingOutput"
,
"EmbeddingRequestOutput"
,
"LLMEngine"
,
"LLMEngine"
,
"EngineArgs"
,
"EngineArgs"
,
"AsyncLLMEngine"
,
"AsyncLLMEngine"
,
"AsyncEngineArgs"
,
"AsyncEngineArgs"
,
"initialize_ray_cluster"
,
"initialize_ray_cluster"
,
"PoolingParams"
,
]
]
vllm/config.py
View file @
e254497b
...
@@ -9,6 +9,7 @@ from transformers import PretrainedConfig
...
@@ -9,6 +9,7 @@ from transformers import PretrainedConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
(
QUANTIZATION_METHODS
,
from
vllm.model_executor.layers.quantization
import
(
QUANTIZATION_METHODS
,
get_quantization_config
)
get_quantization_config
)
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.utils
import
get_cpu_memory
,
is_cpu
,
is_hip
,
is_neuron
from
vllm.utils
import
get_cpu_memory
,
is_cpu
,
is_hip
,
is_neuron
...
@@ -22,6 +23,7 @@ if TYPE_CHECKING:
...
@@ -22,6 +23,7 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_GB
=
1
<<
30
_GB
=
1
<<
30
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS
=
32768
class
ModelConfig
:
class
ModelConfig
:
...
@@ -126,6 +128,7 @@ class ModelConfig:
...
@@ -126,6 +128,7 @@ class ModelConfig:
served_model_name
)
served_model_name
)
if
not
self
.
skip_tokenizer_init
:
if
not
self
.
skip_tokenizer_init
:
self
.
_verify_tokenizer_mode
()
self
.
_verify_tokenizer_mode
()
self
.
_verify_embedding_mode
()
self
.
_verify_quantization
()
self
.
_verify_quantization
()
self
.
_verify_cuda_graph
()
self
.
_verify_cuda_graph
()
...
@@ -137,6 +140,11 @@ class ModelConfig:
...
@@ -137,6 +140,11 @@ class ModelConfig:
"either 'auto' or 'slow'."
)
"either 'auto' or 'slow'."
)
self
.
tokenizer_mode
=
tokenizer_mode
self
.
tokenizer_mode
=
tokenizer_mode
def
_verify_embedding_mode
(
self
)
->
None
:
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
self
.
embedding_mode
=
any
(
ModelRegistry
.
is_embedding_model
(
arch
)
for
arch
in
architectures
)
def
_verify_quantization
(
self
)
->
None
:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
rocm_supported_quantization
=
[
"gptq"
,
"squeezellm"
]
rocm_supported_quantization
=
[
"gptq"
,
"squeezellm"
]
...
@@ -591,6 +599,7 @@ class SchedulerConfig:
...
@@ -591,6 +599,7 @@ class SchedulerConfig:
prompt latency) before scheduling next prompt.
prompt latency) before scheduling next prompt.
enable_chunked_prefill: If True, prefill requests can be chunked based
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
on the remaining max_num_batched_tokens.
embedding_mode: Whether the running model is for embedding.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -602,6 +611,7 @@ class SchedulerConfig:
...
@@ -602,6 +611,7 @@ class SchedulerConfig:
num_lookahead_slots
:
int
=
0
,
num_lookahead_slots
:
int
=
0
,
delay_factor
:
float
=
0.0
,
delay_factor
:
float
=
0.0
,
enable_chunked_prefill
:
bool
=
False
,
enable_chunked_prefill
:
bool
=
False
,
embedding_mode
:
Optional
[
bool
]
=
False
,
)
->
None
:
)
->
None
:
if
max_num_batched_tokens
is
not
None
:
if
max_num_batched_tokens
is
not
None
:
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
max_num_batched_tokens
=
max_num_batched_tokens
...
@@ -610,6 +620,10 @@ class SchedulerConfig:
...
@@ -610,6 +620,10 @@ class SchedulerConfig:
# It is the values that have the best balance between ITL
# It is the values that have the best balance between ITL
# and TTFT on A100. Note it is not optimized for throughput.
# and TTFT on A100. Note it is not optimized for throughput.
self
.
max_num_batched_tokens
=
512
self
.
max_num_batched_tokens
=
512
elif
embedding_mode
:
# For embedding, choose specific value for higher throughput
self
.
max_num_batched_tokens
=
max
(
max_model_len
,
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS
)
else
:
else
:
# If max_model_len is too short, use 2048 as the default value
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
# for higher throughput.
...
@@ -623,6 +637,7 @@ class SchedulerConfig:
...
@@ -623,6 +637,7 @@ class SchedulerConfig:
self
.
num_lookahead_slots
=
num_lookahead_slots
self
.
num_lookahead_slots
=
num_lookahead_slots
self
.
delay_factor
=
delay_factor
self
.
delay_factor
=
delay_factor
self
.
chunked_prefill_enabled
=
enable_chunked_prefill
self
.
chunked_prefill_enabled
=
enable_chunked_prefill
self
.
embedding_mode
=
embedding_mode
self
.
_verify_args
()
self
.
_verify_args
()
...
...
vllm/core/embedding_model_block_manager.py
0 → 100644
View file @
e254497b
from
typing
import
List
,
Tuple
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.sequence
import
Sequence
,
SequenceGroup
class
EmbeddingModelBlockSpaceManager
(
BlockSpaceManager
):
"""An embedding version of BlockSpaceManager for use in environments
with embedding models where block management is not required.
This class provides the same interface as BlockSpaceManager, but its
methods perform no actions or return simple values like True in specific
actions. It's designed to be used in scenarios where the overhead of
block management is unnecessary, such as in an embedding environment.
"""
def
__init__
(
self
,
**
kwargs
,
)
->
None
:
pass
def
can_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
AllocStatus
:
# Always return OK for dummy purposes
return
AllocStatus
.
OK
def
allocate
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
# No actual allocation logic needed
pass
def
can_append_slots
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
bool
:
return
True
def
append_slots
(
self
,
seq
:
Sequence
,
num_lookahead_slots
:
int
,
)
->
List
[
Tuple
[
int
,
int
]]:
return
None
# type: ignore
def
fork
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
pass
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
AllocStatus
:
return
AllocStatus
.
OK
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
List
[
Tuple
[
int
,
int
]]:
return
None
# type: ignore
def
can_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
return
True
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
Tuple
[
int
,
int
]]:
return
None
# type: ignore
def
free
(
self
,
seq
:
Sequence
)
->
None
:
# No operation on free
return
def
get_block_table
(
self
,
seq
:
Sequence
)
->
List
[
int
]:
return
None
# type: ignore
def
get_num_free_gpu_blocks
(
self
)
->
int
:
return
1
def
get_num_free_cpu_blocks
(
self
)
->
int
:
return
1
def
access_all_blocks_in_seq
(
self
,
seq
:
Sequence
,
access_time
:
float
,
)
->
None
:
pass
def
get_common_computed_block_ids
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
int
]:
return
None
# type: ignore
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
):
pass
vllm/core/interfaces.py
View file @
e254497b
...
@@ -35,6 +35,11 @@ class BlockSpaceManager(ABC):
...
@@ -35,6 +35,11 @@ class BlockSpaceManager(ABC):
from
vllm.core.block_manager_v2
import
BlockSpaceManagerV2
from
vllm.core.block_manager_v2
import
BlockSpaceManagerV2
return
BlockSpaceManagerV2
return
BlockSpaceManagerV2
if
version
==
"embedding"
:
from
vllm.core.embedding_model_block_manager
import
(
EmbeddingModelBlockSpaceManager
)
return
EmbeddingModelBlockSpaceManager
raise
ValueError
(
f
"Unknown version
{
version
=
}
"
)
raise
ValueError
(
f
"Unknown version
{
version
=
}
"
)
@
abstractmethod
@
abstractmethod
...
...
vllm/core/scheduler.py
View file @
e254497b
...
@@ -270,9 +270,14 @@ class Scheduler:
...
@@ -270,9 +270,14 @@ class Scheduler:
self
.
scheduler_config
.
max_model_len
,
self
.
scheduler_config
.
max_model_len
,
self
.
scheduler_config
.
max_num_batched_tokens
)
self
.
scheduler_config
.
max_num_batched_tokens
)
version
=
"v1"
if
self
.
scheduler_config
.
use_v2_block_manager
:
version
=
"v2"
if
self
.
scheduler_config
.
embedding_mode
:
version
=
"embedding"
BlockSpaceManagerImpl
=
BlockSpaceManager
.
get_block_space_manager_class
(
BlockSpaceManagerImpl
=
BlockSpaceManager
.
get_block_space_manager_class
(
version
=
"v2"
if
self
.
scheduler_config
.
version
)
use_v2_block_manager
else
"v1"
)
# Create the block space manager.
# Create the block space manager.
self
.
block_manager
=
BlockSpaceManagerImpl
(
self
.
block_manager
=
BlockSpaceManagerImpl
(
...
@@ -968,6 +973,7 @@ class Scheduler:
...
@@ -968,6 +973,7 @@ class Scheduler:
sampling_params
=
seq_group
.
sampling_params
,
sampling_params
=
seq_group
.
sampling_params
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
do_sample
=
do_sample
,
do_sample
=
do_sample
,
pooling_params
=
seq_group
.
pooling_params
,
token_chunk_size
=
token_chunk_size
,
token_chunk_size
=
token_chunk_size
,
lora_request
=
seq_group
.
lora_request
,
lora_request
=
seq_group
.
lora_request
,
computed_block_nums
=
common_computed_block_nums
,
computed_block_nums
=
common_computed_block_nums
,
...
...
vllm/engine/arg_utils.py
View file @
e254497b
...
@@ -574,6 +574,7 @@ class EngineArgs:
...
@@ -574,6 +574,7 @@ class EngineArgs:
speculative_config
.
num_lookahead_slots
),
speculative_config
.
num_lookahead_slots
),
delay_factor
=
self
.
scheduler_delay_factor
,
delay_factor
=
self
.
scheduler_delay_factor
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
embedding_mode
=
model_config
.
embedding_mode
,
)
)
lora_config
=
LoRAConfig
(
lora_config
=
LoRAConfig
(
max_lora_rank
=
self
.
max_lora_rank
,
max_lora_rank
=
self
.
max_lora_rank
,
...
...
vllm/engine/async_llm_engine.py
View file @
e254497b
...
@@ -14,7 +14,8 @@ from vllm.engine.llm_engine import LLMEngine
...
@@ -14,7 +14,8 @@ from vllm.engine.llm_engine import LLMEngine
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
,
MultiModalData
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
MultiModalData
,
SamplerOutput
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
...
@@ -47,15 +48,16 @@ def _raise_exception_on_finish(
...
@@ -47,15 +48,16 @@ def _raise_exception_on_finish(
class
AsyncStream
:
class
AsyncStream
:
"""A stream of RequestOutputs
f
or
a request that can be
"""A stream of RequestOutputs or
EmbeddingRequestOutputs for a request
iterated over asynchronously."""
that can be
iterated over asynchronously."""
def
__init__
(
self
,
request_id
:
str
)
->
None
:
def
__init__
(
self
,
request_id
:
str
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
_queue
:
asyncio
.
Queue
=
asyncio
.
Queue
()
self
.
_queue
:
asyncio
.
Queue
=
asyncio
.
Queue
()
self
.
_finished
=
False
self
.
_finished
=
False
def
put
(
self
,
item
:
Union
[
RequestOutput
,
Exception
])
->
None
:
def
put
(
self
,
item
:
Union
[
RequestOutput
,
EmbeddingRequestOutput
,
Exception
])
->
None
:
if
self
.
_finished
:
if
self
.
_finished
:
return
return
self
.
_queue
.
put_nowait
(
item
)
self
.
_queue
.
put_nowait
(
item
)
...
@@ -71,7 +73,7 @@ class AsyncStream:
...
@@ -71,7 +73,7 @@ class AsyncStream:
def
__aiter__
(
self
):
def
__aiter__
(
self
):
return
self
return
self
async
def
__anext__
(
self
)
->
RequestOutput
:
async
def
__anext__
(
self
)
->
Union
[
RequestOutput
,
EmbeddingRequestOutput
]
:
result
=
await
self
.
_queue
.
get
()
result
=
await
self
.
_queue
.
get
()
if
isinstance
(
result
,
Exception
):
if
isinstance
(
result
,
Exception
):
raise
result
raise
result
...
@@ -108,7 +110,8 @@ class RequestTracker:
...
@@ -108,7 +110,8 @@ class RequestTracker:
self
.
abort_request
(
rid
)
self
.
abort_request
(
rid
)
def
process_request_output
(
self
,
def
process_request_output
(
self
,
request_output
:
RequestOutput
,
request_output
:
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
*
,
*
,
verbose
:
bool
=
False
)
->
None
:
verbose
:
bool
=
False
)
->
None
:
"""Process a request output from the engine."""
"""Process a request output from the engine."""
...
@@ -196,7 +199,8 @@ class RequestTracker:
...
@@ -196,7 +199,8 @@ class RequestTracker:
class
_AsyncLLMEngine
(
LLMEngine
):
class
_AsyncLLMEngine
(
LLMEngine
):
"""Extension of LLMEngine to add async methods."""
"""Extension of LLMEngine to add async methods."""
async
def
step_async
(
self
)
->
List
[
RequestOutput
]:
async
def
step_async
(
self
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
"""Performs one decoding iteration and returns newly generated results.
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
The workers are ran asynchronously if possible.
...
@@ -251,7 +255,7 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -251,7 +255,7 @@ class _AsyncLLMEngine(LLMEngine):
self
,
self
,
request_id
:
str
,
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt
:
Optional
[
str
],
s
ampling
_p
arams
:
Samp
lingParams
,
params
:
Union
[
S
ampling
P
arams
,
Poo
lingParams
]
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -270,8 +274,8 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -270,8 +274,8 @@ class _AsyncLLMEngine(LLMEngine):
return
self
.
add_request
(
request_id
,
return
self
.
add_request
(
request_id
,
prompt
=
prompt
,
prompt
=
prompt
,
params
=
params
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
sampling_params
,
arrival_time
=
arrival_time
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
multi_modal_data
=
multi_modal_data
)
multi_modal_data
=
multi_modal_data
)
...
@@ -511,7 +515,7 @@ class AsyncLLMEngine:
...
@@ -511,7 +515,7 @@ class AsyncLLMEngine:
self
,
self
,
request_id
:
str
,
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt
:
Optional
[
str
],
s
ampling
_p
arams
:
Samp
lingParams
,
params
:
Union
[
S
ampling
P
arams
,
Poo
lingParams
]
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -528,9 +532,9 @@ class AsyncLLMEngine:
...
@@ -528,9 +532,9 @@ class AsyncLLMEngine:
max_log_len
]
max_log_len
]
logger
.
info
(
logger
.
info
(
"Received request %s: prompt: %r, "
"Received request %s: prompt: %r, "
"
sampling_
params: %s, prompt_token_ids: %s, "
"params: %s, prompt_token_ids: %s, "
"lora_request: %s."
,
request_id
,
shortened_prompt
,
"lora_request: %s."
,
request_id
,
shortened_prompt
,
params
,
sampling_params
,
shortened_token_ids
,
lora_request
)
shortened_token_ids
,
lora_request
)
if
not
self
.
is_running
:
if
not
self
.
is_running
:
if
self
.
start_engine_loop
:
if
self
.
start_engine_loop
:
...
@@ -562,7 +566,7 @@ class AsyncLLMEngine:
...
@@ -562,7 +566,7 @@ class AsyncLLMEngine:
stream
=
self
.
_request_tracker
.
add_request
(
stream
=
self
.
_request_tracker
.
add_request
(
request_id
,
request_id
,
prompt
=
prompt
,
prompt
=
prompt
,
sampling_params
=
sampling_
params
,
params
=
params
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
...
@@ -597,8 +601,8 @@ class AsyncLLMEngine:
...
@@ -597,8 +601,8 @@ class AsyncLLMEngine:
multi_modal_data: Multi modal data per request.
multi_modal_data: Multi modal data per request.
Yields:
Yields:
The output `RequestOutput` objects from the LLMEngine
for the
The output `RequestOutput` objects from the LLMEngine
request.
for the
request.
Details:
Details:
- If the engine is not running, start the background loop,
- If the engine is not running, start the background loop,
...
@@ -643,25 +647,123 @@ class AsyncLLMEngine:
...
@@ -643,25 +647,123 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> # Process and return the final output
>>> ...
>>> ...
"""
"""
# Preprocess the request.
async
for
output
in
self
.
process_request
(
arrival_time
=
time
.
time
()
try
:
stream
=
await
self
.
add_request
(
request_id
,
request_id
,
prompt
,
prompt
,
sampling_params
,
sampling_params
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
,
arrival_time
=
arrival_time
,
lora_request
,
lora_request
=
lora_request
,
multi_modal_data
,
multi_modal_data
=
multi_modal_data
,
):
)
yield
output
async
def
encode
(
self
,
prompt
:
Optional
[
str
],
pooling_params
:
PoolingParams
,
request_id
:
str
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
)
->
AsyncIterator
[
EmbeddingRequestOutput
]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
>>> # Please refer to entrypoints/api_server.py for
>>> # the complete example.
>>>
>>> # initialize the engine and the example input
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
>>> example_input = {
>>> "input": "What is LLM?",
>>> "request_id": 0,
>>> }
>>>
>>> # start the generation
>>> results_generator = engine.encode(
>>> example_input["input"],
>>> PoolingParams(),
>>> example_input["request_id"])
>>>
>>> # get the results
>>> final_output = None
>>> async for request_output in results_generator:
>>> if await request.is_disconnected():
>>> # Abort the request if the client disconnects.
>>> await engine.abort(request_id)
>>> # Return or raise an error
>>> ...
>>> final_output = request_output
>>>
>>> # Process and return the final output
>>> ...
"""
async
for
output
in
self
.
process_request
(
request_id
,
prompt
,
pooling_params
,
prompt_token_ids
,
lora_request
,
multi_modal_data
,
):
yield
output
async
def
process_request
(
self
,
request_id
:
str
,
prompt
:
Optional
[
str
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
AsyncIterator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
arrival_time
=
time
.
time
()
stream
=
await
self
.
add_request
(
request_id
,
prompt
,
params
,
prompt_token_ids
=
prompt_token_ids
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
multi_modal_data
=
multi_modal_data
,
)
try
:
async
for
request_output
in
stream
:
async
for
request_output
in
stream
:
yield
request_output
yield
request_output
except
(
Exception
,
asyncio
.
CancelledError
)
as
e
:
except
(
Exception
,
asyncio
.
CancelledError
)
as
e
:
# If there is an exception or coroutine is cancelled, abort the
# request.
self
.
_abort
(
request_id
)
self
.
_abort
(
request_id
)
raise
e
raise
e
...
...
vllm/engine/llm_engine.py
View file @
e254497b
...
@@ -20,9 +20,12 @@ from vllm.executor.executor_base import ExecutorBase
...
@@ -20,9 +20,12 @@ from vllm.executor.executor_base import ExecutorBase
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
(
EmbeddingRequestOutput
,
RequestOutput
,
RequestOutputFactory
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
ExecuteModelRequest
,
MultiModalData
,
SamplerOutput
,
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
MultiModalData
,
PoolerOutput
,
SamplerOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
...
@@ -169,7 +172,8 @@ class LLMEngine:
...
@@ -169,7 +172,8 @@ class LLMEngine:
load_config
=
load_config
,
load_config
=
load_config
,
)
)
self
.
_initialize_kv_caches
()
if
not
self
.
model_config
.
embedding_mode
:
self
.
_initialize_kv_caches
()
# If usage stat is enabled, collect relevant info.
# If usage stat is enabled, collect relevant info.
if
is_usage_stats_enabled
():
if
is_usage_stats_enabled
():
...
@@ -354,7 +358,7 @@ class LLMEngine:
...
@@ -354,7 +358,7 @@ class LLMEngine:
self
,
self
,
request_id
:
str
,
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt
:
Optional
[
str
],
s
ampling
_p
arams
:
Samp
lingParams
,
params
:
Union
[
S
ampling
P
arams
,
Poo
lingParams
]
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
@@ -370,7 +374,8 @@ class LLMEngine:
...
@@ -370,7 +374,8 @@ class LLMEngine:
request_id: The unique ID of the request.
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
provided.
sampling_params: The sampling parameters for text generation.
params: Parameters for sampling or pooling. SamplingParams
for text generation. PoolingParams for pooling.
prompt_token_ids: The token IDs of the prompt. If None, we
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
arrival_time: The arrival time of the request. If None, we use
...
@@ -404,13 +409,6 @@ class LLMEngine:
...
@@ -404,13 +409,6 @@ class LLMEngine:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
"not enabled!"
)
max_logprobs
=
self
.
get_model_config
().
max_logprobs
if
(
sampling_params
.
logprobs
and
sampling_params
.
logprobs
>
max_logprobs
)
or
(
sampling_params
.
prompt_logprobs
and
sampling_params
.
prompt_logprobs
>
max_logprobs
):
raise
ValueError
(
f
"Cannot request more than "
f
"
{
max_logprobs
}
logprobs."
)
if
arrival_time
is
None
:
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
prompt_token_ids
=
self
.
encode_request
(
prompt_token_ids
=
self
.
encode_request
(
...
@@ -432,6 +430,50 @@ class LLMEngine:
...
@@ -432,6 +430,50 @@ class LLMEngine:
seq
=
Sequence
(
seq_id
,
prompt
,
prompt_token_ids
,
block_size
,
seq
=
Sequence
(
seq_id
,
prompt
,
prompt_token_ids
,
block_size
,
eos_token_id
,
lora_request
)
eos_token_id
,
lora_request
)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if
isinstance
(
params
,
SamplingParams
):
seq_group
=
self
.
_create_sequence_group_with_sampling
(
request_id
,
seq
,
params
,
arrival_time
,
lora_request
,
multi_modal_data
,
)
elif
isinstance
(
params
,
PoolingParams
):
seq_group
=
self
.
_create_sequence_group_with_pooling
(
request_id
,
seq
,
params
,
arrival_time
,
lora_request
,
multi_modal_data
,
)
else
:
raise
ValueError
(
"Either SamplingParams or PoolingParams must be provided."
)
# Add the sequence group to the scheduler.
self
.
scheduler
.
add_seq_group
(
seq_group
)
def
_create_sequence_group_with_sampling
(
self
,
request_id
:
str
,
seq
:
Sequence
,
sampling_params
:
SamplingParams
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
SequenceGroup
:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs
=
self
.
get_model_config
().
max_logprobs
if
(
sampling_params
.
logprobs
and
sampling_params
.
logprobs
>
max_logprobs
)
or
(
sampling_params
.
prompt_logprobs
and
sampling_params
.
prompt_logprobs
>
max_logprobs
):
raise
ValueError
(
f
"Cannot request more than "
f
"
{
max_logprobs
}
logprobs."
)
# Defensive copy of SamplingParams, which are used by the sampler,
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
# this doesn't deep-copy LogitsProcessor objects
sampling_params
=
sampling_params
.
clone
()
sampling_params
=
sampling_params
.
clone
()
...
@@ -443,11 +485,35 @@ class LLMEngine:
...
@@ -443,11 +485,35 @@ class LLMEngine:
self
.
generation_config_fields
)
self
.
generation_config_fields
)
# Create the sequence group.
# Create the sequence group.
seq_group
=
SequenceGroup
(
request_id
,
[
seq
],
sampling_params
,
seq_group
=
SequenceGroup
(
request_id
=
request_id
,
arrival_time
,
lora_request
,
multi_modal_data
)
seqs
=
[
seq
],
arrival_time
=
arrival_time
,
sampling_params
=
sampling_params
,
lora_request
=
lora_request
,
multi_modal_data
=
multi_modal_data
)
# Add the sequence group to the scheduler.
return
seq_group
self
.
scheduler
.
add_seq_group
(
seq_group
)
def
_create_sequence_group_with_pooling
(
self
,
request_id
:
str
,
seq
:
Sequence
,
pooling_params
:
PoolingParams
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
SequenceGroup
:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
pooling_params
=
pooling_params
.
clone
()
# Create the sequence group.
seq_group
=
SequenceGroup
(
request_id
=
request_id
,
seqs
=
[
seq
],
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
multi_modal_data
=
multi_modal_data
,
pooling_params
=
pooling_params
)
return
seq_group
def
abort_request
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
def
abort_request
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
"""Aborts a request(s) with the given ID.
"""Aborts a request(s) with the given ID.
...
@@ -484,13 +550,25 @@ class LLMEngine:
...
@@ -484,13 +550,25 @@ class LLMEngine:
"""Returns True if there are unfinished requests."""
"""Returns True if there are unfinished requests."""
return
self
.
scheduler
.
has_unfinished_seqs
()
return
self
.
scheduler
.
has_unfinished_seqs
()
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
List
[
EmbeddingSequenceGroupOutput
],
)
->
None
:
seq_group
.
embeddings
=
outputs
[
0
].
embeddings
for
seq
in
seq_group
.
get_seqs
():
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
def
_process_model_outputs
(
def
_process_model_outputs
(
self
,
self
,
output
:
List
[
SamplerOutput
],
output
:
List
[
Union
[
SamplerOutput
,
PoolerOutput
]
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
],
ignored_seq_groups
:
List
[
SequenceGroup
],
ignored_seq_groups
:
List
[
SequenceGroup
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
List
[
RequestOutput
]:
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]
]:
"""Apply the model output to the sequences in the scheduled seq groups.
"""Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
Returns RequestOutputs that can be returned to the client.
...
@@ -510,6 +588,9 @@ class LLMEngine:
...
@@ -510,6 +588,9 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
update_num_computed_tokens
(
seq_group
.
update_num_computed_tokens
(
scheduled_seq_group
.
token_chunk_size
)
scheduled_seq_group
.
token_chunk_size
)
if
self
.
model_config
.
embedding_mode
:
self
.
_process_sequence_group_outputs
(
seq_group
,
outputs
)
continue
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
outputs
)
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
outputs
)
if
seq_group_meta
.
do_sample
:
if
seq_group_meta
.
do_sample
:
...
@@ -519,18 +600,19 @@ class LLMEngine:
...
@@ -519,18 +600,19 @@ class LLMEngine:
self
.
scheduler
.
free_finished_seq_groups
()
self
.
scheduler
.
free_finished_seq_groups
()
# Create the outputs.
# Create the outputs.
request_outputs
:
List
[
RequestOutput
]
=
[]
request_outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
for
scheduled_seq_group
in
scheduled_seq_groups
:
for
scheduled_seq_group
in
scheduled_seq_groups
:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_output
=
RequestOutput
Factory
.
create
(
seq_group
)
request_outputs
.
append
(
request_output
)
request_outputs
.
append
(
request_output
)
for
seq_group
in
ignored_seq_groups
:
for
seq_group
in
ignored_seq_groups
:
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_output
=
RequestOutput
Factory
.
create
(
seq_group
)
request_outputs
.
append
(
request_output
)
request_outputs
.
append
(
request_output
)
return
request_outputs
return
request_outputs
def
step
(
self
)
->
List
[
RequestOutput
]:
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]
]:
"""Performs one decoding iteration and returns newly generated results.
"""Performs one decoding iteration and returns newly generated results.
.. figure:: https://i.imgur.com/sv2HssD.png
.. figure:: https://i.imgur.com/sv2HssD.png
...
@@ -570,7 +652,7 @@ class LLMEngine:
...
@@ -570,7 +652,7 @@ class LLMEngine:
>>> while True:
>>> while True:
>>> if example_inputs:
>>> if example_inputs:
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
>>> engine.add_request(str(req_id),
prompt,
sampling_params)
>>> engine.add_request(str(req_id),prompt,sampling_params)
>>>
>>>
>>> # continue the request processing
>>> # continue the request processing
>>> request_outputs = engine.step()
>>> request_outputs = engine.step()
...
@@ -637,12 +719,15 @@ class LLMEngine:
...
@@ -637,12 +719,15 @@ class LLMEngine:
# KV Cache Usage in %
# KV Cache Usage in %
num_total_gpu
=
self
.
cache_config
.
num_gpu_blocks
num_total_gpu
=
self
.
cache_config
.
num_gpu_blocks
num_free_gpu
=
self
.
scheduler
.
block_manager
.
get_num_free_gpu_blocks
()
gpu_cache_usage_sys
=
0.
gpu_cache_usage_sys
=
1.0
-
(
num_free_gpu
/
num_total_gpu
)
if
num_total_gpu
is
not
None
:
num_free_gpu
=
self
.
scheduler
.
block_manager
.
get_num_free_gpu_blocks
(
)
gpu_cache_usage_sys
=
1.0
-
(
num_free_gpu
/
num_total_gpu
)
num_total_cpu
=
self
.
cache_config
.
num_cpu_blocks
num_total_cpu
=
self
.
cache_config
.
num_cpu_blocks
cpu_cache_usage_sys
=
0.
cpu_cache_usage_sys
=
0.
if
num_total_cpu
>
0
:
if
num_total_cpu
is
not
None
and
num_total_cpu
>
0
:
num_free_cpu
=
self
.
scheduler
.
block_manager
.
get_num_free_cpu_blocks
(
num_free_cpu
=
self
.
scheduler
.
block_manager
.
get_num_free_cpu_blocks
(
)
)
cpu_cache_usage_sys
=
1.0
-
(
num_free_cpu
/
num_total_cpu
)
cpu_cache_usage_sys
=
1.0
-
(
num_free_cpu
/
num_total_cpu
)
...
@@ -716,8 +801,10 @@ class LLMEngine:
...
@@ -716,8 +801,10 @@ class LLMEngine:
seq
.
get_output_len
()
seq
.
get_output_len
()
for
seq
in
seq_group
.
get_finished_seqs
()
for
seq
in
seq_group
.
get_finished_seqs
()
])
])
best_of_requests
.
append
(
seq_group
.
sampling_params
.
best_of
)
if
seq_group
.
sampling_params
is
not
None
:
n_requests
.
append
(
seq_group
.
sampling_params
.
n
)
best_of_requests
.
append
(
seq_group
.
sampling_params
.
best_of
)
n_requests
.
append
(
seq_group
.
sampling_params
.
n
)
finished_reason_requests
.
extend
([
finished_reason_requests
.
extend
([
SequenceStatus
.
get_finished_reason
(
seq
.
status
)
SequenceStatus
.
get_finished_reason
(
seq
.
status
)
for
seq
in
seq_group
.
get_finished_seqs
()
for
seq
in
seq_group
.
get_finished_seqs
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment