Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
539aa992
Commit
539aa992
authored
Sep 27, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.2' into v0.6.2-dev
parents
93872128
7193774b
Changes
383
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
287 additions
and
160 deletions
+287
-160
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+47
-20
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+8
-9
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+33
-18
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+9
-8
vllm/envs.py
vllm/envs.py
+14
-3
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+1
-0
vllm/executor/multiproc_gpu_executor.py
vllm/executor/multiproc_gpu_executor.py
+0
-14
vllm/executor/multiproc_worker_utils.py
vllm/executor/multiproc_worker_utils.py
+18
-28
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+4
-2
vllm/executor/ray_tpu_executor.py
vllm/executor/ray_tpu_executor.py
+8
-2
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+6
-1
vllm/executor/tpu_executor.py
vllm/executor/tpu_executor.py
+11
-5
vllm/inputs/data.py
vllm/inputs/data.py
+6
-0
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+15
-7
vllm/inputs/registry.py
vllm/inputs/registry.py
+74
-22
vllm/lora/ops/bgmv_expand.py
vllm/lora/ops/bgmv_expand.py
+1
-1
vllm/lora/ops/bgmv_expand_slice.py
vllm/lora/ops/bgmv_expand_slice.py
+1
-1
vllm/lora/ops/sgmv_expand.py
vllm/lora/ops/sgmv_expand.py
+10
-6
vllm/lora/ops/sgmv_expand_slice.py
vllm/lora/ops/sgmv_expand_slice.py
+11
-7
vllm/lora/ops/sgmv_shrink.py
vllm/lora/ops/sgmv_shrink.py
+10
-6
No files found.
vllm/entrypoints/openai/serving_completion.py
View file @
539aa992
...
...
@@ -8,7 +8,7 @@ from typing import Tuple, Union, cast
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
Async
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
...
...
@@ -18,9 +18,12 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
ErrorResponse
,
UsageInfo
)
ErrorResponse
,
RequestResponseMetadata
,
UsageInfo
)
# yapf: enable
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
from
vllm.entrypoints.openai.serving_engine
import
(
BaseModelPath
,
LoRAModulePath
,
OpenAIServing
,
PromptAdapterPath
)
from
vllm.logger
import
init_logger
...
...
@@ -43,18 +46,18 @@ class OpenAIServingCompletion(OpenAIServing):
def
__init__
(
self
,
async_
engine_client
:
Async
EngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
se
rved
_model_
name
s
:
List
[
str
],
ba
se_model_
path
s
:
List
[
BaseModelPath
],
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
request_logger
:
Optional
[
RequestLogger
],
return_tokens_as_token_ids
:
bool
=
False
,
):
super
().
__init__
(
async_
engine_client
=
async_
engine_client
,
super
().
__init__
(
engine_client
=
engine_client
,
model_config
=
model_config
,
se
rved
_model_
names
=
served
_model_
name
s
,
ba
se_model_
paths
=
base
_model_
path
s
,
lora_modules
=
lora_modules
,
prompt_adapters
=
prompt_adapters
,
request_logger
=
request_logger
,
...
...
@@ -78,15 +81,25 @@ class OpenAIServingCompletion(OpenAIServing):
if
error_check_ret
is
not
None
:
return
error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if
self
.
engine_client
.
errored
:
raise
self
.
engine_client
.
dead_error
# Return error for unsupported features.
if
request
.
suffix
is
not
None
:
return
self
.
create_error_response
(
"suffix is not currently supported"
)
model_name
=
self
.
se
rved
_model_name
s
[
0
]
model_name
=
self
.
ba
se_model_
paths
[
0
].
name
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
created_time
=
int
(
time
.
time
())
request_metadata
=
RequestResponseMetadata
(
request_id
=
request_id
)
if
raw_request
:
raw_request
.
state
.
request_metadata
=
request_metadata
# Schedule the request and get the result generator.
generators
:
List
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
try
:
...
...
@@ -95,8 +108,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
guided_decode_logits_processor
=
(
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
...
...
@@ -124,8 +136,8 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
is_tracing_enabled
=
(
await
self
.
async_
engine_client
.
is_tracing_enabled
())
is_tracing_enabled
=
(
await
self
.
engine_client
.
is_tracing_enabled
())
trace_headers
=
None
if
is_tracing_enabled
:
trace_headers
=
extract_trace_headers
(
raw_request
.
headers
)
...
...
@@ -133,7 +145,7 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request
.
headers
):
log_tracing_disabled_warning
()
generator
=
self
.
async_
engine_client
.
generate
(
generator
=
self
.
engine_client
.
generate
(
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]},
sampling_params
,
request_id_item
,
...
...
@@ -159,13 +171,15 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response
if
stream
:
return
self
.
completion_stream_generator
(
request
,
result_generator
,
request_id
,
created_time
,
model_name
,
num_prompts
=
len
(
prompts
),
tokenizer
=
tokenizer
)
return
self
.
completion_stream_generator
(
request
,
result_generator
,
request_id
,
created_time
,
model_name
,
num_prompts
=
len
(
prompts
),
tokenizer
=
tokenizer
,
request_metadata
=
request_metadata
)
# Non-streaming response
final_res_batch
:
List
[
Optional
[
RequestOutput
]]
=
[
None
]
*
len
(
prompts
)
...
...
@@ -192,6 +206,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time
,
model_name
,
tokenizer
,
request_metadata
,
)
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
...
...
@@ -221,6 +236,7 @@ class OpenAIServingCompletion(OpenAIServing):
model_name
:
str
,
num_prompts
:
int
,
tokenizer
:
AnyTokenizer
,
request_metadata
:
RequestResponseMetadata
,
)
->
AsyncGenerator
[
str
,
None
]:
num_choices
=
1
if
request
.
n
is
None
else
request
.
n
previous_text_lens
=
[
0
]
*
num_choices
*
num_prompts
...
...
@@ -340,6 +356,14 @@ class OpenAIServingCompletion(OpenAIServing):
exclude_unset
=
False
,
exclude_none
=
True
))
yield
f
"data:
{
final_usage_data
}
\n\n
"
# report to FastAPI middleware aggregate usage across all choices
total_prompt_tokens
=
sum
(
num_prompt_tokens
)
total_completion_tokens
=
sum
(
previous_num_tokens
)
request_metadata
.
final_usage_info
=
UsageInfo
(
prompt_tokens
=
total_prompt_tokens
,
completion_tokens
=
total_completion_tokens
,
total_tokens
=
total_prompt_tokens
+
total_completion_tokens
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
data
=
self
.
create_streaming_error_response
(
str
(
e
))
...
...
@@ -354,6 +378,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time
:
int
,
model_name
:
str
,
tokenizer
:
AnyTokenizer
,
request_metadata
:
RequestResponseMetadata
,
)
->
CompletionResponse
:
choices
:
List
[
CompletionResponseChoice
]
=
[]
num_prompt_tokens
=
0
...
...
@@ -427,6 +452,8 @@ class OpenAIServingCompletion(OpenAIServing):
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
)
request_metadata
.
final_usage_info
=
usage
return
CompletionResponse
(
id
=
request_id
,
created
=
created_time
,
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
539aa992
...
...
@@ -8,13 +8,13 @@ from fastapi import Request
from
typing_extensions
import
assert_never
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
Async
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
EmbeddingRequest
,
EmbeddingResponse
,
EmbeddingResponseData
,
ErrorResponse
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
,
OpenAIServing
from
vllm.logger
import
init_logger
from
vllm.outputs
import
EmbeddingOutput
,
EmbeddingRequestOutput
from
vllm.utils
import
merge_async_iterators
,
random_uuid
...
...
@@ -71,15 +71,15 @@ class OpenAIServingEmbedding(OpenAIServing):
def
__init__
(
self
,
async_
engine_client
:
Async
EngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
se
rved
_model_
name
s
:
List
[
str
],
ba
se_model_
path
s
:
List
[
BaseModelPath
],
*
,
request_logger
:
Optional
[
RequestLogger
],
):
super
().
__init__
(
async_
engine_client
=
async_
engine_client
,
super
().
__init__
(
engine_client
=
engine_client
,
model_config
=
model_config
,
se
rved
_model_
names
=
served
_model_
name
s
,
ba
se_model_
paths
=
base
_model_
path
s
,
lora_modules
=
None
,
prompt_adapters
=
None
,
request_logger
=
request_logger
)
...
...
@@ -118,8 +118,7 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
pooling_params
=
request
.
to_pooling_params
()
...
...
@@ -144,7 +143,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"Prompt adapter is not supported "
"for embedding models"
)
generator
=
self
.
async_
engine_client
.
encode
(
generator
=
self
.
engine_client
.
encode
(
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]},
pooling_params
,
request_id_item
,
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
539aa992
...
...
@@ -8,7 +8,7 @@ from pydantic import Field
from
typing_extensions
import
Annotated
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
Async
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
...
...
@@ -39,6 +39,12 @@ from vllm.utils import AtomicCounter
logger
=
init_logger
(
__name__
)
@
dataclass
class
BaseModelPath
:
name
:
str
model_path
:
str
@
dataclass
class
PromptAdapterPath
:
name
:
str
...
...
@@ -49,6 +55,7 @@ class PromptAdapterPath:
class
LoRAModulePath
:
name
:
str
path
:
str
base_model_name
:
Optional
[
str
]
=
None
AnyRequest
=
Union
[
ChatCompletionRequest
,
CompletionRequest
,
DetokenizeRequest
,
...
...
@@ -64,9 +71,9 @@ class OpenAIServing:
def
__init__
(
self
,
async_
engine_client
:
Async
EngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
se
rved
_model_
name
s
:
List
[
str
],
ba
se_model_
path
s
:
List
[
BaseModelPath
],
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
...
...
@@ -75,21 +82,24 @@ class OpenAIServing:
):
super
().
__init__
()
self
.
async_
engine_client
=
async_
engine_client
self
.
engine_client
=
engine_client
self
.
model_config
=
model_config
self
.
max_model_len
=
model_config
.
max_model_len
self
.
se
rved
_model_
names
=
served
_model_
name
s
self
.
ba
se_model_
paths
=
base
_model_
path
s
self
.
lora_id_counter
=
AtomicCounter
(
0
)
self
.
lora_requests
=
[]
if
lora_modules
is
not
None
:
self
.
lora_requests
=
[
LoRARequest
(
lora_name
=
lora
.
name
,
lora_int_id
=
i
,
lora_path
=
lora
.
path
,
)
for
i
,
lora
in
enumerate
(
lora_modules
,
start
=
1
)
LoRARequest
(
lora_name
=
lora
.
name
,
lora_int_id
=
i
,
lora_path
=
lora
.
path
,
base_model_name
=
lora
.
base_model_name
if
lora
.
base_model_name
and
self
.
_is_model_supported
(
lora
.
base_model_name
)
else
self
.
base_model_paths
[
0
].
name
)
for
i
,
lora
in
enumerate
(
lora_modules
,
start
=
1
)
]
self
.
prompt_adapter_requests
=
[]
...
...
@@ -112,21 +122,23 @@ class OpenAIServing:
async
def
show_available_models
(
self
)
->
ModelList
:
"""Show available models. Right now we only have one model."""
model_cards
=
[
ModelCard
(
id
=
se
rved
_model
_
name
,
ModelCard
(
id
=
ba
se_model
.
name
,
max_model_len
=
self
.
max_model_len
,
root
=
self
.
served_model_names
[
0
]
,
root
=
base_model
.
model_path
,
permission
=
[
ModelPermission
()])
for
se
rved
_model
_name
in
self
.
se
rved
_model_
name
s
for
ba
se_model
in
self
.
ba
se_model_
path
s
]
lora_cards
=
[
ModelCard
(
id
=
lora
.
lora_name
,
root
=
self
.
served_model_names
[
0
],
root
=
lora
.
local_path
,
parent
=
lora
.
base_model_name
if
lora
.
base_model_name
else
self
.
base_model_paths
[
0
].
name
,
permission
=
[
ModelPermission
()])
for
lora
in
self
.
lora_requests
]
prompt_adapter_cards
=
[
ModelCard
(
id
=
prompt_adapter
.
prompt_adapter_name
,
root
=
self
.
se
rved
_model_name
s
[
0
]
,
root
=
self
.
ba
se_model_
paths
[
0
].
name
,
permission
=
[
ModelPermission
()])
for
prompt_adapter
in
self
.
prompt_adapter_requests
]
...
...
@@ -159,7 +171,7 @@ class OpenAIServing:
async
def
_guided_decode_logits_processor
(
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
tokenizer
:
AnyTokenizer
)
->
Optional
[
LogitsProcessor
]:
decoding_config
=
await
self
.
async_
engine_client
.
get_decoding_config
()
decoding_config
=
await
self
.
engine_client
.
get_decoding_config
()
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
return
await
get_guided_decoding_logits_processor
(
...
...
@@ -169,7 +181,7 @@ class OpenAIServing:
self
,
request
:
AnyRequest
,
)
->
Optional
[
ErrorResponse
]:
if
request
.
model
in
self
.
served_model_names
:
if
self
.
_is_model_supported
(
request
.
model
)
:
return
None
if
request
.
model
in
[
lora
.
lora_name
for
lora
in
self
.
lora_requests
]:
return
None
...
...
@@ -187,7 +199,7 @@ class OpenAIServing:
self
,
request
:
AnyRequest
)
->
Union
[
Tuple
[
None
,
None
],
Tuple
[
LoRARequest
,
None
],
Tuple
[
None
,
PromptAdapterRequest
]]:
if
request
.
model
in
self
.
served_model_names
:
if
self
.
_is_model_supported
(
request
.
model
)
:
return
None
,
None
for
lora
in
self
.
lora_requests
:
if
request
.
model
==
lora
.
lora_name
:
...
...
@@ -480,3 +492,6 @@ class OpenAIServing:
if
lora_request
.
lora_name
!=
lora_name
]
return
f
"Success: LoRA adapter '
{
lora_name
}
' removed successfully."
def
_is_model_supported
(
self
,
model_name
):
return
any
(
model
.
name
==
model_name
for
model
in
self
.
base_model_paths
)
vllm/entrypoints/openai/serving_tokenization.py
View file @
539aa992
from
typing
import
List
,
Optional
,
Union
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
Async
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
(
apply_hf_chat_template
,
apply_mistral_chat_template
,
load_chat_template
,
...
...
@@ -16,7 +16,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
TokenizeRequest
,
TokenizeResponse
)
# yapf: enable
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
from
vllm.entrypoints.openai.serving_engine
import
(
BaseModelPath
,
LoRAModulePath
,
OpenAIServing
)
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
MistralTokenizer
...
...
@@ -29,17 +30,17 @@ class OpenAIServingTokenization(OpenAIServing):
def
__init__
(
self
,
async_
engine_client
:
Async
EngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
se
rved
_model_
name
s
:
List
[
str
],
ba
se_model_
path
s
:
List
[
BaseModelPath
],
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
request_logger
:
Optional
[
RequestLogger
],
chat_template
:
Optional
[
str
],
):
super
().
__init__
(
async_
engine_client
=
async_
engine_client
,
super
().
__init__
(
engine_client
=
engine_client
,
model_config
=
model_config
,
se
rved
_model_
names
=
served
_model_
name
s
,
ba
se_model_
paths
=
base
_model_
path
s
,
lora_modules
=
lora_modules
,
prompt_adapters
=
None
,
request_logger
=
request_logger
)
...
...
@@ -66,7 +67,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
async_
engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
prompt
:
Union
[
str
,
List
[
int
]]
if
isinstance
(
request
,
TokenizeChatRequest
):
...
...
@@ -132,7 +133,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
async_
engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
self
.
_log_inputs
(
request_id
,
request
.
tokens
,
...
...
vllm/envs.py
View file @
539aa992
...
...
@@ -59,10 +59,12 @@ if TYPE_CHECKING:
VERBOSE
:
bool
=
False
VLLM_ALLOW_LONG_MAX_MODEL_LEN
:
bool
=
False
VLLM_TEST_FORCE_FP8_MARLIN
:
bool
=
False
VLLM_RPC_
GET_DATA_
TIMEOUT
_MS
:
int
=
5
000
VLLM_RPC_TIMEOUT
:
int
=
10
000
# ms
VLLM_PLUGINS
:
Optional
[
List
[
str
]]
=
None
VLLM_TORCH_PROFILER_DIR
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_AWQ
:
bool
=
False
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_ALLOW_DEPRECATED_BEAM_SEARCH
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -206,6 +208,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_PA_PRINT_PARAM"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# If set, allowing the use of deprecated beam search implementation
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH"
:
lambda
:
os
.
environ
.
get
(
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH"
,
"0"
)
==
"1"
,
# Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
,
"0"
)),
...
...
@@ -214,6 +220,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
(
os
.
environ
.
get
(
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# Internal flag to control whether we use custom op,
# or use the native pytorch implementation
"VLLM_TEST_COMPILE_NO_CUSTOM_OPS"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_TEST_COMPILE_NO_CUSTOM_OPS"
,
"0"
)),
# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
:
lambda
:
bool
(
...
...
@@ -399,8 +410,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Time in ms for the zmq client to wait for a response from the backend
# server for simple data operations
"VLLM_RPC_
GET_DATA_
TIMEOUT
_MS
"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_RPC_
GET_DATA_
TIMEOUT
_MS
"
,
"
5
000"
)),
"VLLM_RPC_TIMEOUT"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_RPC_TIMEOUT"
,
"
10
000"
)),
# a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded
...
...
vllm/executor/cpu_executor.py
View file @
539aa992
...
...
@@ -106,6 +106,7 @@ class CPUExecutor(ExecutorBase):
))
for
rank
in
range
(
1
,
world_size
)
]
self
.
worker_monitor
=
None
if
world_size
!=
1
or
is_async
:
if
is_async
:
async_worker_list
=
self
.
workers
+
[
self
.
driver_worker
]
...
...
vllm/executor/multiproc_gpu_executor.py
View file @
539aa992
import
asyncio
import
os
import
signal
import
threading
import
weakref
from
functools
import
partial
from
typing
import
Any
,
List
,
Optional
...
...
@@ -108,17 +105,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
# Set up signal handlers to shutdown the executor cleanly
# sometimes gc does not work well
# Use weakref to avoid holding a reference to self
ref
=
weakref
.
ref
(
self
)
def
shutdown
(
signum
,
frame
):
if
executor
:
=
ref
():
executor
.
shutdown
()
if
threading
.
current_thread
()
is
threading
.
main_thread
():
signal
.
signal
(
signal
.
SIGINT
,
shutdown
)
signal
.
signal
(
signal
.
SIGTERM
,
shutdown
)
self
.
driver_worker
=
self
.
_create_worker
(
distributed_init_method
=
distributed_init_method
)
self
.
_run_workers
(
"init_device"
)
...
...
vllm/executor/multiproc_worker_utils.py
View file @
539aa992
...
...
@@ -76,8 +76,7 @@ class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)"""
def
__init__
(
self
)
->
None
:
super
().
__init__
(
daemon
=
False
)
# super().__init__(daemon=True)
super
().
__init__
(
daemon
=
True
)
self
.
result_queue
=
mp
.
Queue
()
self
.
tasks
:
Dict
[
uuid
.
UUID
,
Union
[
ResultFuture
,
asyncio
.
Future
]]
=
{}
...
...
@@ -101,8 +100,7 @@ class WorkerMonitor(threading.Thread):
def
__init__
(
self
,
workers
:
List
[
'ProcessWorkerWrapper'
],
result_handler
:
ResultHandler
):
super
().
__init__
(
daemon
=
False
)
# super().__init__(daemon=True)
super
().
__init__
(
daemon
=
True
)
self
.
workers
=
workers
self
.
result_handler
=
result_handler
self
.
_close
=
False
...
...
@@ -114,30 +112,16 @@ class WorkerMonitor(threading.Thread):
self
.
_close
=
True
# Kill / cleanup all workers
# for worker in self.workers:
# process = worker.process
# if process.sentinel in dead_sentinels:
# process.join(JOIN_TIMEOUT_S)
# if process.exitcode is not None and process.exitcode != 0:
# logger.error("Worker %s pid %s died, exit code: %s",
# process.name, process.pid, process.exitcode)
if
not
sys
.
is_finalizing
():
# Kill / cleanup all workers
died_count
=
0
for
worker
in
self
.
workers
:
process
=
worker
.
process
if
process
.
sentinel
in
dead_sentinels
:
process
.
join
(
JOIN_TIMEOUT_S
)
if
process
.
exitcode
is
not
None
and
process
.
exitcode
!=
0
:
died_count
+=
1
logger
.
error
(
"Worker %s pid %s died, exit code: %s"
,
process
.
name
,
process
.
pid
,
process
.
exitcode
)
if
died_count
<
len
(
self
.
workers
):
logger
.
info
(
"Killing remaining local vLLM worker processes"
)
for
worker
in
self
.
workers
:
process
=
worker
.
process
if
process
.
sentinel
in
dead_sentinels
:
process
.
join
(
JOIN_TIMEOUT_S
)
if
process
.
exitcode
is
not
None
and
process
.
exitcode
!=
0
:
logger
.
error
(
"Worker %s pid %s died, exit code: %s"
,
process
.
name
,
process
.
pid
,
process
.
exitcode
)
# Cleanup any remaining workers
# logger.info("Killing local vLLM worker processes")
if
logger
:
logger
.
info
(
"Killing local vLLM worker processes"
)
for
worker
in
self
.
workers
:
worker
.
kill_worker
()
# Must be done after worker task queues are all closed
...
...
@@ -184,6 +168,8 @@ class ProcessWorkerWrapper:
self
.
tasks
[
task_id
]
=
future
try
:
self
.
_task_queue
.
put
((
task_id
,
method
,
args
,
kwargs
))
except
SystemExit
:
raise
except
BaseException
as
e
:
del
self
.
tasks
[
task_id
]
raise
ChildProcessError
(
"worker died"
)
from
e
...
...
@@ -238,6 +224,10 @@ def _run_worker_process(
try
:
executor
=
getattr
(
worker
,
method
)
output
=
executor
(
*
args
,
**
kwargs
)
except
SystemExit
:
raise
except
KeyboardInterrupt
:
break
except
BaseException
as
e
:
tb
=
traceback
.
format_exc
()
logger
.
error
(
...
...
@@ -278,4 +268,4 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
file
.
start_new_line
=
False
# type: ignore[attr-defined]
file
.
start_new_line
=
True
# type: ignore[attr-defined]
file
.
write
=
write_with_prefix
# type: ignore[method-assign]
file
.
write
=
write_with_prefix
# type: ignore[method-assign]
\ No newline at end of file
vllm/executor/ray_gpu_executor.py
View file @
539aa992
...
...
@@ -437,8 +437,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
required_version
=
version
.
parse
(
"2.35"
)
current_version
=
version
.
parse
(
pkg_resources
.
get_distribution
(
"ray"
).
version
)
if
current_version
<
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
or greater is "
# TODO: update the constraint once we adapt to the backward
# incompatible API change from ray 2.36
if
current_version
!=
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
is "
f
"required, but found
{
current_version
}
"
)
import
importlib.util
...
...
vllm/executor/ray_tpu_executor.py
View file @
539aa992
...
...
@@ -26,6 +26,8 @@ logger = init_logger(__name__)
class
RayTPUExecutor
(
TPUExecutor
):
uses_ray
:
bool
=
True
def
__init__
(
self
,
*
args
,
**
kwargs
):
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
...
...
@@ -68,8 +70,12 @@ class RayTPUExecutor(TPUExecutor):
)
assert
self
.
speculative_config
is
None
worker_module_name
=
"vllm.worker.tpu_worker"
worker_class_name
=
"TPUWorker"
if
self
.
scheduler_config
.
is_multi_step
:
worker_module_name
=
"vllm.worker.multi_step_tpu_worker"
worker_class_name
=
"MultiStepTPUWorker"
else
:
worker_module_name
=
"vllm.worker.tpu_worker"
worker_class_name
=
"TPUWorker"
# GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we
...
...
vllm/executor/ray_utils.py
View file @
539aa992
...
...
@@ -18,9 +18,14 @@ PG_WAIT_TIMEOUT = 1800
try
:
import
ray
from
ray._private.state
import
available_resources_per_node
from
ray.util
import
placement_group_table
from
ray.util.placement_group
import
PlacementGroup
try
:
from
ray._private.state
import
available_resources_per_node
except
ImportError
:
# Ray 2.9.x doesn't expose `available_resources_per_node`
from
ray._private.state
import
state
as
_state
available_resources_per_node
=
_state
.
_available_resources_per_node
class
RayWorkerWrapper
(
WorkerWrapperBase
):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
...
...
vllm/executor/tpu_executor.py
View file @
539aa992
...
...
@@ -62,11 +62,17 @@ class TPUExecutor(ExecutorBase):
rank
:
int
=
0
,
distributed_init_method
:
Optional
[
str
]
=
None
,
):
from
vllm.worker.tpu_worker
import
TPUWorker
worker
=
TPUWorker
(
**
self
.
_get_worker_kwargs
(
local_rank
,
rank
,
distributed_init_method
))
return
worker
if
self
.
scheduler_config
.
is_multi_step
:
from
vllm.worker.multi_step_tpu_worker
import
MultiStepTPUWorker
worker
=
MultiStepTPUWorker
(
**
self
.
_get_worker_kwargs
(
local_rank
,
rank
,
distributed_init_method
))
return
worker
else
:
from
vllm.worker.tpu_worker
import
TPUWorker
worker
=
TPUWorker
(
**
self
.
_get_worker_kwargs
(
local_rank
,
rank
,
distributed_init_method
))
return
worker
def
initialize_cache
(
self
,
...
...
vllm/inputs/data.py
View file @
539aa992
...
...
@@ -139,6 +139,12 @@ class EncoderDecoderLLMInputs(LLMInputs):
available.
"""
encoder_multi_modal_data
:
NotRequired
[
Optional
[
"MultiModalDataDict"
]]
"""
Optional multi-modal data to pass to the encoder model,
if the model supports it.
"""
_T1
=
TypeVar
(
"_T1"
,
bound
=
SingletonPromptInputs
,
...
...
vllm/inputs/preprocess.py
View file @
539aa992
...
...
@@ -128,6 +128,7 @@ class InputPreprocessor:
def
_prepare_decoder_input_ids_for_generation
(
self
,
decoder_input_ids
:
Optional
[
List
[
int
]],
force_bos
:
bool
=
True
,
)
->
List
[
int
]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
...
...
@@ -157,8 +158,8 @@ class InputPreprocessor:
# use decoder_start_token_id as decoder_input_ids
decoder_input_ids
=
self
.
_get_default_enc_dec_decoder_prompt
()
if
(
len
(
decoder_input_ids
)
==
0
or
decoder_input_ids
[
0
]
!=
decoder_start_token_id
):
if
force_bos
and
(
len
(
decoder_input_ids
)
==
0
or
decoder_input_ids
[
0
]
!=
decoder_start_token_id
):
decoder_input_ids
=
[
decoder_start_token_id
]
+
decoder_input_ids
return
decoder_input_ids
...
...
@@ -295,18 +296,25 @@ class InputPreprocessor:
encoder_prompt
,
encoder_prompt_ids
,
encoder_mm_data
=
encoder_comps
decoder_prompt
,
decoder_prompt_ids
,
decoder_mm_data
=
decoder_comps
if
encoder_mm_data
is
not
None
or
decoder_mm_data
is
not
None
:
raise
ValueError
(
"Multi-modal encoder-decoder models are "
"not supported yet"
)
if
decoder_mm_data
is
not
None
:
raise
ValueError
(
"Multi-modality decoder inputs of encoder-decoder models are "
"not supported yet"
)
decoder_prompt_ids
=
(
self
.
_prepare_decoder_input_ids_for_generation
(
decoder_prompt_ids
))
# For Multi-Modal models (e.g., mllama), the text input can be
# <|image|><|begin_of_text|>hello world. And we should not add
# another <|begin_of_text|> to the beginning.
decoder_prompt_ids
=
(
self
.
_prepare_decoder_input_ids_for_generation
(
decoder_prompt_ids
,
force_bos
=
(
encoder_mm_data
is
None
and
decoder_mm_data
is
None
)))
return
EncoderDecoderLLMInputs
(
prompt_token_ids
=
decoder_prompt_ids
,
prompt
=
decoder_prompt
,
multi_modal_data
=
decoder_mm_data
,
encoder_prompt_token_ids
=
encoder_prompt_ids
,
encoder_prompt
=
encoder_prompt
,
encoder_multi_modal_data
=
encoder_mm_data
,
)
def
_process_encoder_decoder_prompt
(
...
...
vllm/inputs/registry.py
View file @
539aa992
import
functools
from
array
import
array
from
collections
import
UserDict
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Mapping
,
Optional
,
...
...
@@ -10,6 +9,7 @@ from transformers import PretrainedConfig
from
typing_extensions
import
TypeVar
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_allowed_kwarg_only_overrides
from
.data
import
LLMInputs
...
...
@@ -22,10 +22,6 @@ logger = init_logger(__name__)
C
=
TypeVar
(
"C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
# We cannot import it here because of circular dependencies.
VLLM_TOKEN_ID_ARRAY_TYPE
=
"l"
@
dataclass
(
frozen
=
True
)
class
InputContext
:
...
...
@@ -73,12 +69,17 @@ class DummyDataFactory(Protocol):
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
**
mm_processor_kwargs
:
Any
,
)
->
Tuple
[
"SequenceData"
,
Optional
[
"MultiModalDataDict"
]]:
"""
Create dummy data to be inputted into the model.
Note:
:data:`InputProcessor` is not applied to the dummy data.
The :code:`mm_processor_kwargs` are overrides provided at
initialization time to values in the config whose values
may affect the number of tokens per instance.
"""
...
...
...
@@ -111,6 +112,8 @@ class InputRegistry:
def
__init__
(
self
)
->
None
:
self
.
_dummy_factories_by_model_type
:
Dict
[
Type
[
nn
.
Module
],
DummyDataFactory
]
=
{}
self
.
_dummy_encoder_factories_by_model_type
:
Dict
[
Type
[
nn
.
Module
],
DummyDataFactory
]
=
{}
self
.
_input_processors_by_model_type
:
Dict
[
Type
[
nn
.
Module
],
InputProcessor
]
=
{}
...
...
@@ -130,8 +133,7 @@ class InputRegistry:
# Avoid circular import
from
vllm.sequence
import
SequenceData
dummy_seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
seq_len
)
dummy_seq_data
=
SequenceData
.
from_token_counts
((
0
,
seq_len
))
dummy_multi_modal_data
=
None
return
dummy_seq_data
,
dummy_multi_modal_data
...
...
@@ -158,11 +160,48 @@ class InputRegistry:
return
wrapper
def
_get_dummy_data_factory
(
self
,
model_cls
:
Type
[
nn
.
Module
]):
return
self
.
_dummy_factories_by_model_type
\
.
get
(
model_cls
,
self
.
_default_dummy_data_factory
)
def
register_dummy_encoder_data
(
self
,
factory
:
DummyDataFactory
):
"""
Register a dummy encoder data factory to a model class
This is similar to :meth:`~register_dummy_data`, but for encoder input.
"""
def
wrapper
(
model_cls
:
N
)
->
N
:
if
model_cls
in
self
.
_dummy_encoder_factories_by_model_type
:
logger
.
warning
(
"Model class %s already has dummy encoder data "
"registered to %s. It is overwritten by the new one."
,
model_cls
,
self
)
self
.
_dummy_encoder_factories_by_model_type
[
model_cls
]
=
factory
return
model_cls
return
wrapper
def
_get_dummy_encoder_data_factory
(
self
,
model_cls
:
Type
[
nn
.
Module
]):
if
model_cls
in
self
.
_dummy_encoder_factories_by_model_type
:
dummy_factory
=
self
.
_dummy_encoder_factories_by_model_type
[
model_cls
]
else
:
logger
.
warning
(
"No dummy encoder data factory registered to %s. "
"Using the dummy data factory for the model instead."
,
model_cls
)
dummy_factory
=
self
.
_get_dummy_data_factory
(
model_cls
)
return
dummy_factory
def
dummy_data_for_profiling
(
self
,
model_config
:
"ModelConfig"
,
seq_len
:
int
,
mm_registry
:
"MultiModalRegistry"
,
is_encoder_data
:
bool
=
False
,
)
->
Tuple
[
"SequenceData"
,
Optional
[
"MultiModalDataDict"
]]:
"""
Create dummy data for profiling the memory usage of a model.
...
...
@@ -180,22 +219,29 @@ class InputRegistry:
from
vllm.model_executor.model_loader
import
get_model_architecture
model_cls
,
_
=
get_model_architecture
(
model_config
)
dummy_factory
=
self
.
_dummy_factories_by_model_type
\
.
get
(
model_cls
,
self
.
_default_dummy_data_factory
)
if
is_encoder_data
:
dummy_factory
=
self
.
_get_dummy_encoder_data_factory
(
model_cls
)
else
:
dummy_factory
=
self
.
_get_dummy_data_factory
(
model_cls
)
mm_counts
=
mm_registry
.
get_mm_limits_per_prompt
(
model_config
)
mm_processor_kwargs
=
get_allowed_kwarg_only_overrides
(
dummy_factory
,
overrides
=
model_config
.
mm_processor_kwargs
)
seq_data
,
mm_data
=
dummy_factory
(
InputContext
(
model_config
),
seq_len
,
_MultiModalCounts
(
mm_counts
),
)
seq_data
,
mm_data
=
dummy_factory
(
InputContext
(
model_config
),
seq_len
,
_MultiModalCounts
(
mm_counts
),
**
mm_processor_kwargs
)
# Having more tokens is over-conservative but otherwise fine
num_tokens
=
seq_data
.
prompt_token_ids
assert
len
(
num_tokens
)
>=
seq_len
,
(
f
"Expected at least
{
seq_len
}
dummy tokens for profiling, "
f
"but found
{
len
(
num_tokens
)
}
tokens instead."
)
if
len
(
num_tokens
)
<
seq_len
:
if
is_encoder_data
:
logger
.
warning
(
"Expected at least %d dummy encoder tokens for profiling, "
"but found %d tokens instead."
,
seq_len
,
len
(
num_tokens
))
else
:
raise
AssertionError
(
f
"Expected at least
{
seq_len
}
dummy tokens for profiling, "
f
"but found
{
len
(
num_tokens
)
}
tokens instead."
)
if
mm_data
is
not
None
:
for
k
,
v
in
mm_data
.
items
():
num_items
=
len
(
v
)
if
isinstance
(
v
,
list
)
else
1
...
...
@@ -235,6 +281,10 @@ class InputRegistry:
return
wrapper
def
_get_model_input_processor
(
self
,
model_cls
:
Type
[
nn
.
Module
]):
return
self
.
_input_processors_by_model_type
\
.
get
(
model_cls
,
self
.
_default_input_processor
)
def
process_input
(
self
,
model_config
:
"ModelConfig"
,
inputs
:
LLMInputs
)
->
LLMInputs
:
"""
...
...
@@ -249,15 +299,17 @@ class InputRegistry:
from
vllm.model_executor.model_loader
import
get_model_architecture
model_cls
,
_
=
get_model_architecture
(
model_config
)
processor
=
self
.
_get_model_input_processor
(
model_cls
)
processor
=
self
.
_input_processors_by_model_type
\
.
get
(
model_cls
,
self
.
_default_input
_processor
)
mm_
processor
_kwargs
=
get_allowed_kwarg_only_overrides
(
processor
,
overrides
=
model_config
.
mm
_processor
_kwargs
)
return
processor
(
InputContext
(
model_config
),
inputs
)
return
processor
(
InputContext
(
model_config
),
inputs
,
**
mm_processor_kwargs
)
def
create_input_processor
(
self
,
model_config
:
"ModelConfig"
):
"""
Create an input processor (see :meth:`process_input`) for a
Create an input processor (see :meth:`
_
process_input`) for a
specific model.
"""
return
functools
.
partial
(
self
.
process_input
,
model_config
)
vllm/lora/ops/bgmv_expand.py
View file @
539aa992
...
...
@@ -100,7 +100,7 @@ def _bgmv_expand(
corresponding to each batch, An index of -1 means no lora should be
applied.
batches (int): batch size
add_inputs (bool, optional): Defaults to False
.
adds the final lora
add_inputs (bool, optional): Defaults to False
,
adds the final lora
results to the output.
"""
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
...
...
vllm/lora/ops/bgmv_expand_slice.py
View file @
539aa992
...
...
@@ -104,7 +104,7 @@ def _bgmv_expand_slice(
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be
applied.
slice_offst (int): output_tensor's offst
slice_offs
e
t (int): output_tensor's offs
e
t
slice_size (int): current output_tensor's size
batches (int): batch size
add_inputs (bool, optional): Defaults to False.
...
...
vllm/lora/ops/sgmv_expand.py
View file @
539aa992
...
...
@@ -106,6 +106,7 @@ def _sgmv_expand(
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
add_inputs
:
bool
=
False
,
)
->
None
:
"""
...
...
@@ -115,17 +116,19 @@ def _sgmv_expand(
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g.,
if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,).
r
ecord the sequence
length of the sequences
in the batch
seq_len_tensor (torch.Tensor): (batch_size,).
R
ecord the sequence
length of the sequences in the batch
.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
add_inputs (bool, optional): Defaults to False. adds the final lora
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
...
...
@@ -134,6 +137,7 @@ def _sgmv_expand(
torch
.
float16
,
torch
.
bfloat16
,
]
assert
inputs
.
size
(
0
)
==
token_nums
assert
inputs
.
size
(
1
)
==
lora_b_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
...
...
vllm/lora/ops/sgmv_expand_slice.py
View file @
539aa992
...
...
@@ -112,6 +112,7 @@ def _sgmv_expand_slice(
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
slice_offset
:
int
,
slice_size
:
int
,
add_inputs
:
bool
=
False
,
...
...
@@ -124,20 +125,22 @@ def _sgmv_expand_slice(
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g.,
if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,).
r
ecord the sequence
length of the sequences
in the batch
seq_len_tensor (torch.Tensor): (batch_size,).
R
ecord the sequence
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int):
The max sequence lengths of the sequences
max_seq_length (int): The max sequence lengths of the sequences
in the batch
slice_offst (int): output_tensor's offst
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
add_inputs (bool, optional):
Defaults to False
.
adds the final lora
results to the output.
.
add_inputs (bool, optional): Defaults to False
,
adds the final lora
results to the output.
"""
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
...
...
@@ -145,6 +148,7 @@ def _sgmv_expand_slice(
torch
.
float16
,
torch
.
bfloat16
,
]
assert
inputs
.
size
(
0
)
==
token_nums
assert
inputs
.
size
(
1
)
==
lora_b_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
...
...
vllm/lora/ops/sgmv_shrink.py
View file @
539aa992
...
...
@@ -110,6 +110,7 @@ def _sgmv_shrink(
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
scaling
:
float
,
)
->
None
:
"""
...
...
@@ -120,17 +121,19 @@ def _sgmv_shrink(
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g.,
if the sequence length is [4, 6], it is
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,).
r
ecord the sequence
length of the sequences
in the batch
seq_len_tensor (torch.Tensor): (batch_size,).
R
ecord the sequence
length of the sequences in the batch
.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
scaling (float): Scaling factor.
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
scaling (float): Scaling factor.
"""
assert
inputs
.
dtype
==
lora_a_weights
.
dtype
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
...
...
@@ -138,6 +141,7 @@ def _sgmv_shrink(
torch
.
float16
,
torch
.
bfloat16
,
]
assert
inputs
.
size
(
0
)
==
token_nums
assert
inputs
.
size
(
1
)
==
lora_a_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
...
...
Prev
1
…
9
10
11
12
13
14
15
16
17
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment