Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e254497b
Unverified
Commit
e254497b
authored
May 11, 2024
by
Chang Su
Committed by
GitHub
May 11, 2024
Browse files
[Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734)
parent
4e121310
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
1028 additions
and
70 deletions
+1028
-70
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+121
-29
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+18
-2
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+35
-1
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+134
-0
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+14
-2
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+5
-5
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+56
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+4
-3
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+11
-1
vllm/model_executor/models/llama_embedding.py
vllm/model_executor/models/llama_embedding.py
+87
-0
vllm/model_executor/pooling_metadata.py
vllm/model_executor/pooling_metadata.py
+69
-0
vllm/outputs.py
vllm/outputs.py
+81
-1
vllm/pooling_params.py
vllm/pooling_params.py
+20
-0
vllm/sequence.py
vllm/sequence.py
+77
-12
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+3
-2
vllm/worker/embedding_model_runner.py
vllm/worker/embedding_model_runner.py
+266
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+18
-7
vllm/worker/worker.py
vllm/worker/worker.py
+9
-5
No files found.
vllm/entrypoints/llm.py
View file @
e254497b
...
...
@@ -6,13 +6,17 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
MultiModalData
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
logger
=
init_logger
(
__name__
)
class
LLM
:
"""An LLM for generating texts from given prompts and sampling parameters.
...
...
@@ -164,8 +168,89 @@ class LLM:
multi_modal_data: Multi modal data.
Returns:
A list of `RequestOutput` objects containing the generated
completions in the same order as the input prompts.
A list of `RequestOutput` objects containing the
generated completions in the same order as the input prompts.
"""
if
sampling_params
is
None
:
# Use default sampling params.
sampling_params
=
SamplingParams
()
requests_data
=
self
.
_validate_and_prepare_requests
(
prompts
,
sampling_params
,
prompt_token_ids
,
lora_request
,
multi_modal_data
,
)
# Add requests to the engine and run the engine
for
request_data
in
requests_data
:
self
.
_add_request
(
**
request_data
)
return
self
.
_run_engine
(
use_tqdm
)
def
encode
(
self
,
prompts
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
List
[
PoolingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
EmbeddingRequestOutput
]:
"""Generates the completions for the input prompts.
NOTE: This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: A list of prompts to generate completions for.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns:
A list of `EmbeddingRequestOutput` objects containing the
generated embeddings in the same order as the input prompts.
"""
if
pooling_params
is
None
:
# Use default pooling params.
pooling_params
=
PoolingParams
()
requests_data
=
self
.
_validate_and_prepare_requests
(
prompts
,
pooling_params
,
prompt_token_ids
,
lora_request
,
multi_modal_data
,
)
# Add requests to the engine and run the engine
for
request_data
in
requests_data
:
self
.
_add_request
(
**
request_data
)
return
self
.
_run_engine
(
use_tqdm
)
def
_validate_and_prepare_requests
(
self
,
prompts
:
Optional
[
Union
[
str
,
List
[
str
]]],
params
:
Union
[
Union
[
SamplingParams
,
PoolingParams
],
List
[
Union
[
SamplingParams
,
PoolingParams
]]],
# Unified parameter
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
dict
]:
"""Validates and prepares request data for adding to the engine.
Ensures prompts and token IDs are consistent, and returns a list of
dictionaries with request data for further processing.
"""
if
prompts
is
None
and
prompt_token_ids
is
None
:
raise
ValueError
(
"Either prompts or prompt_token_ids must be "
...
...
@@ -188,40 +273,43 @@ class LLM:
assert
prompt_token_ids
is
not
None
num_requests
=
len
(
prompt_token_ids
)
if
sampling_params
is
None
:
# Use default sampling params.
sampling_params
=
SamplingParams
()
elif
isinstance
(
sampling_params
,
list
)
and
len
(
sampling_params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and sampling_params "
if
isinstance
(
params
,
list
)
and
len
(
params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and params "
"must be the same."
)
if
multi_modal_data
:
multi_modal_data
.
data
=
multi_modal_data
.
data
.
to
(
torch
.
float16
)
# Add requests to the engine.
requests_data
=
[]
for
i
in
range
(
num_requests
):
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
token_ids
=
None
if
prompt_token_ids
is
None
else
prompt_token_ids
[
i
]
self
.
_add_request
(
multi_modal_item
=
MultiModalData
(
type
=
multi_modal_data
.
type
,
data
=
multi_modal_data
.
data
[
i
].
unsqueeze
(
0
),
)
if
multi_modal_data
else
None
requests_data
.
append
({
"prompt"
:
prompt
,
sampling_params
[
i
]
if
isinstance
(
sampling_params
,
list
)
else
sampling_params
,
"params"
:
params
[
i
]
if
isinstance
(
params
,
list
)
else
params
,
"prompt_token_ids"
:
token_ids
,
lora_request
=
lora_request
,
# Get ith image while maintaining the batch dim.
multi_modal_data
=
MultiModalData
(
type
=
multi_modal_data
.
type
,
data
=
multi_modal_data
.
data
[
i
].
unsqueeze
(
0
))
if
multi_modal_data
else
None
,
)
return
self
.
_run_engine
(
use_tqdm
)
"lora_request"
:
lora_request
,
"multi_modal_data"
:
multi_modal_item
,
})
return
requests_data
def
_add_request
(
self
,
prompt
:
Optional
[
str
],
s
ampling
_p
arams
:
Samp
lingParams
,
params
:
Union
[
S
ampling
P
arams
,
Poo
lingParams
]
,
prompt_token_ids
:
Optional
[
List
[
int
]],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
...
...
@@ -229,12 +317,14 @@ class LLM:
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_engine
.
add_request
(
request_id
,
prompt
,
sampling_
params
,
params
,
prompt_token_ids
,
lora_request
=
lora_request
,
multi_modal_data
=
multi_modal_data
)
def
_run_engine
(
self
,
use_tqdm
:
bool
)
->
List
[
RequestOutput
]:
def
_run_engine
(
self
,
use_tqdm
:
bool
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
# Initialize tqdm.
if
use_tqdm
:
num_requests
=
self
.
llm_engine
.
get_num_unfinished_requests
()
...
...
@@ -245,7 +335,7 @@ class LLM:
postfix
=
f
"Generation Speed:
{
0
:.
2
f
}
toks/s"
,
)
# Run the engine.
outputs
:
List
[
RequestOutput
]
=
[]
outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]
]
=
[]
total_toks
=
0
while
self
.
llm_engine
.
has_unfinished_requests
():
step_outputs
=
self
.
llm_engine
.
step
()
...
...
@@ -253,10 +343,12 @@ class LLM:
if
output
.
finished
:
outputs
.
append
(
output
)
if
use_tqdm
:
total_toks
+=
(
sum
(
len
(
stp
.
token_ids
)
for
stp
in
output
.
outputs
))
spd
=
total_toks
/
pbar
.
format_dict
[
"elapsed"
]
pbar
.
postfix
=
f
"Generation Speed:
{
spd
:.
2
f
}
toks/s"
if
isinstance
(
output
,
RequestOutput
):
# Calculate tokens only for RequestOutput
total_toks
+=
sum
(
len
(
stp
.
token_ids
)
for
stp
in
output
.
outputs
)
spd
=
total_toks
/
pbar
.
format_dict
[
"elapsed"
]
pbar
.
postfix
=
f
"Generation Speed:
{
spd
:.
2
f
}
toks/s"
pbar
.
update
(
1
)
if
use_tqdm
:
pbar
.
close
()
...
...
vllm/entrypoints/openai/api_server.py
View file @
e254497b
...
...
@@ -22,9 +22,11 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
ChatCompletionResponse
,
CompletionRequest
,
ErrorResponse
)
CompletionRequest
,
EmbeddingRequest
,
ErrorResponse
)
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
...
...
@@ -32,6 +34,8 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat
:
OpenAIServingChat
openai_serving_completion
:
OpenAIServingCompletion
openai_serving_embedding
:
OpenAIServingEmbedding
logger
=
init_logger
(
__name__
)
_running_tasks
:
Set
[
asyncio
.
Task
]
=
set
()
...
...
@@ -123,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return
JSONResponse
(
content
=
generator
.
model_dump
())
@
app
.
post
(
"/v1/embeddings"
)
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
generator
=
await
openai_serving_embedding
.
create_embedding
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
else
:
return
JSONResponse
(
content
=
generator
.
model_dump
())
if
__name__
==
"__main__"
:
args
=
parse_args
()
...
...
@@ -190,7 +205,8 @@ if __name__ == "__main__":
args
.
chat_template
)
openai_serving_completion
=
OpenAIServingCompletion
(
engine
,
model_config
,
served_model_names
,
args
.
lora_modules
)
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine
,
model_config
,
served_model_names
)
app
.
root_path
=
args
.
root_path
uvicorn
.
run
(
app
,
host
=
args
.
host
,
...
...
vllm/entrypoints/openai/protocol.py
View file @
e254497b
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import
time
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
import
torch
from
openai.types.chat
import
ChatCompletionMessageParam
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
model_validator
from
typing_extensions
import
Annotated
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
random_uuid
...
...
@@ -363,6 +364,24 @@ class CompletionRequest(OpenAIBaseModel):
return
data
class
EmbeddingRequest
(
BaseModel
):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
model
:
str
input
:
Union
[
List
[
int
],
List
[
List
[
int
]],
str
,
List
[
str
]]
encoding_format
:
Optional
[
str
]
=
Field
(
'float'
,
pattern
=
'^(float|base64)$'
)
dimensions
:
Optional
[
int
]
=
None
user
:
Optional
[
str
]
=
None
# doc: begin-embedding-pooling-params
additional_data
:
Optional
[
Any
]
=
None
# doc: end-embedding-pooling-params
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
class
LogProbs
(
OpenAIBaseModel
):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
...
...
@@ -416,6 +435,21 @@ class CompletionStreamResponse(OpenAIBaseModel):
usage
:
Optional
[
UsageInfo
]
=
Field
(
default
=
None
)
class
EmbeddingResponseData
(
BaseModel
):
index
:
int
object
:
str
=
"embedding"
embedding
:
List
[
float
]
class
EmbeddingResponse
(
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"cmpl-
{
random_uuid
()
}
"
)
object
:
str
=
"list"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
data
:
List
[
EmbeddingResponseData
]
usage
:
UsageInfo
class
ChatMessage
(
OpenAIBaseModel
):
role
:
str
content
:
str
...
...
vllm/entrypoints/openai/serving_embedding.py
0 → 100644
View file @
e254497b
import
time
from
typing
import
AsyncIterator
,
List
,
Tuple
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
EmbeddingRequest
,
EmbeddingResponse
,
EmbeddingResponseData
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_completion
import
parse_prompt_format
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.logger
import
init_logger
from
vllm.outputs
import
EmbeddingRequestOutput
from
vllm.utils
import
merge_async_iterators
,
random_uuid
logger
=
init_logger
(
__name__
)
TypeTokenIDs
=
List
[
int
]
def
request_output_to_embedding_response
(
final_res_batch
:
List
[
EmbeddingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
)
->
EmbeddingResponse
:
data
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
assert
final_res
is
not
None
prompt_token_ids
=
final_res
.
prompt_token_ids
embedding_data
=
EmbeddingResponseData
(
index
=
idx
,
embedding
=
final_res
.
outputs
.
embedding
)
data
.
append
(
embedding_data
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
total_tokens
=
num_prompt_tokens
,
)
return
EmbeddingResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
data
=
data
,
usage
=
usage
,
)
class
OpenAIServingEmbedding
(
OpenAIServing
):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
]):
super
().
__init__
(
engine
=
engine
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
lora_modules
=
None
)
self
.
_check_embedding_mode
(
model_config
.
embedding_mode
)
async
def
create_embedding
(
self
,
request
:
EmbeddingRequest
,
raw_request
:
Request
):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
# Return error for unsupported features.
if
request
.
encoding_format
==
"base64"
:
return
self
.
create_error_response
(
"base64 encoding is not currently supported"
)
if
request
.
dimensions
is
not
None
:
return
self
.
create_error_response
(
"dimensions is currently not supported"
)
model_name
=
request
.
model
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
created_time
=
int
(
time
.
monotonic
())
# Schedule the request and get the result generator.
generators
=
[]
try
:
prompt_is_tokens
,
prompts
=
parse_prompt_format
(
request
.
input
)
pooling_params
=
request
.
to_pooling_params
()
for
i
,
prompt
in
enumerate
(
prompts
):
if
prompt_is_tokens
:
prompt_formats
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt_ids
=
prompt
)
else
:
prompt_formats
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
prompt
)
prompt_ids
,
prompt_text
=
prompt_formats
generators
.
append
(
self
.
engine
.
generate
(
prompt_text
,
pooling_params
,
f
"
{
request_id
}
-
{
i
}
"
,
prompt_token_ids
=
prompt_ids
))
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
result_generator
:
AsyncIterator
[
Tuple
[
int
,
EmbeddingRequestOutput
]]
=
merge_async_iterators
(
*
generators
)
# Non-streaming response
final_res_batch
:
EmbeddingRequestOutput
=
[
None
]
*
len
(
prompts
)
async
for
i
,
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
"Client disconnected"
)
final_res_batch
[
i
]
=
res
response
=
request_output_to_embedding_response
(
final_res_batch
,
request_id
,
created_time
,
model_name
)
return
response
def
_check_embedding_mode
(
self
,
embedding_mode
:
bool
):
if
not
embedding_mode
:
logger
.
warning
(
"embedding_mode is False. Embedding API will not work."
)
else
:
logger
.
info
(
"Activating the server engine with embedding enabled."
)
vllm/entrypoints/openai/serving_engine.py
View file @
e254497b
...
...
@@ -9,7 +9,8 @@ from typing_extensions import Annotated
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
,
ErrorResponse
,
CompletionRequest
,
EmbeddingRequest
,
ErrorResponse
,
LogProbs
,
ModelCard
,
ModelList
,
ModelPermission
)
from
vllm.logger
import
init_logger
...
...
@@ -165,7 +166,8 @@ class OpenAIServing:
def
_validate_prompt_and_tokenize
(
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
],
prompt
:
Optional
[
str
]
=
None
,
prompt_ids
:
Optional
[
List
[
int
]]
=
None
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
...
...
@@ -191,6 +193,16 @@ class OpenAIServing:
prompt_ids
)
token_num
=
len
(
input_ids
)
# Note: EmbeddingRequest doesn't have max_tokens
if
isinstance
(
request
,
EmbeddingRequest
):
if
token_num
>
self
.
max_model_len
:
raise
ValueError
(
f
"This model's maximum context length is "
f
"
{
self
.
max_model_len
}
tokens. However, you requested "
f
"
{
token_num
}
tokens in the input for embedding "
f
"generation. Please reduce the length of the input."
,
)
return
input_ids
,
input_text
if
request
.
max_tokens
is
None
:
if
token_num
>=
self
.
max_model_len
:
raise
ValueError
(
...
...
vllm/executor/gpu_executor.py
View file @
e254497b
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
,
SamplerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
from
vllm.worker.worker_base
import
WorkerWrapperBase
...
...
@@ -123,8 +123,8 @@ class GPUExecutor(ExecutorBase):
self
.
driver_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
Union
[
SamplerOutput
,
PoolerOutput
]
]:
output
=
self
.
driver_worker
.
execute_model
(
execute_model_req
)
return
output
...
...
@@ -150,7 +150,7 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
async
def
execute_model_async
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
List
[
SamplerOutput
]:
)
->
List
[
Union
[
SamplerOutput
,
PoolerOutput
]
]:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
execute_model_req
=
execute_model_req
,
)
return
output
vllm/model_executor/layers/pooler.py
0 → 100644
View file @
e254497b
from
enum
import
IntEnum
import
torch
import
torch.nn
as
nn
from
vllm.model_executor.pooling_metadata
import
(
PoolingMetadata
,
PoolingTensors
)
from
vllm.sequence
import
EmbeddingSequenceGroupOutput
,
PoolerOutput
class
PoolingType
(
IntEnum
):
"""Enumeration for different types of pooling methods."""
LAST
=
0
class
Pooler
(
nn
.
Module
):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
normalize: Whether to normalize the pooled data.
"""
def
__init__
(
self
,
pooling_type
:
PoolingType
,
normalize
:
bool
):
super
().
__init__
()
self
.
pooling_type
=
pooling_type
self
.
normalize
=
normalize
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
"""Pools specific information from hidden states based on metadata."""
prompt_lens
=
PoolingTensors
.
from_pooling_metadata
(
pooling_metadata
,
hidden_states
.
device
).
prompt_lens
if
self
.
pooling_type
==
PoolingType
.
LAST
:
last_token_flat_indices
=
torch
.
cumsum
(
prompt_lens
,
dim
=
0
)
-
1
pooled_data
=
hidden_states
[
last_token_flat_indices
]
else
:
raise
ValueError
(
f
"Invalid pooling type:
{
self
.
pooling_type
}
"
)
if
self
.
normalize
:
pooled_data
=
nn
.
functional
.
normalize
(
pooled_data
,
p
=
2
,
dim
=
1
)
pooled_outputs
=
[
EmbeddingSequenceGroupOutput
(
data
.
tolist
())
for
data
in
pooled_data
]
return
PoolerOutput
(
outputs
=
pooled_outputs
)
vllm/model_executor/layers/sampler.py
View file @
e254497b
...
...
@@ -10,8 +10,9 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors
,
SequenceGroupToSample
)
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
(
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceGroupOutput
,
SequenceOutput
)
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceOutput
)
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType
=
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
...
...
@@ -1019,7 +1020,7 @@ def _build_sampler_output(
seq_outputs
.
append
(
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
logprobs
))
sampler_output
.
append
(
SequenceGroupOutput
(
seq_outputs
,
group_prompt_logprobs
))
Completion
SequenceGroupOutput
(
seq_outputs
,
group_prompt_logprobs
))
# If not specified, store None values in SamplerOutput.
if
on_device_tensors
is
not
None
:
...
...
vllm/model_executor/models/__init__.py
View file @
e254497b
...
...
@@ -9,7 +9,7 @@ from vllm.utils import is_hip
logger
=
init_logger
(
__name__
)
# Architecture -> (module, class).
_MODELS
=
{
_GENERATION
_MODELS
=
{
"AquilaModel"
:
(
"llama"
,
"LlamaForCausalLM"
),
"AquilaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
# AquilaChat2
"BaiChuanForCausalLM"
:
(
"baichuan"
,
"BaiChuanForCausalLM"
),
# baichuan-7b
...
...
@@ -58,6 +58,12 @@ _MODELS = {
"XverseForCausalLM"
:
(
"xverse"
,
"XverseForCausalLM"
),
}
_EMBEDDING_MODELS
=
{
"MistralModel"
:
(
"llama_embedding"
,
"LlamaEmbeddingModel"
),
}
_MODELS
=
{
**
_GENERATION_MODELS
,
**
_EMBEDDING_MODELS
}
# Architecture -> type.
# out of tree models
_OOT_MODELS
:
Dict
[
str
,
Type
[
nn
.
Module
]]
=
{}
...
...
@@ -114,6 +120,10 @@ class ModelRegistry:
global
_OOT_MODELS
_OOT_MODELS
[
model_arch
]
=
model_cls
@
staticmethod
def
is_embedding_model
(
model_arch
:
str
)
->
bool
:
return
model_arch
in
_EMBEDDING_MODELS
__all__
=
[
"ModelRegistry"
,
...
...
vllm/model_executor/models/llama_embedding.py
0 → 100644
View file @
e254497b
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
vllm.attention
import
AttentionMetadata
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
PoolerOutput
class
LlamaEmbeddingModel
(
nn
.
Module
):
"""A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def
__init__
(
self
,
**
kwargs
,
)
->
None
:
super
().
__init__
()
self
.
model
=
LlamaModel
(
**
kwargs
)
self
.
_pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
model
.
forward
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
inputs_embeds
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
model
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/pooling_metadata.py
0 → 100644
View file @
e254497b
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Tuple
import
torch
from
vllm.pooling_params
import
PoolingParams
from
vllm.utils
import
is_pin_memory_available
class
PoolingMetadata
:
"""Metadata for pooling operations in the Pooler layer.
This class holds the necessary information for pooling operations,
providing context for how to perform pooling and other related operations.
Attributes:
seq_groups: List of (seq_ids, pooling_params).
seq_data: A mapping of sequence ID to additional sequence data.
prompt_lens: List of the lengths of each prompt.
"""
def
__init__
(
self
,
seq_groups
:
List
[
Tuple
[
List
[
int
],
PoolingParams
]],
seq_data
:
Dict
[
int
,
Any
],
# Specific data related to sequences
prompt_lens
:
List
[
int
],
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
seq_data
=
seq_data
self
.
prompt_lens
=
prompt_lens
def
__repr__
(
self
)
->
str
:
return
(
"PoolingMetadata("
f
"seq_groups=
{
self
.
seq_groups
}
, "
f
"seq_data=
{
self
.
seq_data
}
, "
f
"prompt_lens=
{
self
.
prompt_lens
}
)"
)
@
dataclass
class
PoolingTensors
:
"""Tensors for pooling."""
prompt_lens
:
torch
.
Tensor
@
classmethod
def
from_pooling_metadata
(
cls
,
pooling_metadata
:
"PoolingMetadata"
,
device
:
torch
.
device
,
)
->
"PoolingTensors"
:
"""
Create PoolingTensors from PoolingMetadata.
Args:
pooling_metadata: PoolingMetadata instance to convert.
device: Device to store the tensors.
"""
# Convert prompt lengths to tensor
pin_memory
=
is_pin_memory_available
()
prompt_lens_t
=
torch
.
tensor
(
pooling_metadata
.
prompt_lens
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
return
cls
(
prompt_lens
=
prompt_lens_t
.
to
(
device
=
device
,
non_blocking
=
True
),
)
vllm/outputs.py
View file @
e254497b
...
...
@@ -57,8 +57,27 @@ class CompletionOutput:
f
"stop_reason=
{
self
.
stop_reason
}
)"
)
class
EmbeddingOutput
:
"""The output data of one completion output of a request.
Args:
embedding: The embedding vector, which is a list of floats. The
length of vector depends on the model as listed in the embedding guide.
"""
def
__init__
(
self
,
embedding
:
List
[
float
],
)
->
None
:
self
.
embedding
=
embedding
def
__repr__
(
self
)
->
str
:
return
(
f
"EmbeddingOutput("
f
"embedding=
{
len
(
self
.
embedding
)
}
"
)
class
RequestOutput
:
"""The output data of a request to the LLM.
"""The output data of a
completion
request to the LLM.
Args:
request_id: The unique ID of the request.
...
...
@@ -93,6 +112,9 @@ class RequestOutput:
@
classmethod
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
)
->
"RequestOutput"
:
if
seq_group
.
sampling_params
is
None
:
raise
ValueError
(
"Sampling parameters are missing for a CompletionRequest."
)
seqs
=
seq_group
.
get_seqs
()
if
len
(
seqs
)
==
1
:
top_n_seqs
=
seqs
...
...
@@ -148,3 +170,61 @@ class RequestOutput:
f
"finished=
{
self
.
finished
}
, "
f
"metrics=
{
self
.
metrics
}
, "
f
"lora_request=
{
self
.
lora_request
}
)"
)
class
EmbeddingRequestOutput
:
"""
The output data of an embedding request to the LLM.
Args:
request_id (str): A unique identifier for the embedding request.
outputs (EmbeddingOutput): The embedding results for the given input.
prompt_token_ids (List[int]): A list of token IDs used in the prompt.
finished (bool): A flag indicating whether the embedding is completed.
"""
def
__init__
(
self
,
request_id
:
str
,
outputs
:
'EmbeddingOutput'
,
prompt_token_ids
:
List
[
int
],
finished
:
bool
):
self
.
request_id
=
request_id
self
.
prompt_token_ids
=
prompt_token_ids
self
.
finished
=
finished
self
.
outputs
=
outputs
@
classmethod
def
from_seq_group
(
cls
,
seq_group
:
'SequenceGroup'
)
->
"EmbeddingRequestOutput"
:
if
seq_group
.
embeddings
is
None
:
raise
ValueError
(
"Embeddings are missing in seq_group for EmbeddingRequest."
)
output
=
EmbeddingOutput
(
seq_group
.
embeddings
)
prompt_token_ids
=
seq_group
.
prompt_token_ids
finished
=
seq_group
.
is_finished
()
return
cls
(
seq_group
.
request_id
,
output
,
prompt_token_ids
,
finished
)
def
__repr__
(
self
):
"""
Returns a string representation of an EmbeddingRequestOutput instance.
The representation includes the request_id and the number of outputs,
providing a quick overview of the embedding request's results.
Returns:
str: A string representation of the EmbeddingRequestOutput instance.
"""
return
(
f
"EmbeddingRequestOutput(request_id='
{
self
.
request_id
}
', "
f
"outputs=
{
repr
(
self
.
outputs
)
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"finished=
{
self
.
finished
}
)"
)
class
RequestOutputFactory
:
@
staticmethod
def
create
(
seq_group
):
# Determine the type based on a condition, for example:
if
hasattr
(
seq_group
,
'embeddings'
)
and
seq_group
.
embeddings
is
not
None
:
return
EmbeddingRequestOutput
.
from_seq_group
(
seq_group
)
else
:
return
RequestOutput
.
from_seq_group
(
seq_group
)
vllm/pooling_params.py
0 → 100644
View file @
e254497b
from
typing
import
Any
,
Optional
class
PoolingParams
:
"""Pooling parameters for pooling.
Attributes:
additional_data: Any additional data needed for pooling.
"""
def
__init__
(
self
,
additional_data
:
Optional
[
Any
]
=
None
):
self
.
additional_data
=
additional_data
def
clone
(
self
)
->
"PoolingParams"
:
"""Returns a deep copy of the PoolingParams instance."""
return
PoolingParams
(
additional_data
=
self
.
additional_data
,
)
def
__repr__
(
self
)
->
str
:
return
(
f
"PoolingParams("
f
"additional_metadata=
{
self
.
additional_data
}
)"
)
vllm/sequence.py
View file @
e254497b
"""Sequence and its related classes."""
import
copy
import
enum
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
vllm.block
import
LogicalTokenBlock
from
vllm.lora.request
import
LoRARequest
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
if
TYPE_CHECKING
:
...
...
@@ -375,12 +377,12 @@ class SequenceGroupState:
class
MultiModalData
:
"""Multi modal request.
Args:
type: The data type.
data: The actual data.
The required shape and semantic meaning of it depends on the vision
language config of the hosted model.
language config of the hosted model.
See `VisionLanguageConfig` in `config.py`.
"""
...
...
@@ -402,16 +404,22 @@ class SequenceGroup:
arrival_time: The arrival time of the request.
lora_request: LoRA request.
multi_modal_data: Multi modal data associated with the request.
embeddings: The embeddings vectors of the prompt of the sequence group
for an embedding model.
pooling_params: The pooling parameters used to generate the pooling
for an embedding model.
"""
def
__init__
(
self
,
request_id
:
str
,
seqs
:
List
[
Sequence
],
sampling_params
:
SamplingParams
,
arrival_time
:
float
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
embeddings
:
Optional
[
List
[
float
]]
=
None
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
)
->
None
:
self
.
request_id
=
request_id
self
.
seqs_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
...
...
@@ -425,6 +433,8 @@ class SequenceGroup:
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
state
=
SequenceGroupState
()
self
.
multi_modal_data
=
multi_modal_data
self
.
embeddings
=
embeddings
self
.
pooling_params
=
pooling_params
@
property
def
prompt
(
self
)
->
str
:
...
...
@@ -479,12 +489,13 @@ class SequenceGroup:
def
get_max_num_running_seqs
(
self
)
->
int
:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
if
self
.
sampling_params
.
use_beam_search
:
if
self
.
sampling_params
and
self
.
sampling_params
.
use_beam_search
:
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
return
self
.
sampling_params
.
best_of
else
:
if
self
.
sampling_params
.
best_of
>
self
.
num_seqs
():
if
(
self
.
sampling_params
and
self
.
sampling_params
.
best_of
>
self
.
num_seqs
()):
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences running.
...
...
@@ -555,7 +566,7 @@ class SequenceGroup:
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
get_seqs
())
def
is_prefill
(
self
)
->
bool
:
# Every sequence
s
should be in the same stage.
# Every sequence should be in the same stage.
return
self
.
get_seqs
()[
0
].
is_prefill
()
def
__repr__
(
self
)
->
str
:
...
...
@@ -594,6 +605,7 @@ class SequenceGroupMetadata:
sampling_params
:
SamplingParams
,
block_tables
:
Dict
[
int
,
List
[
int
]],
do_sample
:
bool
=
True
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
token_chunk_size
:
Optional
[
int
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
...
...
@@ -605,6 +617,7 @@ class SequenceGroupMetadata:
self
.
seq_data
=
seq_data
self
.
sampling_params
=
sampling_params
self
.
block_tables
=
block_tables
self
.
pooling_params
=
pooling_params
self
.
lora_request
=
lora_request
self
.
computed_block_nums
=
computed_block_nums
self
.
multi_modal_data
=
multi_modal_data
...
...
@@ -669,8 +682,20 @@ class SequenceOutput:
return
equal
and
log_probs_equal
class
SequenceGroupOutput
:
"""The model output associated with a sequence group."""
class
SequenceGroupOutput
(
ABC
):
"""The base class for model outputs associated with a sequence group."""
@
abstractmethod
def
__repr__
(
self
)
->
str
:
pass
@
abstractmethod
def
__eq__
(
self
,
other
:
object
)
->
bool
:
pass
class
CompletionSequenceGroupOutput
(
SequenceGroupOutput
):
"""The model output associated with a completion sequence group."""
def
__init__
(
self
,
...
...
@@ -682,26 +707,45 @@ class SequenceGroupOutput:
self
.
prompt_logprobs
=
prompt_logprobs
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceGroupOutput(samples=
{
self
.
samples
}
, "
return
(
f
"
Completion
SequenceGroupOutput(samples=
{
self
.
samples
}
, "
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
)"
)
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
SequenceGroupOutput
):
if
not
isinstance
(
other
,
Completion
SequenceGroupOutput
):
raise
NotImplementedError
()
return
(
self
.
samples
==
other
.
samples
and
self
.
prompt_logprobs
==
other
.
prompt_logprobs
)
class
EmbeddingSequenceGroupOutput
(
SequenceGroupOutput
):
"""The model output associated with an embedding sequence group."""
def
__init__
(
self
,
embeddings
:
List
[
float
],
)
->
None
:
self
.
embeddings
=
embeddings
def
__repr__
(
self
)
->
str
:
return
(
f
"EmbeddingSequenceGroupOutput("
f
"embeddings_shape=
{
len
(
self
.
embeddings
)
}
)"
)
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
EmbeddingSequenceGroupOutput
):
raise
NotImplementedError
()
return
self
.
embeddings
==
other
.
embeddings
@
dataclass
class
SamplerOutput
:
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
This datastructure implements methods so it can be used like a list, but
This data
structure implements methods
,
so it can be used like a list, but
also has optional fields for device tensors.
"""
outputs
:
List
[
SequenceGroupOutput
]
outputs
:
List
[
Completion
SequenceGroupOutput
]
# On-device tensor containing probabilities of each token.
sampled_token_probs
:
Optional
[
"torch.Tensor"
]
=
None
...
...
@@ -742,6 +786,27 @@ class SamplerOutput:
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
)"
)
@
dataclass
class
PoolerOutput
:
"""The output from a pooling operation in the embedding model."""
outputs
:
List
[
EmbeddingSequenceGroupOutput
]
spec_decode_worker_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
def
__getitem__
(
self
,
idx
:
int
):
return
self
.
outputs
[
idx
]
def
__setitem__
(
self
,
idx
:
int
,
value
):
self
.
outputs
[
idx
]
=
value
def
__len__
(
self
):
return
len
(
self
.
outputs
)
def
__eq__
(
self
,
other
:
object
):
return
isinstance
(
other
,
self
.
__class__
)
and
self
.
outputs
==
other
.
outputs
@
dataclass
class
ExecuteModelRequest
:
"""The model execution request."""
...
...
vllm/spec_decode/util.py
View file @
e254497b
...
...
@@ -4,7 +4,8 @@ from typing import Dict, List, Tuple
import
torch
from
vllm.sequence
import
(
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceOutput
)
SeqId
=
int
...
...
@@ -94,7 +95,7 @@ def create_sequence_group_output(
for
topk_logprob_index
,
_
in
enumerate
(
topk_token_ids
)
})
return
SequenceGroupOutput
(
return
Completion
SequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
seq_id
,
output_token
=
token_id
,
...
...
vllm/worker/embedding_model_runner.py
0 → 100644
View file @
e254497b
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
import
torch
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.pooling_params
import
PoolingParams
from
vllm.sequence
import
PoolerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.model_runner
import
BatchType
,
ModelRunner
logger
=
init_logger
(
__name__
)
class
EmbeddingModelRunner
(
ModelRunner
):
def
__init__
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
vision_language_config
:
Optional
[
VisionLanguageConfig
]
=
None
,
):
super
().
__init__
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
cache_config
,
load_config
,
lora_config
=
lora_config
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
is_driver_worker
,
vision_language_config
=
vision_language_config
)
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
kv_caches
:
List
[
torch
.
Tensor
],
)
->
Optional
[
PoolerOutput
]:
(
input_tokens
,
input_positions
,
attn_metadata
,
pooling_metadata
,
lora_requests
,
lora_mapping
,
multi_modal_input
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
if
self
.
lora_config
:
self
.
set_active_loras
(
lora_requests
,
lora_mapping
)
# Currently cuda graph is only supported by the decode phase.
prefill_meta
=
attn_metadata
.
prefill_metadata
decode_meta
=
attn_metadata
.
decode_metadata
if
prefill_meta
is
None
and
decode_meta
.
use_cuda_graph
:
graph_batch_size
=
input_tokens
.
shape
[
0
]
model_executable
=
self
.
graph_runners
[
graph_batch_size
]
else
:
model_executable
=
self
.
model
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
kv_caches
=
[
None
]
*
num_layers
execute_model_kwargs
=
{
"input_ids"
:
input_tokens
,
"positions"
:
input_positions
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
attn_metadata
,
}
if
self
.
vision_language_config
:
execute_model_kwargs
.
update
({
"image_input"
:
multi_modal_input
})
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
return
self
.
model
.
pooler
(
hidden_states
=
hidden_states
,
pooling_metadata
=
pooling_metadata
)
def
prepare_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
PoolingMetadata
,
Set
[
LoRARequest
],
LoRAMapping
,
torch
.
Tensor
]:
if
self
.
is_driver_worker
:
prefill_reqs
=
[]
decode_reqs
=
[]
for
seq_group_meta
in
seq_group_metadata_list
:
if
seq_group_meta
.
is_prompt
:
prefill_reqs
.
append
(
seq_group_meta
)
else
:
decode_reqs
.
append
(
seq_group_meta
)
# Prepare input tensors.
(
input_tokens
,
input_positions
,
prefill_attn_metadata
,
prompt_lens
,
subquery_lens
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
,
multi_modal_input
,
slot_mapping
,
)
=
self
.
_prepare_prompt
(
prefill_reqs
)
(
decode_input_tokens
,
decode_input_positions
,
decode_attn_metadata
,
decode_lora_index_mapping
,
decode_lora_prompt_mapping
,
decode_lora_requests
,
decode_slot_mapping
,
)
=
self
.
_prepare_decode
(
decode_reqs
)
# Prepare PoolingMetadata
pooling_metadata
=
self
.
_prepare_pooling
(
seq_group_metadata_list
,
prompt_lens
)
if
not
self
.
scheduler_config
.
chunked_prefill_enabled
:
assert
(
len
(
prefill_reqs
)
and
len
(
decode_reqs
))
==
0
num_prefills
=
len
(
prompt_lens
)
num_prefill_tokens
=
len
(
input_tokens
)
num_decode_tokens
=
len
(
decode_input_tokens
)
# Coalesce tensors. Note that attn_metadata is currently not
# coalesced for simplicity.
input_tokens
.
extend
(
decode_input_tokens
)
input_positions
.
extend
(
decode_input_positions
)
slot_mapping
.
extend
(
decode_slot_mapping
)
lora_index_mapping
.
extend
(
decode_lora_index_mapping
)
lora_prompt_mapping
.
extend
(
decode_lora_prompt_mapping
)
lora_requests
.
update
(
decode_lora_requests
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
if
self
.
lora_config
:
lora_mapping
=
LoRAMapping
(
lora_index_mapping
,
lora_prompt_mapping
,
)
else
:
lora_mapping
=
None
# Broadcast the metadata.
# If batch contains both prefill and decode, it sends 2 broadcasts.
# If it only contains 1 type, it triggers a single broadcast.
if
(
prefill_attn_metadata
is
not
None
and
decode_attn_metadata
is
not
None
):
batch_type
=
BatchType
.
MIXED
elif
prefill_attn_metadata
is
not
None
:
batch_type
=
BatchType
.
PREFILL
else
:
batch_type
=
BatchType
.
DECODE
metadata_dict
=
{
"input_tokens"
:
input_tokens
,
"input_positions"
:
input_positions
,
"lora_requests"
:
lora_requests
,
"lora_mapping"
:
lora_mapping
,
"multi_modal_input"
:
multi_modal_input
,
"num_prefill_tokens"
:
num_prefill_tokens
,
"num_decode_tokens"
:
num_decode_tokens
,
"slot_mapping"
:
slot_mapping
,
"num_prefills"
:
num_prefills
,
"batch_type"
:
batch_type
,
}
if
prefill_attn_metadata
is
not
None
:
metadata_dict
.
update
(
prefill_attn_metadata
.
asdict_zerocopy
())
else
:
assert
decode_attn_metadata
is
not
None
metadata_dict
.
update
(
decode_attn_metadata
.
asdict_zerocopy
())
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
# Broadcast decode attn metadata for mixed batch type.
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
# We can potentially reduce the overhead by coelescing tensors.
if
batch_type
==
BatchType
.
MIXED
:
assert
decode_attn_metadata
is
not
None
metadata_dict
=
decode_attn_metadata
.
asdict_zerocopy
()
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
else
:
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
input_tokens
=
metadata_dict
.
pop
(
"input_tokens"
)
input_positions
=
metadata_dict
.
pop
(
"input_positions"
)
slot_mapping
=
metadata_dict
.
pop
(
"slot_mapping"
)
num_prefills
=
metadata_dict
.
pop
(
"num_prefills"
)
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
multi_modal_input
=
metadata_dict
.
pop
(
"multi_modal_input"
)
num_prefill_tokens
=
metadata_dict
.
pop
(
"num_prefill_tokens"
)
num_decode_tokens
=
metadata_dict
.
pop
(
"num_decode_tokens"
)
batch_type
=
metadata_dict
.
pop
(
"batch_type"
)
# Create an attention metadata.
prefill_attn_metadata
=
None
decode_attn_metadata
=
None
if
batch_type
==
BatchType
.
PREFILL
or
batch_type
==
BatchType
.
MIXED
:
prefill_attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
else
:
decode_attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
pooling_metadata
=
PoolingMetadata
(
seq_groups
=
None
,
seq_data
=
None
,
prompt_lens
=
None
)
# if it is a mixed batch, decode attn_metadata is broadcasted
# separately.
if
batch_type
==
BatchType
.
MIXED
:
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
decode_attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
attn_metadata
=
AttentionMetadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
slot_mapping
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
prefill_metadata
=
prefill_attn_metadata
,
decode_metadata
=
decode_attn_metadata
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
pooling_metadata
,
lora_requests
,
lora_mapping
,
multi_modal_input
)
def
_prepare_pooling
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt_lens
:
List
[
int
],
)
->
PoolingMetadata
:
"""Prepare PoolingMetadata for the sequence group metadata list."""
seq_groups
:
List
[
Tuple
[
List
[
int
],
PoolingParams
]]
=
[]
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
pooling_params
=
seq_group_metadata
.
pooling_params
seq_groups
.
append
((
seq_ids
,
pooling_params
))
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
pooling_metadata
=
PoolingMetadata
(
seq_groups
=
seq_groups
,
seq_data
=
seq_data
,
prompt_lens
=
prompt_lens
,
)
return
pooling_metadata
vllm/worker/model_runner.py
View file @
e254497b
import
time
from
enum
import
IntEnum
from
typing
import
Dict
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -287,18 +287,18 @@ class ModelRunner:
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
lora_index_mapping
+=
[
lora_id
]
*
(
seq_len
-
context_len
)
lora_prompt_mapping
.
extend
(
[
lora_id
]
*
(
seq_len
-
context_len
if
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
lora_prompt_mapping
.
extend
([
lora_id
]
*
(
seq_len
-
context_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
if
seq_group_metadata
.
multi_modal_data
:
multi_modal_input_list
.
append
(
seq_group_metadata
.
multi_modal_data
.
data
)
if
seq_group_metadata
.
block_tables
is
None
:
if
_is_block_tables_empty
(
seq_group_metadata
.
block_tables
)
:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping
.
extend
([
_PAD_SLOT_ID
]
*
seq_len
)
continue
...
...
@@ -813,7 +813,6 @@ class ModelRunner:
sampling_params
=
SamplingParams
(
top_p
=
0.99
,
top_k
=
self
.
vocab_size
-
1
)
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
...
...
@@ -1139,3 +1138,15 @@ def _prepare_fake_inputs(
prompt_tokens
=
[
0
]
*
seq_len
fake_image_input
=
None
return
SequenceData
(
prompt_tokens
),
fake_image_input
def
_is_block_tables_empty
(
block_tables
:
Union
[
None
,
Dict
]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if
block_tables
is
None
:
return
True
if
isinstance
(
block_tables
,
dict
)
and
all
(
value
is
None
for
value
in
block_tables
.
values
()):
return
True
return
False
vllm/worker/worker.py
View file @
e254497b
"""A GPU worker class."""
import
gc
import
os
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch.distributed
...
...
@@ -16,8 +16,9 @@ from vllm.distributed.device_communicators.custom_all_reduce import (
init_custom_ar
)
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
,
SamplerOutput
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.embedding_model_runner
import
EmbeddingModelRunner
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.worker_base
import
WorkerBase
...
...
@@ -68,7 +69,9 @@ class Worker(WorkerBase):
assert
not
self
.
lora_config
,
(
"To be tested: vision language model with LoRA settings."
)
self
.
model_runner
=
ModelRunner
(
ModelRunnerClass
=
(
EmbeddingModelRunner
if
self
.
model_config
.
embedding_mode
else
ModelRunner
)
self
.
model_runner
=
ModelRunnerClass
(
model_config
,
parallel_config
,
scheduler_config
,
...
...
@@ -83,7 +86,8 @@ class Worker(WorkerBase):
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self
.
cache_engine
:
CacheEngine
self
.
gpu_cache
:
List
[
torch
.
Tensor
]
# Initialize gpu_cache as embedding models don't initialize kv_caches
self
.
gpu_cache
:
Optional
[
List
[
torch
.
tensor
]]
=
None
def
init_device
(
self
)
->
None
:
if
self
.
device_config
.
device
.
type
==
"cuda"
:
...
...
@@ -209,7 +213,7 @@ class Worker(WorkerBase):
def
execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
)
->
List
[
Union
[
SamplerOutput
,
PoolerOutput
]
]:
if
execute_model_req
is
None
:
seq_group_metadata_list
=
None
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment