Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e254497b
Unverified
Commit
e254497b
authored
May 11, 2024
by
Chang Su
Committed by
GitHub
May 11, 2024
Browse files
[Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734)
parent
4e121310
Changes
38
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
1028 additions
and
70 deletions
+1028
-70
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+121
-29
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+18
-2
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+35
-1
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+134
-0
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+14
-2
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+5
-5
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+56
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+4
-3
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+11
-1
vllm/model_executor/models/llama_embedding.py
vllm/model_executor/models/llama_embedding.py
+87
-0
vllm/model_executor/pooling_metadata.py
vllm/model_executor/pooling_metadata.py
+69
-0
vllm/outputs.py
vllm/outputs.py
+81
-1
vllm/pooling_params.py
vllm/pooling_params.py
+20
-0
vllm/sequence.py
vllm/sequence.py
+77
-12
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+3
-2
vllm/worker/embedding_model_runner.py
vllm/worker/embedding_model_runner.py
+266
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+18
-7
vllm/worker/worker.py
vllm/worker/worker.py
+9
-5
No files found.
vllm/entrypoints/llm.py
View file @
e254497b
...
@@ -6,13 +6,17 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
...
@@ -6,13 +6,17 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
MultiModalData
from
vllm.sequence
import
MultiModalData
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
logger
=
init_logger
(
__name__
)
class
LLM
:
class
LLM
:
"""An LLM for generating texts from given prompts and sampling parameters.
"""An LLM for generating texts from given prompts and sampling parameters.
...
@@ -164,8 +168,89 @@ class LLM:
...
@@ -164,8 +168,89 @@ class LLM:
multi_modal_data: Multi modal data.
multi_modal_data: Multi modal data.
Returns:
Returns:
A list of `RequestOutput` objects containing the generated
A list of `RequestOutput` objects containing the
completions in the same order as the input prompts.
generated completions in the same order as the input prompts.
"""
if
sampling_params
is
None
:
# Use default sampling params.
sampling_params
=
SamplingParams
()
requests_data
=
self
.
_validate_and_prepare_requests
(
prompts
,
sampling_params
,
prompt_token_ids
,
lora_request
,
multi_modal_data
,
)
# Add requests to the engine and run the engine
for
request_data
in
requests_data
:
self
.
_add_request
(
**
request_data
)
return
self
.
_run_engine
(
use_tqdm
)
def
encode
(
self
,
prompts
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
List
[
PoolingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
EmbeddingRequestOutput
]:
"""Generates the completions for the input prompts.
NOTE: This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: A list of prompts to generate completions for.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns:
A list of `EmbeddingRequestOutput` objects containing the
generated embeddings in the same order as the input prompts.
"""
if
pooling_params
is
None
:
# Use default pooling params.
pooling_params
=
PoolingParams
()
requests_data
=
self
.
_validate_and_prepare_requests
(
prompts
,
pooling_params
,
prompt_token_ids
,
lora_request
,
multi_modal_data
,
)
# Add requests to the engine and run the engine
for
request_data
in
requests_data
:
self
.
_add_request
(
**
request_data
)
return
self
.
_run_engine
(
use_tqdm
)
def
_validate_and_prepare_requests
(
self
,
prompts
:
Optional
[
Union
[
str
,
List
[
str
]]],
params
:
Union
[
Union
[
SamplingParams
,
PoolingParams
],
List
[
Union
[
SamplingParams
,
PoolingParams
]]],
# Unified parameter
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
dict
]:
"""Validates and prepares request data for adding to the engine.
Ensures prompts and token IDs are consistent, and returns a list of
dictionaries with request data for further processing.
"""
"""
if
prompts
is
None
and
prompt_token_ids
is
None
:
if
prompts
is
None
and
prompt_token_ids
is
None
:
raise
ValueError
(
"Either prompts or prompt_token_ids must be "
raise
ValueError
(
"Either prompts or prompt_token_ids must be "
...
@@ -188,40 +273,43 @@ class LLM:
...
@@ -188,40 +273,43 @@ class LLM:
assert
prompt_token_ids
is
not
None
assert
prompt_token_ids
is
not
None
num_requests
=
len
(
prompt_token_ids
)
num_requests
=
len
(
prompt_token_ids
)
if
sampling_params
is
None
:
if
isinstance
(
params
,
list
)
and
len
(
params
)
!=
num_requests
:
# Use default sampling params.
raise
ValueError
(
"The lengths of prompts and params "
sampling_params
=
SamplingParams
()
elif
isinstance
(
sampling_params
,
list
)
and
len
(
sampling_params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and sampling_params "
"must be the same."
)
"must be the same."
)
if
multi_modal_data
:
if
multi_modal_data
:
multi_modal_data
.
data
=
multi_modal_data
.
data
.
to
(
torch
.
float16
)
multi_modal_data
.
data
=
multi_modal_data
.
data
.
to
(
torch
.
float16
)
# Add requests to the engine.
# Add requests to the engine.
requests_data
=
[]
for
i
in
range
(
num_requests
):
for
i
in
range
(
num_requests
):
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
token_ids
=
None
if
prompt_token_ids
is
None
else
prompt_token_ids
[
token_ids
=
None
if
prompt_token_ids
is
None
else
prompt_token_ids
[
i
]
i
]
self
.
_add_request
(
multi_modal_item
=
MultiModalData
(
type
=
multi_modal_data
.
type
,
data
=
multi_modal_data
.
data
[
i
].
unsqueeze
(
0
),
)
if
multi_modal_data
else
None
requests_data
.
append
({
"prompt"
:
prompt
,
prompt
,
sampling_params
[
i
]
"params"
:
if
isinstance
(
sampling_params
,
list
)
else
sampling_params
,
params
[
i
]
if
isinstance
(
params
,
list
)
else
params
,
"prompt_token_ids"
:
token_ids
,
token_ids
,
lora_request
=
lora_request
,
"lora_request"
:
# Get ith image while maintaining the batch dim.
lora_request
,
multi_modal_data
=
MultiModalData
(
"multi_modal_data"
:
type
=
multi_modal_data
.
type
,
multi_modal_item
,
data
=
multi_modal_data
.
data
[
i
].
unsqueeze
(
0
))
})
if
multi_modal_data
else
None
,
)
return
requests_data
return
self
.
_run_engine
(
use_tqdm
)
def
_add_request
(
def
_add_request
(
self
,
self
,
prompt
:
Optional
[
str
],
prompt
:
Optional
[
str
],
s
ampling
_p
arams
:
Samp
lingParams
,
params
:
Union
[
S
ampling
P
arams
,
Poo
lingParams
]
,
prompt_token_ids
:
Optional
[
List
[
int
]],
prompt_token_ids
:
Optional
[
List
[
int
]],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
...
@@ -229,12 +317,14 @@ class LLM:
...
@@ -229,12 +317,14 @@ class LLM:
request_id
=
str
(
next
(
self
.
request_counter
))
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_engine
.
add_request
(
request_id
,
self
.
llm_engine
.
add_request
(
request_id
,
prompt
,
prompt
,
sampling_
params
,
params
,
prompt_token_ids
,
prompt_token_ids
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
multi_modal_data
=
multi_modal_data
)
multi_modal_data
=
multi_modal_data
)
def
_run_engine
(
self
,
use_tqdm
:
bool
)
->
List
[
RequestOutput
]:
def
_run_engine
(
self
,
use_tqdm
:
bool
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
# Initialize tqdm.
# Initialize tqdm.
if
use_tqdm
:
if
use_tqdm
:
num_requests
=
self
.
llm_engine
.
get_num_unfinished_requests
()
num_requests
=
self
.
llm_engine
.
get_num_unfinished_requests
()
...
@@ -245,7 +335,7 @@ class LLM:
...
@@ -245,7 +335,7 @@ class LLM:
postfix
=
f
"Generation Speed:
{
0
:.
2
f
}
toks/s"
,
postfix
=
f
"Generation Speed:
{
0
:.
2
f
}
toks/s"
,
)
)
# Run the engine.
# Run the engine.
outputs
:
List
[
RequestOutput
]
=
[]
outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]
]
=
[]
total_toks
=
0
total_toks
=
0
while
self
.
llm_engine
.
has_unfinished_requests
():
while
self
.
llm_engine
.
has_unfinished_requests
():
step_outputs
=
self
.
llm_engine
.
step
()
step_outputs
=
self
.
llm_engine
.
step
()
...
@@ -253,8 +343,10 @@ class LLM:
...
@@ -253,8 +343,10 @@ class LLM:
if
output
.
finished
:
if
output
.
finished
:
outputs
.
append
(
output
)
outputs
.
append
(
output
)
if
use_tqdm
:
if
use_tqdm
:
total_toks
+=
(
sum
(
if
isinstance
(
output
,
RequestOutput
):
len
(
stp
.
token_ids
)
for
stp
in
output
.
outputs
))
# Calculate tokens only for RequestOutput
total_toks
+=
sum
(
len
(
stp
.
token_ids
)
for
stp
in
output
.
outputs
)
spd
=
total_toks
/
pbar
.
format_dict
[
"elapsed"
]
spd
=
total_toks
/
pbar
.
format_dict
[
"elapsed"
]
pbar
.
postfix
=
f
"Generation Speed:
{
spd
:.
2
f
}
toks/s"
pbar
.
postfix
=
f
"Generation Speed:
{
spd
:.
2
f
}
toks/s"
pbar
.
update
(
1
)
pbar
.
update
(
1
)
...
...
vllm/entrypoints/openai/api_server.py
View file @
e254497b
...
@@ -22,9 +22,11 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
...
@@ -22,9 +22,11 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
ChatCompletionResponse
,
ChatCompletionResponse
,
CompletionRequest
,
ErrorResponse
)
CompletionRequest
,
EmbeddingRequest
,
ErrorResponse
)
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
...
@@ -32,6 +34,8 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
...
@@ -32,6 +34,8 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat
:
OpenAIServingChat
openai_serving_chat
:
OpenAIServingChat
openai_serving_completion
:
OpenAIServingCompletion
openai_serving_completion
:
OpenAIServingCompletion
openai_serving_embedding
:
OpenAIServingEmbedding
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_running_tasks
:
Set
[
asyncio
.
Task
]
=
set
()
_running_tasks
:
Set
[
asyncio
.
Task
]
=
set
()
...
@@ -123,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
...
@@ -123,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return
JSONResponse
(
content
=
generator
.
model_dump
())
return
JSONResponse
(
content
=
generator
.
model_dump
())
@
app
.
post
(
"/v1/embeddings"
)
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
generator
=
await
openai_serving_embedding
.
create_embedding
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
else
:
return
JSONResponse
(
content
=
generator
.
model_dump
())
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
parse_args
()
args
=
parse_args
()
...
@@ -190,7 +205,8 @@ if __name__ == "__main__":
...
@@ -190,7 +205,8 @@ if __name__ == "__main__":
args
.
chat_template
)
args
.
chat_template
)
openai_serving_completion
=
OpenAIServingCompletion
(
openai_serving_completion
=
OpenAIServingCompletion
(
engine
,
model_config
,
served_model_names
,
args
.
lora_modules
)
engine
,
model_config
,
served_model_names
,
args
.
lora_modules
)
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine
,
model_config
,
served_model_names
)
app
.
root_path
=
args
.
root_path
app
.
root_path
=
args
.
root_path
uvicorn
.
run
(
app
,
uvicorn
.
run
(
app
,
host
=
args
.
host
,
host
=
args
.
host
,
...
...
vllm/entrypoints/openai/protocol.py
View file @
e254497b
# Adapted from
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import
time
import
time
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
import
torch
import
torch
from
openai.types.chat
import
ChatCompletionMessageParam
from
openai.types.chat
import
ChatCompletionMessageParam
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
model_validator
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
model_validator
from
typing_extensions
import
Annotated
from
typing_extensions
import
Annotated
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
...
@@ -363,6 +364,24 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -363,6 +364,24 @@ class CompletionRequest(OpenAIBaseModel):
return
data
return
data
class
EmbeddingRequest
(
BaseModel
):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
model
:
str
input
:
Union
[
List
[
int
],
List
[
List
[
int
]],
str
,
List
[
str
]]
encoding_format
:
Optional
[
str
]
=
Field
(
'float'
,
pattern
=
'^(float|base64)$'
)
dimensions
:
Optional
[
int
]
=
None
user
:
Optional
[
str
]
=
None
# doc: begin-embedding-pooling-params
additional_data
:
Optional
[
Any
]
=
None
# doc: end-embedding-pooling-params
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
class
LogProbs
(
OpenAIBaseModel
):
class
LogProbs
(
OpenAIBaseModel
):
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
text_offset
:
List
[
int
]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
token_logprobs
:
List
[
Optional
[
float
]]
=
Field
(
default_factory
=
list
)
...
@@ -416,6 +435,21 @@ class CompletionStreamResponse(OpenAIBaseModel):
...
@@ -416,6 +435,21 @@ class CompletionStreamResponse(OpenAIBaseModel):
usage
:
Optional
[
UsageInfo
]
=
Field
(
default
=
None
)
usage
:
Optional
[
UsageInfo
]
=
Field
(
default
=
None
)
class
EmbeddingResponseData
(
BaseModel
):
index
:
int
object
:
str
=
"embedding"
embedding
:
List
[
float
]
class
EmbeddingResponse
(
BaseModel
):
id
:
str
=
Field
(
default_factory
=
lambda
:
f
"cmpl-
{
random_uuid
()
}
"
)
object
:
str
=
"list"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
data
:
List
[
EmbeddingResponseData
]
usage
:
UsageInfo
class
ChatMessage
(
OpenAIBaseModel
):
class
ChatMessage
(
OpenAIBaseModel
):
role
:
str
role
:
str
content
:
str
content
:
str
...
...
vllm/entrypoints/openai/serving_embedding.py
0 → 100644
View file @
e254497b
import
time
from
typing
import
AsyncIterator
,
List
,
Tuple
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
EmbeddingRequest
,
EmbeddingResponse
,
EmbeddingResponseData
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_completion
import
parse_prompt_format
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.logger
import
init_logger
from
vllm.outputs
import
EmbeddingRequestOutput
from
vllm.utils
import
merge_async_iterators
,
random_uuid
logger
=
init_logger
(
__name__
)
TypeTokenIDs
=
List
[
int
]
def
request_output_to_embedding_response
(
final_res_batch
:
List
[
EmbeddingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
)
->
EmbeddingResponse
:
data
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
assert
final_res
is
not
None
prompt_token_ids
=
final_res
.
prompt_token_ids
embedding_data
=
EmbeddingResponseData
(
index
=
idx
,
embedding
=
final_res
.
outputs
.
embedding
)
data
.
append
(
embedding_data
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
total_tokens
=
num_prompt_tokens
,
)
return
EmbeddingResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
data
=
data
,
usage
=
usage
,
)
class
OpenAIServingEmbedding
(
OpenAIServing
):
def
__init__
(
self
,
engine
:
AsyncLLMEngine
,
model_config
:
ModelConfig
,
served_model_names
:
List
[
str
]):
super
().
__init__
(
engine
=
engine
,
model_config
=
model_config
,
served_model_names
=
served_model_names
,
lora_modules
=
None
)
self
.
_check_embedding_mode
(
model_config
.
embedding_mode
)
async
def
create_embedding
(
self
,
request
:
EmbeddingRequest
,
raw_request
:
Request
):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
# Return error for unsupported features.
if
request
.
encoding_format
==
"base64"
:
return
self
.
create_error_response
(
"base64 encoding is not currently supported"
)
if
request
.
dimensions
is
not
None
:
return
self
.
create_error_response
(
"dimensions is currently not supported"
)
model_name
=
request
.
model
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
created_time
=
int
(
time
.
monotonic
())
# Schedule the request and get the result generator.
generators
=
[]
try
:
prompt_is_tokens
,
prompts
=
parse_prompt_format
(
request
.
input
)
pooling_params
=
request
.
to_pooling_params
()
for
i
,
prompt
in
enumerate
(
prompts
):
if
prompt_is_tokens
:
prompt_formats
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt_ids
=
prompt
)
else
:
prompt_formats
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
prompt
)
prompt_ids
,
prompt_text
=
prompt_formats
generators
.
append
(
self
.
engine
.
generate
(
prompt_text
,
pooling_params
,
f
"
{
request_id
}
-
{
i
}
"
,
prompt_token_ids
=
prompt_ids
))
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
result_generator
:
AsyncIterator
[
Tuple
[
int
,
EmbeddingRequestOutput
]]
=
merge_async_iterators
(
*
generators
)
# Non-streaming response
final_res_batch
:
EmbeddingRequestOutput
=
[
None
]
*
len
(
prompts
)
async
for
i
,
res
in
result_generator
:
if
await
raw_request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
self
.
engine
.
abort
(
f
"
{
request_id
}
-
{
i
}
"
)
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
"Client disconnected"
)
final_res_batch
[
i
]
=
res
response
=
request_output_to_embedding_response
(
final_res_batch
,
request_id
,
created_time
,
model_name
)
return
response
def
_check_embedding_mode
(
self
,
embedding_mode
:
bool
):
if
not
embedding_mode
:
logger
.
warning
(
"embedding_mode is False. Embedding API will not work."
)
else
:
logger
.
info
(
"Activating the server engine with embedding enabled."
)
vllm/entrypoints/openai/serving_engine.py
View file @
e254497b
...
@@ -9,7 +9,8 @@ from typing_extensions import Annotated
...
@@ -9,7 +9,8 @@ from typing_extensions import Annotated
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
,
ErrorResponse
,
CompletionRequest
,
EmbeddingRequest
,
ErrorResponse
,
LogProbs
,
ModelCard
,
ModelList
,
LogProbs
,
ModelCard
,
ModelList
,
ModelPermission
)
ModelPermission
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -165,7 +166,8 @@ class OpenAIServing:
...
@@ -165,7 +166,8 @@ class OpenAIServing:
def
_validate_prompt_and_tokenize
(
def
_validate_prompt_and_tokenize
(
self
,
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
],
prompt
:
Optional
[
str
]
=
None
,
prompt
:
Optional
[
str
]
=
None
,
prompt_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_ids
:
Optional
[
List
[
int
]]
=
None
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
...
@@ -191,6 +193,16 @@ class OpenAIServing:
...
@@ -191,6 +193,16 @@ class OpenAIServing:
prompt_ids
)
prompt_ids
)
token_num
=
len
(
input_ids
)
token_num
=
len
(
input_ids
)
# Note: EmbeddingRequest doesn't have max_tokens
if
isinstance
(
request
,
EmbeddingRequest
):
if
token_num
>
self
.
max_model_len
:
raise
ValueError
(
f
"This model's maximum context length is "
f
"
{
self
.
max_model_len
}
tokens. However, you requested "
f
"
{
token_num
}
tokens in the input for embedding "
f
"generation. Please reduce the length of the input."
,
)
return
input_ids
,
input_text
if
request
.
max_tokens
is
None
:
if
request
.
max_tokens
is
None
:
if
token_num
>=
self
.
max_model_len
:
if
token_num
>=
self
.
max_model_len
:
raise
ValueError
(
raise
ValueError
(
...
...
vllm/executor/gpu_executor.py
View file @
e254497b
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
,
SamplerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
make_async
)
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.worker.worker_base
import
WorkerWrapperBase
...
@@ -123,8 +123,8 @@ class GPUExecutor(ExecutorBase):
...
@@ -123,8 +123,8 @@ class GPUExecutor(ExecutorBase):
self
.
driver_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
self
.
driver_worker
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
def
execute_model
(
def
execute_model
(
self
,
self
,
execute_model_req
:
ExecuteModelRequest
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
)
->
List
[
Union
[
SamplerOutput
,
PoolerOutput
]
]:
output
=
self
.
driver_worker
.
execute_model
(
execute_model_req
)
output
=
self
.
driver_worker
.
execute_model
(
execute_model_req
)
return
output
return
output
...
@@ -150,7 +150,7 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
...
@@ -150,7 +150,7 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
async
def
execute_model_async
(
async
def
execute_model_async
(
self
,
self
,
execute_model_req
:
ExecuteModelRequest
,
execute_model_req
:
ExecuteModelRequest
,
)
->
List
[
SamplerOutput
]:
)
->
List
[
Union
[
SamplerOutput
,
PoolerOutput
]
]:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
execute_model_req
=
execute_model_req
,
)
)(
execute_model_req
=
execute_model_req
,
)
return
output
return
output
vllm/model_executor/layers/pooler.py
0 → 100644
View file @
e254497b
from
enum
import
IntEnum
import
torch
import
torch.nn
as
nn
from
vllm.model_executor.pooling_metadata
import
(
PoolingMetadata
,
PoolingTensors
)
from
vllm.sequence
import
EmbeddingSequenceGroupOutput
,
PoolerOutput
class
PoolingType
(
IntEnum
):
"""Enumeration for different types of pooling methods."""
LAST
=
0
class
Pooler
(
nn
.
Module
):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
normalize: Whether to normalize the pooled data.
"""
def
__init__
(
self
,
pooling_type
:
PoolingType
,
normalize
:
bool
):
super
().
__init__
()
self
.
pooling_type
=
pooling_type
self
.
normalize
=
normalize
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
"""Pools specific information from hidden states based on metadata."""
prompt_lens
=
PoolingTensors
.
from_pooling_metadata
(
pooling_metadata
,
hidden_states
.
device
).
prompt_lens
if
self
.
pooling_type
==
PoolingType
.
LAST
:
last_token_flat_indices
=
torch
.
cumsum
(
prompt_lens
,
dim
=
0
)
-
1
pooled_data
=
hidden_states
[
last_token_flat_indices
]
else
:
raise
ValueError
(
f
"Invalid pooling type:
{
self
.
pooling_type
}
"
)
if
self
.
normalize
:
pooled_data
=
nn
.
functional
.
normalize
(
pooled_data
,
p
=
2
,
dim
=
1
)
pooled_outputs
=
[
EmbeddingSequenceGroupOutput
(
data
.
tolist
())
for
data
in
pooled_data
]
return
PoolerOutput
(
outputs
=
pooled_outputs
)
vllm/model_executor/layers/sampler.py
View file @
e254497b
...
@@ -10,8 +10,9 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
...
@@ -10,8 +10,9 @@ from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors
,
SamplingTensors
,
SequenceGroupToSample
)
SequenceGroupToSample
)
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
(
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
SamplerOutput
,
SequenceGroupOutput
,
SequenceOutput
)
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceOutput
)
# (num_token_ids, num_parent_ids) per sequence group.
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType
=
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
SampleResultType
=
List
[
Tuple
[
List
[
int
],
List
[
int
]]]
...
@@ -1019,7 +1020,7 @@ def _build_sampler_output(
...
@@ -1019,7 +1020,7 @@ def _build_sampler_output(
seq_outputs
.
append
(
seq_outputs
.
append
(
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
logprobs
))
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
logprobs
))
sampler_output
.
append
(
sampler_output
.
append
(
SequenceGroupOutput
(
seq_outputs
,
group_prompt_logprobs
))
Completion
SequenceGroupOutput
(
seq_outputs
,
group_prompt_logprobs
))
# If not specified, store None values in SamplerOutput.
# If not specified, store None values in SamplerOutput.
if
on_device_tensors
is
not
None
:
if
on_device_tensors
is
not
None
:
...
...
vllm/model_executor/models/__init__.py
View file @
e254497b
...
@@ -9,7 +9,7 @@ from vllm.utils import is_hip
...
@@ -9,7 +9,7 @@ from vllm.utils import is_hip
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# Architecture -> (module, class).
# Architecture -> (module, class).
_MODELS
=
{
_GENERATION
_MODELS
=
{
"AquilaModel"
:
(
"llama"
,
"LlamaForCausalLM"
),
"AquilaModel"
:
(
"llama"
,
"LlamaForCausalLM"
),
"AquilaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
# AquilaChat2
"AquilaForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
# AquilaChat2
"BaiChuanForCausalLM"
:
(
"baichuan"
,
"BaiChuanForCausalLM"
),
# baichuan-7b
"BaiChuanForCausalLM"
:
(
"baichuan"
,
"BaiChuanForCausalLM"
),
# baichuan-7b
...
@@ -58,6 +58,12 @@ _MODELS = {
...
@@ -58,6 +58,12 @@ _MODELS = {
"XverseForCausalLM"
:
(
"xverse"
,
"XverseForCausalLM"
),
"XverseForCausalLM"
:
(
"xverse"
,
"XverseForCausalLM"
),
}
}
_EMBEDDING_MODELS
=
{
"MistralModel"
:
(
"llama_embedding"
,
"LlamaEmbeddingModel"
),
}
_MODELS
=
{
**
_GENERATION_MODELS
,
**
_EMBEDDING_MODELS
}
# Architecture -> type.
# Architecture -> type.
# out of tree models
# out of tree models
_OOT_MODELS
:
Dict
[
str
,
Type
[
nn
.
Module
]]
=
{}
_OOT_MODELS
:
Dict
[
str
,
Type
[
nn
.
Module
]]
=
{}
...
@@ -114,6 +120,10 @@ class ModelRegistry:
...
@@ -114,6 +120,10 @@ class ModelRegistry:
global
_OOT_MODELS
global
_OOT_MODELS
_OOT_MODELS
[
model_arch
]
=
model_cls
_OOT_MODELS
[
model_arch
]
=
model_cls
@
staticmethod
def
is_embedding_model
(
model_arch
:
str
)
->
bool
:
return
model_arch
in
_EMBEDDING_MODELS
__all__
=
[
__all__
=
[
"ModelRegistry"
,
"ModelRegistry"
,
...
...
vllm/model_executor/models/llama_embedding.py
0 → 100644
View file @
e254497b
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
vllm.attention
import
AttentionMetadata
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
PoolerOutput
class
LlamaEmbeddingModel
(
nn
.
Module
):
"""A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def
__init__
(
self
,
**
kwargs
,
)
->
None
:
super
().
__init__
()
self
.
model
=
LlamaModel
(
**
kwargs
)
self
.
_pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
model
.
forward
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
inputs_embeds
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
model
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/pooling_metadata.py
0 → 100644
View file @
e254497b
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Tuple
import
torch
from
vllm.pooling_params
import
PoolingParams
from
vllm.utils
import
is_pin_memory_available
class
PoolingMetadata
:
"""Metadata for pooling operations in the Pooler layer.
This class holds the necessary information for pooling operations,
providing context for how to perform pooling and other related operations.
Attributes:
seq_groups: List of (seq_ids, pooling_params).
seq_data: A mapping of sequence ID to additional sequence data.
prompt_lens: List of the lengths of each prompt.
"""
def
__init__
(
self
,
seq_groups
:
List
[
Tuple
[
List
[
int
],
PoolingParams
]],
seq_data
:
Dict
[
int
,
Any
],
# Specific data related to sequences
prompt_lens
:
List
[
int
],
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
seq_data
=
seq_data
self
.
prompt_lens
=
prompt_lens
def
__repr__
(
self
)
->
str
:
return
(
"PoolingMetadata("
f
"seq_groups=
{
self
.
seq_groups
}
, "
f
"seq_data=
{
self
.
seq_data
}
, "
f
"prompt_lens=
{
self
.
prompt_lens
}
)"
)
@
dataclass
class
PoolingTensors
:
"""Tensors for pooling."""
prompt_lens
:
torch
.
Tensor
@
classmethod
def
from_pooling_metadata
(
cls
,
pooling_metadata
:
"PoolingMetadata"
,
device
:
torch
.
device
,
)
->
"PoolingTensors"
:
"""
Create PoolingTensors from PoolingMetadata.
Args:
pooling_metadata: PoolingMetadata instance to convert.
device: Device to store the tensors.
"""
# Convert prompt lengths to tensor
pin_memory
=
is_pin_memory_available
()
prompt_lens_t
=
torch
.
tensor
(
pooling_metadata
.
prompt_lens
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
return
cls
(
prompt_lens
=
prompt_lens_t
.
to
(
device
=
device
,
non_blocking
=
True
),
)
vllm/outputs.py
View file @
e254497b
...
@@ -57,8 +57,27 @@ class CompletionOutput:
...
@@ -57,8 +57,27 @@ class CompletionOutput:
f
"stop_reason=
{
self
.
stop_reason
}
)"
)
f
"stop_reason=
{
self
.
stop_reason
}
)"
)
class
EmbeddingOutput
:
"""The output data of one completion output of a request.
Args:
embedding: The embedding vector, which is a list of floats. The
length of vector depends on the model as listed in the embedding guide.
"""
def
__init__
(
self
,
embedding
:
List
[
float
],
)
->
None
:
self
.
embedding
=
embedding
def
__repr__
(
self
)
->
str
:
return
(
f
"EmbeddingOutput("
f
"embedding=
{
len
(
self
.
embedding
)
}
"
)
class
RequestOutput
:
class
RequestOutput
:
"""The output data of a request to the LLM.
"""The output data of a
completion
request to the LLM.
Args:
Args:
request_id: The unique ID of the request.
request_id: The unique ID of the request.
...
@@ -93,6 +112,9 @@ class RequestOutput:
...
@@ -93,6 +112,9 @@ class RequestOutput:
@
classmethod
@
classmethod
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
)
->
"RequestOutput"
:
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
)
->
"RequestOutput"
:
if
seq_group
.
sampling_params
is
None
:
raise
ValueError
(
"Sampling parameters are missing for a CompletionRequest."
)
seqs
=
seq_group
.
get_seqs
()
seqs
=
seq_group
.
get_seqs
()
if
len
(
seqs
)
==
1
:
if
len
(
seqs
)
==
1
:
top_n_seqs
=
seqs
top_n_seqs
=
seqs
...
@@ -148,3 +170,61 @@ class RequestOutput:
...
@@ -148,3 +170,61 @@ class RequestOutput:
f
"finished=
{
self
.
finished
}
, "
f
"finished=
{
self
.
finished
}
, "
f
"metrics=
{
self
.
metrics
}
, "
f
"metrics=
{
self
.
metrics
}
, "
f
"lora_request=
{
self
.
lora_request
}
)"
)
f
"lora_request=
{
self
.
lora_request
}
)"
)
class
EmbeddingRequestOutput
:
"""
The output data of an embedding request to the LLM.
Args:
request_id (str): A unique identifier for the embedding request.
outputs (EmbeddingOutput): The embedding results for the given input.
prompt_token_ids (List[int]): A list of token IDs used in the prompt.
finished (bool): A flag indicating whether the embedding is completed.
"""
def
__init__
(
self
,
request_id
:
str
,
outputs
:
'EmbeddingOutput'
,
prompt_token_ids
:
List
[
int
],
finished
:
bool
):
self
.
request_id
=
request_id
self
.
prompt_token_ids
=
prompt_token_ids
self
.
finished
=
finished
self
.
outputs
=
outputs
@
classmethod
def
from_seq_group
(
cls
,
seq_group
:
'SequenceGroup'
)
->
"EmbeddingRequestOutput"
:
if
seq_group
.
embeddings
is
None
:
raise
ValueError
(
"Embeddings are missing in seq_group for EmbeddingRequest."
)
output
=
EmbeddingOutput
(
seq_group
.
embeddings
)
prompt_token_ids
=
seq_group
.
prompt_token_ids
finished
=
seq_group
.
is_finished
()
return
cls
(
seq_group
.
request_id
,
output
,
prompt_token_ids
,
finished
)
def
__repr__
(
self
):
"""
Returns a string representation of an EmbeddingRequestOutput instance.
The representation includes the request_id and the number of outputs,
providing a quick overview of the embedding request's results.
Returns:
str: A string representation of the EmbeddingRequestOutput instance.
"""
return
(
f
"EmbeddingRequestOutput(request_id='
{
self
.
request_id
}
', "
f
"outputs=
{
repr
(
self
.
outputs
)
}
, "
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"finished=
{
self
.
finished
}
)"
)
class
RequestOutputFactory
:
@
staticmethod
def
create
(
seq_group
):
# Determine the type based on a condition, for example:
if
hasattr
(
seq_group
,
'embeddings'
)
and
seq_group
.
embeddings
is
not
None
:
return
EmbeddingRequestOutput
.
from_seq_group
(
seq_group
)
else
:
return
RequestOutput
.
from_seq_group
(
seq_group
)
vllm/pooling_params.py
0 → 100644
View file @
e254497b
from
typing
import
Any
,
Optional
class
PoolingParams
:
"""Pooling parameters for pooling.
Attributes:
additional_data: Any additional data needed for pooling.
"""
def
__init__
(
self
,
additional_data
:
Optional
[
Any
]
=
None
):
self
.
additional_data
=
additional_data
def
clone
(
self
)
->
"PoolingParams"
:
"""Returns a deep copy of the PoolingParams instance."""
return
PoolingParams
(
additional_data
=
self
.
additional_data
,
)
def
__repr__
(
self
)
->
str
:
return
(
f
"PoolingParams("
f
"additional_metadata=
{
self
.
additional_data
}
)"
)
vllm/sequence.py
View file @
e254497b
"""Sequence and its related classes."""
"""Sequence and its related classes."""
import
copy
import
copy
import
enum
import
enum
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
vllm.block
import
LogicalTokenBlock
from
vllm.block
import
LogicalTokenBlock
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -402,16 +404,22 @@ class SequenceGroup:
...
@@ -402,16 +404,22 @@ class SequenceGroup:
arrival_time: The arrival time of the request.
arrival_time: The arrival time of the request.
lora_request: LoRA request.
lora_request: LoRA request.
multi_modal_data: Multi modal data associated with the request.
multi_modal_data: Multi modal data associated with the request.
embeddings: The embeddings vectors of the prompt of the sequence group
for an embedding model.
pooling_params: The pooling parameters used to generate the pooling
for an embedding model.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
seqs
:
List
[
Sequence
],
seqs
:
List
[
Sequence
],
sampling_params
:
SamplingParams
,
arrival_time
:
float
,
arrival_time
:
float
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
embeddings
:
Optional
[
List
[
float
]]
=
None
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
seqs_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
self
.
seqs_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
...
@@ -425,6 +433,8 @@ class SequenceGroup:
...
@@ -425,6 +433,8 @@ class SequenceGroup:
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
state
=
SequenceGroupState
()
self
.
state
=
SequenceGroupState
()
self
.
multi_modal_data
=
multi_modal_data
self
.
multi_modal_data
=
multi_modal_data
self
.
embeddings
=
embeddings
self
.
pooling_params
=
pooling_params
@
property
@
property
def
prompt
(
self
)
->
str
:
def
prompt
(
self
)
->
str
:
...
@@ -479,12 +489,13 @@ class SequenceGroup:
...
@@ -479,12 +489,13 @@ class SequenceGroup:
def
get_max_num_running_seqs
(
self
)
->
int
:
def
get_max_num_running_seqs
(
self
)
->
int
:
"""The maximum number of sequences running in parallel in the remaining
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
lifetime of the request."""
if
self
.
sampling_params
.
use_beam_search
:
if
self
.
sampling_params
and
self
.
sampling_params
.
use_beam_search
:
# For beam search, maximally there will always be `best_of` beam
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
# candidates running in the future.
return
self
.
sampling_params
.
best_of
return
self
.
sampling_params
.
best_of
else
:
else
:
if
self
.
sampling_params
.
best_of
>
self
.
num_seqs
():
if
(
self
.
sampling_params
and
self
.
sampling_params
.
best_of
>
self
.
num_seqs
()):
# At prompt stage, the sequence group is not yet filled up
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences running.
# generation stage, we will have `best_of` sequences running.
...
@@ -555,7 +566,7 @@ class SequenceGroup:
...
@@ -555,7 +566,7 @@ class SequenceGroup:
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
get_seqs
())
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
get_seqs
())
def
is_prefill
(
self
)
->
bool
:
def
is_prefill
(
self
)
->
bool
:
# Every sequence
s
should be in the same stage.
# Every sequence should be in the same stage.
return
self
.
get_seqs
()[
0
].
is_prefill
()
return
self
.
get_seqs
()[
0
].
is_prefill
()
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
...
@@ -594,6 +605,7 @@ class SequenceGroupMetadata:
...
@@ -594,6 +605,7 @@ class SequenceGroupMetadata:
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
block_tables
:
Dict
[
int
,
List
[
int
]],
block_tables
:
Dict
[
int
,
List
[
int
]],
do_sample
:
bool
=
True
,
do_sample
:
bool
=
True
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
token_chunk_size
:
Optional
[
int
]
=
None
,
token_chunk_size
:
Optional
[
int
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
...
@@ -605,6 +617,7 @@ class SequenceGroupMetadata:
...
@@ -605,6 +617,7 @@ class SequenceGroupMetadata:
self
.
seq_data
=
seq_data
self
.
seq_data
=
seq_data
self
.
sampling_params
=
sampling_params
self
.
sampling_params
=
sampling_params
self
.
block_tables
=
block_tables
self
.
block_tables
=
block_tables
self
.
pooling_params
=
pooling_params
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
self
.
computed_block_nums
=
computed_block_nums
self
.
computed_block_nums
=
computed_block_nums
self
.
multi_modal_data
=
multi_modal_data
self
.
multi_modal_data
=
multi_modal_data
...
@@ -669,8 +682,20 @@ class SequenceOutput:
...
@@ -669,8 +682,20 @@ class SequenceOutput:
return
equal
and
log_probs_equal
return
equal
and
log_probs_equal
class
SequenceGroupOutput
:
class
SequenceGroupOutput
(
ABC
):
"""The model output associated with a sequence group."""
"""The base class for model outputs associated with a sequence group."""
@
abstractmethod
def
__repr__
(
self
)
->
str
:
pass
@
abstractmethod
def
__eq__
(
self
,
other
:
object
)
->
bool
:
pass
class
CompletionSequenceGroupOutput
(
SequenceGroupOutput
):
"""The model output associated with a completion sequence group."""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -682,26 +707,45 @@ class SequenceGroupOutput:
...
@@ -682,26 +707,45 @@ class SequenceGroupOutput:
self
.
prompt_logprobs
=
prompt_logprobs
self
.
prompt_logprobs
=
prompt_logprobs
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceGroupOutput(samples=
{
self
.
samples
}
, "
return
(
f
"
Completion
SequenceGroupOutput(samples=
{
self
.
samples
}
, "
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
)"
)
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
)"
)
def
__eq__
(
self
,
other
:
object
)
->
bool
:
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
SequenceGroupOutput
):
if
not
isinstance
(
other
,
Completion
SequenceGroupOutput
):
raise
NotImplementedError
()
raise
NotImplementedError
()
return
(
self
.
samples
==
other
.
samples
return
(
self
.
samples
==
other
.
samples
and
self
.
prompt_logprobs
==
other
.
prompt_logprobs
)
and
self
.
prompt_logprobs
==
other
.
prompt_logprobs
)
class
EmbeddingSequenceGroupOutput
(
SequenceGroupOutput
):
"""The model output associated with an embedding sequence group."""
def
__init__
(
self
,
embeddings
:
List
[
float
],
)
->
None
:
self
.
embeddings
=
embeddings
def
__repr__
(
self
)
->
str
:
return
(
f
"EmbeddingSequenceGroupOutput("
f
"embeddings_shape=
{
len
(
self
.
embeddings
)
}
)"
)
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
EmbeddingSequenceGroupOutput
):
raise
NotImplementedError
()
return
self
.
embeddings
==
other
.
embeddings
@
dataclass
@
dataclass
class
SamplerOutput
:
class
SamplerOutput
:
"""For each sequence group, we generate a list of SequenceOutput object,
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
each of which contains one possible candidate for the next token.
This datastructure implements methods so it can be used like a list, but
This data
structure implements methods
,
so it can be used like a list, but
also has optional fields for device tensors.
also has optional fields for device tensors.
"""
"""
outputs
:
List
[
SequenceGroupOutput
]
outputs
:
List
[
Completion
SequenceGroupOutput
]
# On-device tensor containing probabilities of each token.
# On-device tensor containing probabilities of each token.
sampled_token_probs
:
Optional
[
"torch.Tensor"
]
=
None
sampled_token_probs
:
Optional
[
"torch.Tensor"
]
=
None
...
@@ -742,6 +786,27 @@ class SamplerOutput:
...
@@ -742,6 +786,27 @@ class SamplerOutput:
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
)"
)
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
)"
)
@
dataclass
class
PoolerOutput
:
"""The output from a pooling operation in the embedding model."""
outputs
:
List
[
EmbeddingSequenceGroupOutput
]
spec_decode_worker_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
def
__getitem__
(
self
,
idx
:
int
):
return
self
.
outputs
[
idx
]
def
__setitem__
(
self
,
idx
:
int
,
value
):
self
.
outputs
[
idx
]
=
value
def
__len__
(
self
):
return
len
(
self
.
outputs
)
def
__eq__
(
self
,
other
:
object
):
return
isinstance
(
other
,
self
.
__class__
)
and
self
.
outputs
==
other
.
outputs
@
dataclass
@
dataclass
class
ExecuteModelRequest
:
class
ExecuteModelRequest
:
"""The model execution request."""
"""The model execution request."""
...
...
vllm/spec_decode/util.py
View file @
e254497b
...
@@ -4,7 +4,8 @@ from typing import Dict, List, Tuple
...
@@ -4,7 +4,8 @@ from typing import Dict, List, Tuple
import
torch
import
torch
from
vllm.sequence
import
(
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceOutput
)
SequenceGroupOutput
,
SequenceOutput
)
SeqId
=
int
SeqId
=
int
...
@@ -94,7 +95,7 @@ def create_sequence_group_output(
...
@@ -94,7 +95,7 @@ def create_sequence_group_output(
for
topk_logprob_index
,
_
in
enumerate
(
topk_token_ids
)
for
topk_logprob_index
,
_
in
enumerate
(
topk_token_ids
)
})
})
return
SequenceGroupOutput
(
return
Completion
SequenceGroupOutput
(
samples
=
[
samples
=
[
SequenceOutput
(
parent_seq_id
=
seq_id
,
SequenceOutput
(
parent_seq_id
=
seq_id
,
output_token
=
token_id
,
output_token
=
token_id
,
...
...
vllm/worker/embedding_model_runner.py
0 → 100644
View file @
e254497b
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
import
torch
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.pooling_params
import
PoolingParams
from
vllm.sequence
import
PoolerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.model_runner
import
BatchType
,
ModelRunner
logger
=
init_logger
(
__name__
)
class
EmbeddingModelRunner
(
ModelRunner
):
def
__init__
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
vision_language_config
:
Optional
[
VisionLanguageConfig
]
=
None
,
):
super
().
__init__
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
cache_config
,
load_config
,
lora_config
=
lora_config
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
is_driver_worker
,
vision_language_config
=
vision_language_config
)
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
kv_caches
:
List
[
torch
.
Tensor
],
)
->
Optional
[
PoolerOutput
]:
(
input_tokens
,
input_positions
,
attn_metadata
,
pooling_metadata
,
lora_requests
,
lora_mapping
,
multi_modal_input
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
if
self
.
lora_config
:
self
.
set_active_loras
(
lora_requests
,
lora_mapping
)
# Currently cuda graph is only supported by the decode phase.
prefill_meta
=
attn_metadata
.
prefill_metadata
decode_meta
=
attn_metadata
.
decode_metadata
if
prefill_meta
is
None
and
decode_meta
.
use_cuda_graph
:
graph_batch_size
=
input_tokens
.
shape
[
0
]
model_executable
=
self
.
graph_runners
[
graph_batch_size
]
else
:
model_executable
=
self
.
model
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
kv_caches
=
[
None
]
*
num_layers
execute_model_kwargs
=
{
"input_ids"
:
input_tokens
,
"positions"
:
input_positions
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
attn_metadata
,
}
if
self
.
vision_language_config
:
execute_model_kwargs
.
update
({
"image_input"
:
multi_modal_input
})
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
return
self
.
model
.
pooler
(
hidden_states
=
hidden_states
,
pooling_metadata
=
pooling_metadata
)
def
prepare_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
PoolingMetadata
,
Set
[
LoRARequest
],
LoRAMapping
,
torch
.
Tensor
]:
if
self
.
is_driver_worker
:
prefill_reqs
=
[]
decode_reqs
=
[]
for
seq_group_meta
in
seq_group_metadata_list
:
if
seq_group_meta
.
is_prompt
:
prefill_reqs
.
append
(
seq_group_meta
)
else
:
decode_reqs
.
append
(
seq_group_meta
)
# Prepare input tensors.
(
input_tokens
,
input_positions
,
prefill_attn_metadata
,
prompt_lens
,
subquery_lens
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
,
multi_modal_input
,
slot_mapping
,
)
=
self
.
_prepare_prompt
(
prefill_reqs
)
(
decode_input_tokens
,
decode_input_positions
,
decode_attn_metadata
,
decode_lora_index_mapping
,
decode_lora_prompt_mapping
,
decode_lora_requests
,
decode_slot_mapping
,
)
=
self
.
_prepare_decode
(
decode_reqs
)
# Prepare PoolingMetadata
pooling_metadata
=
self
.
_prepare_pooling
(
seq_group_metadata_list
,
prompt_lens
)
if
not
self
.
scheduler_config
.
chunked_prefill_enabled
:
assert
(
len
(
prefill_reqs
)
and
len
(
decode_reqs
))
==
0
num_prefills
=
len
(
prompt_lens
)
num_prefill_tokens
=
len
(
input_tokens
)
num_decode_tokens
=
len
(
decode_input_tokens
)
# Coalesce tensors. Note that attn_metadata is currently not
# coalesced for simplicity.
input_tokens
.
extend
(
decode_input_tokens
)
input_positions
.
extend
(
decode_input_positions
)
slot_mapping
.
extend
(
decode_slot_mapping
)
lora_index_mapping
.
extend
(
decode_lora_index_mapping
)
lora_prompt_mapping
.
extend
(
decode_lora_prompt_mapping
)
lora_requests
.
update
(
decode_lora_requests
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
if
self
.
lora_config
:
lora_mapping
=
LoRAMapping
(
lora_index_mapping
,
lora_prompt_mapping
,
)
else
:
lora_mapping
=
None
# Broadcast the metadata.
# If batch contains both prefill and decode, it sends 2 broadcasts.
# If it only contains 1 type, it triggers a single broadcast.
if
(
prefill_attn_metadata
is
not
None
and
decode_attn_metadata
is
not
None
):
batch_type
=
BatchType
.
MIXED
elif
prefill_attn_metadata
is
not
None
:
batch_type
=
BatchType
.
PREFILL
else
:
batch_type
=
BatchType
.
DECODE
metadata_dict
=
{
"input_tokens"
:
input_tokens
,
"input_positions"
:
input_positions
,
"lora_requests"
:
lora_requests
,
"lora_mapping"
:
lora_mapping
,
"multi_modal_input"
:
multi_modal_input
,
"num_prefill_tokens"
:
num_prefill_tokens
,
"num_decode_tokens"
:
num_decode_tokens
,
"slot_mapping"
:
slot_mapping
,
"num_prefills"
:
num_prefills
,
"batch_type"
:
batch_type
,
}
if
prefill_attn_metadata
is
not
None
:
metadata_dict
.
update
(
prefill_attn_metadata
.
asdict_zerocopy
())
else
:
assert
decode_attn_metadata
is
not
None
metadata_dict
.
update
(
decode_attn_metadata
.
asdict_zerocopy
())
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
# Broadcast decode attn metadata for mixed batch type.
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
# We can potentially reduce the overhead by coelescing tensors.
if
batch_type
==
BatchType
.
MIXED
:
assert
decode_attn_metadata
is
not
None
metadata_dict
=
decode_attn_metadata
.
asdict_zerocopy
()
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
else
:
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
input_tokens
=
metadata_dict
.
pop
(
"input_tokens"
)
input_positions
=
metadata_dict
.
pop
(
"input_positions"
)
slot_mapping
=
metadata_dict
.
pop
(
"slot_mapping"
)
num_prefills
=
metadata_dict
.
pop
(
"num_prefills"
)
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
multi_modal_input
=
metadata_dict
.
pop
(
"multi_modal_input"
)
num_prefill_tokens
=
metadata_dict
.
pop
(
"num_prefill_tokens"
)
num_decode_tokens
=
metadata_dict
.
pop
(
"num_decode_tokens"
)
batch_type
=
metadata_dict
.
pop
(
"batch_type"
)
# Create an attention metadata.
prefill_attn_metadata
=
None
decode_attn_metadata
=
None
if
batch_type
==
BatchType
.
PREFILL
or
batch_type
==
BatchType
.
MIXED
:
prefill_attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
else
:
decode_attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
pooling_metadata
=
PoolingMetadata
(
seq_groups
=
None
,
seq_data
=
None
,
prompt_lens
=
None
)
# if it is a mixed batch, decode attn_metadata is broadcasted
# separately.
if
batch_type
==
BatchType
.
MIXED
:
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
decode_attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
attn_metadata
=
AttentionMetadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
slot_mapping
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
prefill_metadata
=
prefill_attn_metadata
,
decode_metadata
=
decode_attn_metadata
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
pooling_metadata
,
lora_requests
,
lora_mapping
,
multi_modal_input
)
def
_prepare_pooling
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt_lens
:
List
[
int
],
)
->
PoolingMetadata
:
"""Prepare PoolingMetadata for the sequence group metadata list."""
seq_groups
:
List
[
Tuple
[
List
[
int
],
PoolingParams
]]
=
[]
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
pooling_params
=
seq_group_metadata
.
pooling_params
seq_groups
.
append
((
seq_ids
,
pooling_params
))
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
pooling_metadata
=
PoolingMetadata
(
seq_groups
=
seq_groups
,
seq_data
=
seq_data
,
prompt_lens
=
prompt_lens
,
)
return
pooling_metadata
vllm/worker/model_runner.py
View file @
e254497b
import
time
import
time
from
enum
import
IntEnum
from
enum
import
IntEnum
from
typing
import
Dict
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -287,18 +287,18 @@ class ModelRunner:
...
@@ -287,18 +287,18 @@ class ModelRunner:
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
lora_index_mapping
+=
[
lora_id
]
*
(
seq_len
-
context_len
)
lora_index_mapping
+=
[
lora_id
]
*
(
seq_len
-
context_len
)
lora_prompt_mapping
.
extend
(
lora_prompt_mapping
.
extend
([
lora_id
]
*
(
[
lora_id
]
*
seq_len
-
context_len
if
seq_group_metadata
.
sampling_params
(
seq_len
-
context_len
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
if
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
if
seq_group_metadata
.
multi_modal_data
:
if
seq_group_metadata
.
multi_modal_data
:
multi_modal_input_list
.
append
(
multi_modal_input_list
.
append
(
seq_group_metadata
.
multi_modal_data
.
data
)
seq_group_metadata
.
multi_modal_data
.
data
)
if
seq_group_metadata
.
block_tables
is
None
:
if
_is_block_tables_empty
(
seq_group_metadata
.
block_tables
)
:
# During memory profiling, the block tables are not initialized
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
# yet. In this case, we just use a dummy slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping
.
extend
([
_PAD_SLOT_ID
]
*
seq_len
)
slot_mapping
.
extend
([
_PAD_SLOT_ID
]
*
seq_len
)
continue
continue
...
@@ -813,7 +813,6 @@ class ModelRunner:
...
@@ -813,7 +813,6 @@ class ModelRunner:
sampling_params
=
SamplingParams
(
top_p
=
0.99
,
top_k
=
self
.
vocab_size
-
1
)
sampling_params
=
SamplingParams
(
top_p
=
0.99
,
top_k
=
self
.
vocab_size
-
1
)
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
# This represents the maximum number of different requests
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
# consumption create dummy lora request copies from the lora request
...
@@ -1139,3 +1138,15 @@ def _prepare_fake_inputs(
...
@@ -1139,3 +1138,15 @@ def _prepare_fake_inputs(
prompt_tokens
=
[
0
]
*
seq_len
prompt_tokens
=
[
0
]
*
seq_len
fake_image_input
=
None
fake_image_input
=
None
return
SequenceData
(
prompt_tokens
),
fake_image_input
return
SequenceData
(
prompt_tokens
),
fake_image_input
def
_is_block_tables_empty
(
block_tables
:
Union
[
None
,
Dict
]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if
block_tables
is
None
:
return
True
if
isinstance
(
block_tables
,
dict
)
and
all
(
value
is
None
for
value
in
block_tables
.
values
()):
return
True
return
False
vllm/worker/worker.py
View file @
e254497b
"""A GPU worker class."""
"""A GPU worker class."""
import
gc
import
gc
import
os
import
os
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -16,8 +16,9 @@ from vllm.distributed.device_communicators.custom_all_reduce import (
...
@@ -16,8 +16,9 @@ from vllm.distributed.device_communicators.custom_all_reduce import (
init_custom_ar
)
init_custom_ar
)
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
,
SamplerOutput
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.embedding_model_runner
import
EmbeddingModelRunner
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.worker_base
import
WorkerBase
from
vllm.worker.worker_base
import
WorkerBase
...
@@ -68,7 +69,9 @@ class Worker(WorkerBase):
...
@@ -68,7 +69,9 @@ class Worker(WorkerBase):
assert
not
self
.
lora_config
,
(
assert
not
self
.
lora_config
,
(
"To be tested: vision language model with LoRA settings."
)
"To be tested: vision language model with LoRA settings."
)
self
.
model_runner
=
ModelRunner
(
ModelRunnerClass
=
(
EmbeddingModelRunner
if
self
.
model_config
.
embedding_mode
else
ModelRunner
)
self
.
model_runner
=
ModelRunnerClass
(
model_config
,
model_config
,
parallel_config
,
parallel_config
,
scheduler_config
,
scheduler_config
,
...
@@ -83,7 +86,8 @@ class Worker(WorkerBase):
...
@@ -83,7 +86,8 @@ class Worker(WorkerBase):
# Uninitialized cache engine. Will be initialized by
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
# initialize_cache.
self
.
cache_engine
:
CacheEngine
self
.
cache_engine
:
CacheEngine
self
.
gpu_cache
:
List
[
torch
.
Tensor
]
# Initialize gpu_cache as embedding models don't initialize kv_caches
self
.
gpu_cache
:
Optional
[
List
[
torch
.
tensor
]]
=
None
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
if
self
.
device_config
.
device
.
type
==
"cuda"
:
if
self
.
device_config
.
device
.
type
==
"cuda"
:
...
@@ -209,7 +213,7 @@ class Worker(WorkerBase):
...
@@ -209,7 +213,7 @@ class Worker(WorkerBase):
def
execute_model
(
def
execute_model
(
self
,
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
)
->
List
[
Union
[
SamplerOutput
,
PoolerOutput
]
]:
if
execute_model_req
is
None
:
if
execute_model_req
is
None
:
seq_group_metadata_list
=
None
seq_group_metadata_list
=
None
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment