Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
539aa992
Commit
539aa992
authored
Sep 27, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.2' into v0.6.2-dev
parents
93872128
7193774b
Changes
383
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
287 additions
and
160 deletions
+287
-160
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+47
-20
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+8
-9
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+33
-18
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+9
-8
vllm/envs.py
vllm/envs.py
+14
-3
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+1
-0
vllm/executor/multiproc_gpu_executor.py
vllm/executor/multiproc_gpu_executor.py
+0
-14
vllm/executor/multiproc_worker_utils.py
vllm/executor/multiproc_worker_utils.py
+18
-28
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+4
-2
vllm/executor/ray_tpu_executor.py
vllm/executor/ray_tpu_executor.py
+8
-2
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+6
-1
vllm/executor/tpu_executor.py
vllm/executor/tpu_executor.py
+11
-5
vllm/inputs/data.py
vllm/inputs/data.py
+6
-0
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+15
-7
vllm/inputs/registry.py
vllm/inputs/registry.py
+74
-22
vllm/lora/ops/bgmv_expand.py
vllm/lora/ops/bgmv_expand.py
+1
-1
vllm/lora/ops/bgmv_expand_slice.py
vllm/lora/ops/bgmv_expand_slice.py
+1
-1
vllm/lora/ops/sgmv_expand.py
vllm/lora/ops/sgmv_expand.py
+10
-6
vllm/lora/ops/sgmv_expand_slice.py
vllm/lora/ops/sgmv_expand_slice.py
+11
-7
vllm/lora/ops/sgmv_shrink.py
vllm/lora/ops/sgmv_shrink.py
+10
-6
No files found.
vllm/entrypoints/openai/serving_completion.py
View file @
539aa992
...
@@ -8,7 +8,7 @@ from typing import Tuple, Union, cast
...
@@ -8,7 +8,7 @@ from typing import Tuple, Union, cast
from
fastapi
import
Request
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
Async
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
...
@@ -18,9 +18,12 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
...
@@ -18,9 +18,12 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
CompletionStreamResponse
,
ErrorResponse
,
UsageInfo
)
ErrorResponse
,
RequestResponseMetadata
,
UsageInfo
)
# yapf: enable
# yapf: enable
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
from
vllm.entrypoints.openai.serving_engine
import
(
BaseModelPath
,
LoRAModulePath
,
OpenAIServing
,
OpenAIServing
,
PromptAdapterPath
)
PromptAdapterPath
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -43,18 +46,18 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -43,18 +46,18 @@ class OpenAIServingCompletion(OpenAIServing):
def
__init__
(
def
__init__
(
self
,
self
,
async_
engine_client
:
Async
EngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
se
rved
_model_
name
s
:
List
[
str
],
ba
se_model_
path
s
:
List
[
BaseModelPath
],
*
,
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
request_logger
:
Optional
[
RequestLogger
],
request_logger
:
Optional
[
RequestLogger
],
return_tokens_as_token_ids
:
bool
=
False
,
return_tokens_as_token_ids
:
bool
=
False
,
):
):
super
().
__init__
(
async_
engine_client
=
async_
engine_client
,
super
().
__init__
(
engine_client
=
engine_client
,
model_config
=
model_config
,
model_config
=
model_config
,
se
rved
_model_
names
=
served
_model_
name
s
,
ba
se_model_
paths
=
base
_model_
path
s
,
lora_modules
=
lora_modules
,
lora_modules
=
lora_modules
,
prompt_adapters
=
prompt_adapters
,
prompt_adapters
=
prompt_adapters
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
...
@@ -78,15 +81,25 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -78,15 +81,25 @@ class OpenAIServingCompletion(OpenAIServing):
if
error_check_ret
is
not
None
:
if
error_check_ret
is
not
None
:
return
error_check_ret
return
error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if
self
.
engine_client
.
errored
:
raise
self
.
engine_client
.
dead_error
# Return error for unsupported features.
# Return error for unsupported features.
if
request
.
suffix
is
not
None
:
if
request
.
suffix
is
not
None
:
return
self
.
create_error_response
(
return
self
.
create_error_response
(
"suffix is not currently supported"
)
"suffix is not currently supported"
)
model_name
=
self
.
se
rved
_model_name
s
[
0
]
model_name
=
self
.
ba
se_model_
paths
[
0
].
name
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
request_id
=
f
"cmpl-
{
random_uuid
()
}
"
created_time
=
int
(
time
.
time
())
created_time
=
int
(
time
.
time
())
request_metadata
=
RequestResponseMetadata
(
request_id
=
request_id
)
if
raw_request
:
raw_request
.
state
.
request_metadata
=
request_metadata
# Schedule the request and get the result generator.
# Schedule the request and get the result generator.
generators
:
List
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
generators
:
List
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
try
:
try
:
...
@@ -95,8 +108,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -95,8 +108,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_adapter_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
lora_request
)
guided_decode_logits_processor
=
(
guided_decode_logits_processor
=
(
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
...
@@ -124,8 +136,8 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -124,8 +136,8 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
)
is_tracing_enabled
=
(
is_tracing_enabled
=
(
await
await
self
.
async_
engine_client
.
is_tracing_enabled
())
self
.
engine_client
.
is_tracing_enabled
())
trace_headers
=
None
trace_headers
=
None
if
is_tracing_enabled
:
if
is_tracing_enabled
:
trace_headers
=
extract_trace_headers
(
raw_request
.
headers
)
trace_headers
=
extract_trace_headers
(
raw_request
.
headers
)
...
@@ -133,7 +145,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -133,7 +145,7 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request
.
headers
):
raw_request
.
headers
):
log_tracing_disabled_warning
()
log_tracing_disabled_warning
()
generator
=
self
.
async_
engine_client
.
generate
(
generator
=
self
.
engine_client
.
generate
(
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]},
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]},
sampling_params
,
sampling_params
,
request_id_item
,
request_id_item
,
...
@@ -159,13 +171,15 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -159,13 +171,15 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response
# Streaming response
if
stream
:
if
stream
:
return
self
.
completion_stream_generator
(
request
,
return
self
.
completion_stream_generator
(
result_generator
,
request
,
request_id
,
result_generator
,
created_time
,
request_id
,
model_name
,
created_time
,
num_prompts
=
len
(
prompts
),
model_name
,
tokenizer
=
tokenizer
)
num_prompts
=
len
(
prompts
),
tokenizer
=
tokenizer
,
request_metadata
=
request_metadata
)
# Non-streaming response
# Non-streaming response
final_res_batch
:
List
[
Optional
[
RequestOutput
]]
=
[
None
]
*
len
(
prompts
)
final_res_batch
:
List
[
Optional
[
RequestOutput
]]
=
[
None
]
*
len
(
prompts
)
...
@@ -192,6 +206,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -192,6 +206,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time
,
created_time
,
model_name
,
model_name
,
tokenizer
,
tokenizer
,
request_metadata
,
)
)
except
asyncio
.
CancelledError
:
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
return
self
.
create_error_response
(
"Client disconnected"
)
...
@@ -221,6 +236,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -221,6 +236,7 @@ class OpenAIServingCompletion(OpenAIServing):
model_name
:
str
,
model_name
:
str
,
num_prompts
:
int
,
num_prompts
:
int
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
request_metadata
:
RequestResponseMetadata
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
num_choices
=
1
if
request
.
n
is
None
else
request
.
n
num_choices
=
1
if
request
.
n
is
None
else
request
.
n
previous_text_lens
=
[
0
]
*
num_choices
*
num_prompts
previous_text_lens
=
[
0
]
*
num_choices
*
num_prompts
...
@@ -340,6 +356,14 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -340,6 +356,14 @@ class OpenAIServingCompletion(OpenAIServing):
exclude_unset
=
False
,
exclude_none
=
True
))
exclude_unset
=
False
,
exclude_none
=
True
))
yield
f
"data:
{
final_usage_data
}
\n\n
"
yield
f
"data:
{
final_usage_data
}
\n\n
"
# report to FastAPI middleware aggregate usage across all choices
total_prompt_tokens
=
sum
(
num_prompt_tokens
)
total_completion_tokens
=
sum
(
previous_num_tokens
)
request_metadata
.
final_usage_info
=
UsageInfo
(
prompt_tokens
=
total_prompt_tokens
,
completion_tokens
=
total_completion_tokens
,
total_tokens
=
total_prompt_tokens
+
total_completion_tokens
)
except
ValueError
as
e
:
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
# TODO: Use a vllm-specific Validation Error
data
=
self
.
create_streaming_error_response
(
str
(
e
))
data
=
self
.
create_streaming_error_response
(
str
(
e
))
...
@@ -354,6 +378,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -354,6 +378,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time
:
int
,
created_time
:
int
,
model_name
:
str
,
model_name
:
str
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
request_metadata
:
RequestResponseMetadata
,
)
->
CompletionResponse
:
)
->
CompletionResponse
:
choices
:
List
[
CompletionResponseChoice
]
=
[]
choices
:
List
[
CompletionResponseChoice
]
=
[]
num_prompt_tokens
=
0
num_prompt_tokens
=
0
...
@@ -427,6 +452,8 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -427,6 +452,8 @@ class OpenAIServingCompletion(OpenAIServing):
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
total_tokens
=
num_prompt_tokens
+
num_generated_tokens
,
)
)
request_metadata
.
final_usage_info
=
usage
return
CompletionResponse
(
return
CompletionResponse
(
id
=
request_id
,
id
=
request_id
,
created
=
created_time
,
created
=
created_time
,
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
539aa992
...
@@ -8,13 +8,13 @@ from fastapi import Request
...
@@ -8,13 +8,13 @@ from fastapi import Request
from
typing_extensions
import
assert_never
from
typing_extensions
import
assert_never
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
Async
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
EmbeddingRequest
,
from
vllm.entrypoints.openai.protocol
import
(
EmbeddingRequest
,
EmbeddingResponse
,
EmbeddingResponse
,
EmbeddingResponseData
,
EmbeddingResponseData
,
ErrorResponse
,
UsageInfo
)
ErrorResponse
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
,
OpenAIServing
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
EmbeddingOutput
,
EmbeddingRequestOutput
from
vllm.outputs
import
EmbeddingOutput
,
EmbeddingRequestOutput
from
vllm.utils
import
merge_async_iterators
,
random_uuid
from
vllm.utils
import
merge_async_iterators
,
random_uuid
...
@@ -71,15 +71,15 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -71,15 +71,15 @@ class OpenAIServingEmbedding(OpenAIServing):
def
__init__
(
def
__init__
(
self
,
self
,
async_
engine_client
:
Async
EngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
se
rved
_model_
name
s
:
List
[
str
],
ba
se_model_
path
s
:
List
[
BaseModelPath
],
*
,
*
,
request_logger
:
Optional
[
RequestLogger
],
request_logger
:
Optional
[
RequestLogger
],
):
):
super
().
__init__
(
async_
engine_client
=
async_
engine_client
,
super
().
__init__
(
engine_client
=
engine_client
,
model_config
=
model_config
,
model_config
=
model_config
,
se
rved
_model_
names
=
served
_model_
name
s
,
ba
se_model_
paths
=
base
_model_
path
s
,
lora_modules
=
None
,
lora_modules
=
None
,
prompt_adapters
=
None
,
prompt_adapters
=
None
,
request_logger
=
request_logger
)
request_logger
=
request_logger
)
...
@@ -118,8 +118,7 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -118,8 +118,7 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_adapter_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
lora_request
)
pooling_params
=
request
.
to_pooling_params
()
pooling_params
=
request
.
to_pooling_params
()
...
@@ -144,7 +143,7 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -144,7 +143,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"Prompt adapter is not supported "
"Prompt adapter is not supported "
"for embedding models"
)
"for embedding models"
)
generator
=
self
.
async_
engine_client
.
encode
(
generator
=
self
.
engine_client
.
encode
(
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]},
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]},
pooling_params
,
pooling_params
,
request_id_item
,
request_id_item
,
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
539aa992
...
@@ -8,7 +8,7 @@ from pydantic import Field
...
@@ -8,7 +8,7 @@ from pydantic import Field
from
typing_extensions
import
Annotated
from
typing_extensions
import
Annotated
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
Async
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
...
@@ -39,6 +39,12 @@ from vllm.utils import AtomicCounter
...
@@ -39,6 +39,12 @@ from vllm.utils import AtomicCounter
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
dataclass
class
BaseModelPath
:
name
:
str
model_path
:
str
@
dataclass
@
dataclass
class
PromptAdapterPath
:
class
PromptAdapterPath
:
name
:
str
name
:
str
...
@@ -49,6 +55,7 @@ class PromptAdapterPath:
...
@@ -49,6 +55,7 @@ class PromptAdapterPath:
class
LoRAModulePath
:
class
LoRAModulePath
:
name
:
str
name
:
str
path
:
str
path
:
str
base_model_name
:
Optional
[
str
]
=
None
AnyRequest
=
Union
[
ChatCompletionRequest
,
CompletionRequest
,
DetokenizeRequest
,
AnyRequest
=
Union
[
ChatCompletionRequest
,
CompletionRequest
,
DetokenizeRequest
,
...
@@ -64,9 +71,9 @@ class OpenAIServing:
...
@@ -64,9 +71,9 @@ class OpenAIServing:
def
__init__
(
def
__init__
(
self
,
self
,
async_
engine_client
:
Async
EngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
se
rved
_model_
name
s
:
List
[
str
],
ba
se_model_
path
s
:
List
[
BaseModelPath
],
*
,
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
prompt_adapters
:
Optional
[
List
[
PromptAdapterPath
]],
...
@@ -75,21 +82,24 @@ class OpenAIServing:
...
@@ -75,21 +82,24 @@ class OpenAIServing:
):
):
super
().
__init__
()
super
().
__init__
()
self
.
async_
engine_client
=
async_
engine_client
self
.
engine_client
=
engine_client
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
max_model_len
=
model_config
.
max_model_len
self
.
max_model_len
=
model_config
.
max_model_len
self
.
se
rved
_model_
names
=
served
_model_
name
s
self
.
ba
se_model_
paths
=
base
_model_
path
s
self
.
lora_id_counter
=
AtomicCounter
(
0
)
self
.
lora_id_counter
=
AtomicCounter
(
0
)
self
.
lora_requests
=
[]
self
.
lora_requests
=
[]
if
lora_modules
is
not
None
:
if
lora_modules
is
not
None
:
self
.
lora_requests
=
[
self
.
lora_requests
=
[
LoRARequest
(
LoRARequest
(
lora_name
=
lora
.
name
,
lora_name
=
lora
.
name
,
lora_int_id
=
i
,
lora_int_id
=
i
,
lora_path
=
lora
.
path
,
lora_path
=
lora
.
path
,
base_model_name
=
lora
.
base_model_name
)
for
i
,
lora
in
enumerate
(
lora_modules
,
start
=
1
)
if
lora
.
base_model_name
and
self
.
_is_model_supported
(
lora
.
base_model_name
)
else
self
.
base_model_paths
[
0
].
name
)
for
i
,
lora
in
enumerate
(
lora_modules
,
start
=
1
)
]
]
self
.
prompt_adapter_requests
=
[]
self
.
prompt_adapter_requests
=
[]
...
@@ -112,21 +122,23 @@ class OpenAIServing:
...
@@ -112,21 +122,23 @@ class OpenAIServing:
async
def
show_available_models
(
self
)
->
ModelList
:
async
def
show_available_models
(
self
)
->
ModelList
:
"""Show available models. Right now we only have one model."""
"""Show available models. Right now we only have one model."""
model_cards
=
[
model_cards
=
[
ModelCard
(
id
=
se
rved
_model
_
name
,
ModelCard
(
id
=
ba
se_model
.
name
,
max_model_len
=
self
.
max_model_len
,
max_model_len
=
self
.
max_model_len
,
root
=
self
.
served_model_names
[
0
]
,
root
=
base_model
.
model_path
,
permission
=
[
ModelPermission
()])
permission
=
[
ModelPermission
()])
for
se
rved
_model
_name
in
self
.
se
rved
_model_
name
s
for
ba
se_model
in
self
.
ba
se_model_
path
s
]
]
lora_cards
=
[
lora_cards
=
[
ModelCard
(
id
=
lora
.
lora_name
,
ModelCard
(
id
=
lora
.
lora_name
,
root
=
self
.
served_model_names
[
0
],
root
=
lora
.
local_path
,
parent
=
lora
.
base_model_name
if
lora
.
base_model_name
else
self
.
base_model_paths
[
0
].
name
,
permission
=
[
ModelPermission
()])
permission
=
[
ModelPermission
()])
for
lora
in
self
.
lora_requests
for
lora
in
self
.
lora_requests
]
]
prompt_adapter_cards
=
[
prompt_adapter_cards
=
[
ModelCard
(
id
=
prompt_adapter
.
prompt_adapter_name
,
ModelCard
(
id
=
prompt_adapter
.
prompt_adapter_name
,
root
=
self
.
se
rved
_model_name
s
[
0
]
,
root
=
self
.
ba
se_model_
paths
[
0
].
name
,
permission
=
[
ModelPermission
()])
permission
=
[
ModelPermission
()])
for
prompt_adapter
in
self
.
prompt_adapter_requests
for
prompt_adapter
in
self
.
prompt_adapter_requests
]
]
...
@@ -159,7 +171,7 @@ class OpenAIServing:
...
@@ -159,7 +171,7 @@ class OpenAIServing:
async
def
_guided_decode_logits_processor
(
async
def
_guided_decode_logits_processor
(
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
],
tokenizer
:
AnyTokenizer
)
->
Optional
[
LogitsProcessor
]:
tokenizer
:
AnyTokenizer
)
->
Optional
[
LogitsProcessor
]:
decoding_config
=
await
self
.
async_
engine_client
.
get_decoding_config
()
decoding_config
=
await
self
.
engine_client
.
get_decoding_config
()
guided_decoding_backend
=
request
.
guided_decoding_backend
\
guided_decoding_backend
=
request
.
guided_decoding_backend
\
or
decoding_config
.
guided_decoding_backend
or
decoding_config
.
guided_decoding_backend
return
await
get_guided_decoding_logits_processor
(
return
await
get_guided_decoding_logits_processor
(
...
@@ -169,7 +181,7 @@ class OpenAIServing:
...
@@ -169,7 +181,7 @@ class OpenAIServing:
self
,
self
,
request
:
AnyRequest
,
request
:
AnyRequest
,
)
->
Optional
[
ErrorResponse
]:
)
->
Optional
[
ErrorResponse
]:
if
request
.
model
in
self
.
served_model_names
:
if
self
.
_is_model_supported
(
request
.
model
)
:
return
None
return
None
if
request
.
model
in
[
lora
.
lora_name
for
lora
in
self
.
lora_requests
]:
if
request
.
model
in
[
lora
.
lora_name
for
lora
in
self
.
lora_requests
]:
return
None
return
None
...
@@ -187,7 +199,7 @@ class OpenAIServing:
...
@@ -187,7 +199,7 @@ class OpenAIServing:
self
,
request
:
AnyRequest
self
,
request
:
AnyRequest
)
->
Union
[
Tuple
[
None
,
None
],
Tuple
[
LoRARequest
,
None
],
Tuple
[
)
->
Union
[
Tuple
[
None
,
None
],
Tuple
[
LoRARequest
,
None
],
Tuple
[
None
,
PromptAdapterRequest
]]:
None
,
PromptAdapterRequest
]]:
if
request
.
model
in
self
.
served_model_names
:
if
self
.
_is_model_supported
(
request
.
model
)
:
return
None
,
None
return
None
,
None
for
lora
in
self
.
lora_requests
:
for
lora
in
self
.
lora_requests
:
if
request
.
model
==
lora
.
lora_name
:
if
request
.
model
==
lora
.
lora_name
:
...
@@ -480,3 +492,6 @@ class OpenAIServing:
...
@@ -480,3 +492,6 @@ class OpenAIServing:
if
lora_request
.
lora_name
!=
lora_name
if
lora_request
.
lora_name
!=
lora_name
]
]
return
f
"Success: LoRA adapter '
{
lora_name
}
' removed successfully."
return
f
"Success: LoRA adapter '
{
lora_name
}
' removed successfully."
def
_is_model_supported
(
self
,
model_name
):
return
any
(
model
.
name
==
model_name
for
model
in
self
.
base_model_paths
)
vllm/entrypoints/openai/serving_tokenization.py
View file @
539aa992
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
Async
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
(
apply_hf_chat_template
,
from
vllm.entrypoints.chat_utils
import
(
apply_hf_chat_template
,
apply_mistral_chat_template
,
apply_mistral_chat_template
,
load_chat_template
,
load_chat_template
,
...
@@ -16,7 +16,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
...
@@ -16,7 +16,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
TokenizeRequest
,
TokenizeRequest
,
TokenizeResponse
)
TokenizeResponse
)
# yapf: enable
# yapf: enable
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
from
vllm.entrypoints.openai.serving_engine
import
(
BaseModelPath
,
LoRAModulePath
,
OpenAIServing
)
OpenAIServing
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
MistralTokenizer
...
@@ -29,17 +30,17 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -29,17 +30,17 @@ class OpenAIServingTokenization(OpenAIServing):
def
__init__
(
def
__init__
(
self
,
self
,
async_
engine_client
:
Async
EngineClient
,
engine_client
:
EngineClient
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
se
rved
_model_
name
s
:
List
[
str
],
ba
se_model_
path
s
:
List
[
BaseModelPath
],
*
,
*
,
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
lora_modules
:
Optional
[
List
[
LoRAModulePath
]],
request_logger
:
Optional
[
RequestLogger
],
request_logger
:
Optional
[
RequestLogger
],
chat_template
:
Optional
[
str
],
chat_template
:
Optional
[
str
],
):
):
super
().
__init__
(
async_
engine_client
=
async_
engine_client
,
super
().
__init__
(
engine_client
=
engine_client
,
model_config
=
model_config
,
model_config
=
model_config
,
se
rved
_model_
names
=
served
_model_
name
s
,
ba
se_model_
paths
=
base
_model_
path
s
,
lora_modules
=
lora_modules
,
lora_modules
=
lora_modules
,
prompt_adapters
=
None
,
prompt_adapters
=
None
,
request_logger
=
request_logger
)
request_logger
=
request_logger
)
...
@@ -66,7 +67,7 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -66,7 +67,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
async_
engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
prompt
:
Union
[
str
,
List
[
int
]]
prompt
:
Union
[
str
,
List
[
int
]]
if
isinstance
(
request
,
TokenizeChatRequest
):
if
isinstance
(
request
,
TokenizeChatRequest
):
...
@@ -132,7 +133,7 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -132,7 +133,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request
,
prompt_adapter_request
,
)
=
self
.
_maybe_get_adapters
(
request
)
)
=
self
.
_maybe_get_adapters
(
request
)
tokenizer
=
await
self
.
async_
engine_client
.
get_tokenizer
(
lora_request
)
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
self
.
_log_inputs
(
request_id
,
self
.
_log_inputs
(
request_id
,
request
.
tokens
,
request
.
tokens
,
...
...
vllm/envs.py
View file @
539aa992
...
@@ -59,10 +59,12 @@ if TYPE_CHECKING:
...
@@ -59,10 +59,12 @@ if TYPE_CHECKING:
VERBOSE
:
bool
=
False
VERBOSE
:
bool
=
False
VLLM_ALLOW_LONG_MAX_MODEL_LEN
:
bool
=
False
VLLM_ALLOW_LONG_MAX_MODEL_LEN
:
bool
=
False
VLLM_TEST_FORCE_FP8_MARLIN
:
bool
=
False
VLLM_TEST_FORCE_FP8_MARLIN
:
bool
=
False
VLLM_RPC_
GET_DATA_
TIMEOUT
_MS
:
int
=
5
000
VLLM_RPC_TIMEOUT
:
int
=
10
000
# ms
VLLM_PLUGINS
:
Optional
[
List
[
str
]]
=
None
VLLM_PLUGINS
:
Optional
[
List
[
str
]]
=
None
VLLM_TORCH_PROFILER_DIR
:
Optional
[
str
]
=
None
VLLM_TORCH_PROFILER_DIR
:
Optional
[
str
]
=
None
VLLM_USE_TRITON_AWQ
:
bool
=
False
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_ALLOW_DEPRECATED_BEAM_SEARCH
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -206,6 +208,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -206,6 +208,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_PA_PRINT_PARAM"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_PA_PRINT_PARAM"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# If set, allowing the use of deprecated beam search implementation
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH"
:
lambda
:
os
.
environ
.
get
(
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH"
,
"0"
)
==
"1"
,
# Internal flag to enable Dynamo graph capture
# Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
:
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
,
"0"
)),
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
,
"0"
)),
...
@@ -214,6 +220,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -214,6 +220,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
(
os
.
environ
.
get
(
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER"
,
"True"
).
lower
()
in
(
os
.
environ
.
get
(
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# Internal flag to control whether we use custom op,
# or use the native pytorch implementation
"VLLM_TEST_COMPILE_NO_CUSTOM_OPS"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_TEST_COMPILE_NO_CUSTOM_OPS"
,
"0"
)),
# Internal flag to enable Dynamo fullgraph capture
# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
:
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
:
lambda
:
bool
(
lambda
:
bool
(
...
@@ -399,8 +410,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -399,8 +410,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Time in ms for the zmq client to wait for a response from the backend
# Time in ms for the zmq client to wait for a response from the backend
# server for simple data operations
# server for simple data operations
"VLLM_RPC_
GET_DATA_
TIMEOUT
_MS
"
:
"VLLM_RPC_TIMEOUT"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_RPC_
GET_DATA_
TIMEOUT
_MS
"
,
"
5
000"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_RPC_TIMEOUT"
,
"
10
000"
)),
# a list of plugin names to load, separated by commas.
# a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded
# if this is not set, it means all plugins will be loaded
...
...
vllm/executor/cpu_executor.py
View file @
539aa992
...
@@ -106,6 +106,7 @@ class CPUExecutor(ExecutorBase):
...
@@ -106,6 +106,7 @@ class CPUExecutor(ExecutorBase):
))
for
rank
in
range
(
1
,
world_size
)
))
for
rank
in
range
(
1
,
world_size
)
]
]
self
.
worker_monitor
=
None
if
world_size
!=
1
or
is_async
:
if
world_size
!=
1
or
is_async
:
if
is_async
:
if
is_async
:
async_worker_list
=
self
.
workers
+
[
self
.
driver_worker
]
async_worker_list
=
self
.
workers
+
[
self
.
driver_worker
]
...
...
vllm/executor/multiproc_gpu_executor.py
View file @
539aa992
import
asyncio
import
asyncio
import
os
import
os
import
signal
import
threading
import
weakref
from
functools
import
partial
from
functools
import
partial
from
typing
import
Any
,
List
,
Optional
from
typing
import
Any
,
List
,
Optional
...
@@ -108,17 +105,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
...
@@ -108,17 +105,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
# Set up signal handlers to shutdown the executor cleanly
# Set up signal handlers to shutdown the executor cleanly
# sometimes gc does not work well
# sometimes gc does not work well
# Use weakref to avoid holding a reference to self
ref
=
weakref
.
ref
(
self
)
def
shutdown
(
signum
,
frame
):
if
executor
:
=
ref
():
executor
.
shutdown
()
if
threading
.
current_thread
()
is
threading
.
main_thread
():
signal
.
signal
(
signal
.
SIGINT
,
shutdown
)
signal
.
signal
(
signal
.
SIGTERM
,
shutdown
)
self
.
driver_worker
=
self
.
_create_worker
(
self
.
driver_worker
=
self
.
_create_worker
(
distributed_init_method
=
distributed_init_method
)
distributed_init_method
=
distributed_init_method
)
self
.
_run_workers
(
"init_device"
)
self
.
_run_workers
(
"init_device"
)
...
...
vllm/executor/multiproc_worker_utils.py
View file @
539aa992
...
@@ -76,8 +76,7 @@ class ResultHandler(threading.Thread):
...
@@ -76,8 +76,7 @@ class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)"""
"""Handle results from all workers (in background thread)"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
(
daemon
=
False
)
super
().
__init__
(
daemon
=
True
)
# super().__init__(daemon=True)
self
.
result_queue
=
mp
.
Queue
()
self
.
result_queue
=
mp
.
Queue
()
self
.
tasks
:
Dict
[
uuid
.
UUID
,
Union
[
ResultFuture
,
asyncio
.
Future
]]
=
{}
self
.
tasks
:
Dict
[
uuid
.
UUID
,
Union
[
ResultFuture
,
asyncio
.
Future
]]
=
{}
...
@@ -101,8 +100,7 @@ class WorkerMonitor(threading.Thread):
...
@@ -101,8 +100,7 @@ class WorkerMonitor(threading.Thread):
def
__init__
(
self
,
workers
:
List
[
'ProcessWorkerWrapper'
],
def
__init__
(
self
,
workers
:
List
[
'ProcessWorkerWrapper'
],
result_handler
:
ResultHandler
):
result_handler
:
ResultHandler
):
super
().
__init__
(
daemon
=
False
)
super
().
__init__
(
daemon
=
True
)
# super().__init__(daemon=True)
self
.
workers
=
workers
self
.
workers
=
workers
self
.
result_handler
=
result_handler
self
.
result_handler
=
result_handler
self
.
_close
=
False
self
.
_close
=
False
...
@@ -114,30 +112,16 @@ class WorkerMonitor(threading.Thread):
...
@@ -114,30 +112,16 @@ class WorkerMonitor(threading.Thread):
self
.
_close
=
True
self
.
_close
=
True
# Kill / cleanup all workers
# Kill / cleanup all workers
# for worker in self.workers:
for
worker
in
self
.
workers
:
# process = worker.process
process
=
worker
.
process
# if process.sentinel in dead_sentinels:
if
process
.
sentinel
in
dead_sentinels
:
# process.join(JOIN_TIMEOUT_S)
process
.
join
(
JOIN_TIMEOUT_S
)
# if process.exitcode is not None and process.exitcode != 0:
if
process
.
exitcode
is
not
None
and
process
.
exitcode
!=
0
:
# logger.error("Worker %s pid %s died, exit code: %s",
logger
.
error
(
"Worker %s pid %s died, exit code: %s"
,
# process.name, process.pid, process.exitcode)
process
.
name
,
process
.
pid
,
process
.
exitcode
)
if
not
sys
.
is_finalizing
():
# Kill / cleanup all workers
died_count
=
0
for
worker
in
self
.
workers
:
process
=
worker
.
process
if
process
.
sentinel
in
dead_sentinels
:
process
.
join
(
JOIN_TIMEOUT_S
)
if
process
.
exitcode
is
not
None
and
process
.
exitcode
!=
0
:
died_count
+=
1
logger
.
error
(
"Worker %s pid %s died, exit code: %s"
,
process
.
name
,
process
.
pid
,
process
.
exitcode
)
if
died_count
<
len
(
self
.
workers
):
logger
.
info
(
"Killing remaining local vLLM worker processes"
)
# Cleanup any remaining workers
# Cleanup any remaining workers
# logger.info("Killing local vLLM worker processes")
if
logger
:
logger
.
info
(
"Killing local vLLM worker processes"
)
for
worker
in
self
.
workers
:
for
worker
in
self
.
workers
:
worker
.
kill_worker
()
worker
.
kill_worker
()
# Must be done after worker task queues are all closed
# Must be done after worker task queues are all closed
...
@@ -184,6 +168,8 @@ class ProcessWorkerWrapper:
...
@@ -184,6 +168,8 @@ class ProcessWorkerWrapper:
self
.
tasks
[
task_id
]
=
future
self
.
tasks
[
task_id
]
=
future
try
:
try
:
self
.
_task_queue
.
put
((
task_id
,
method
,
args
,
kwargs
))
self
.
_task_queue
.
put
((
task_id
,
method
,
args
,
kwargs
))
except
SystemExit
:
raise
except
BaseException
as
e
:
except
BaseException
as
e
:
del
self
.
tasks
[
task_id
]
del
self
.
tasks
[
task_id
]
raise
ChildProcessError
(
"worker died"
)
from
e
raise
ChildProcessError
(
"worker died"
)
from
e
...
@@ -238,6 +224,10 @@ def _run_worker_process(
...
@@ -238,6 +224,10 @@ def _run_worker_process(
try
:
try
:
executor
=
getattr
(
worker
,
method
)
executor
=
getattr
(
worker
,
method
)
output
=
executor
(
*
args
,
**
kwargs
)
output
=
executor
(
*
args
,
**
kwargs
)
except
SystemExit
:
raise
except
KeyboardInterrupt
:
break
except
BaseException
as
e
:
except
BaseException
as
e
:
tb
=
traceback
.
format_exc
()
tb
=
traceback
.
format_exc
()
logger
.
error
(
logger
.
error
(
...
@@ -278,4 +268,4 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
...
@@ -278,4 +268,4 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
file
.
start_new_line
=
False
# type: ignore[attr-defined]
file
.
start_new_line
=
False
# type: ignore[attr-defined]
file
.
start_new_line
=
True
# type: ignore[attr-defined]
file
.
start_new_line
=
True
# type: ignore[attr-defined]
file
.
write
=
write_with_prefix
# type: ignore[method-assign]
file
.
write
=
write_with_prefix
# type: ignore[method-assign]
\ No newline at end of file
vllm/executor/ray_gpu_executor.py
View file @
539aa992
...
@@ -437,8 +437,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -437,8 +437,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
required_version
=
version
.
parse
(
"2.35"
)
required_version
=
version
.
parse
(
"2.35"
)
current_version
=
version
.
parse
(
current_version
=
version
.
parse
(
pkg_resources
.
get_distribution
(
"ray"
).
version
)
pkg_resources
.
get_distribution
(
"ray"
).
version
)
if
current_version
<
required_version
:
# TODO: update the constraint once we adapt to the backward
raise
ValueError
(
f
"Ray version
{
required_version
}
or greater is "
# incompatible API change from ray 2.36
if
current_version
!=
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
is "
f
"required, but found
{
current_version
}
"
)
f
"required, but found
{
current_version
}
"
)
import
importlib.util
import
importlib.util
...
...
vllm/executor/ray_tpu_executor.py
View file @
539aa992
...
@@ -26,6 +26,8 @@ logger = init_logger(__name__)
...
@@ -26,6 +26,8 @@ logger = init_logger(__name__)
class
RayTPUExecutor
(
TPUExecutor
):
class
RayTPUExecutor
(
TPUExecutor
):
uses_ray
:
bool
=
True
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
# This is non-None when the execute model loop is running
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
...
@@ -68,8 +70,12 @@ class RayTPUExecutor(TPUExecutor):
...
@@ -68,8 +70,12 @@ class RayTPUExecutor(TPUExecutor):
)
)
assert
self
.
speculative_config
is
None
assert
self
.
speculative_config
is
None
worker_module_name
=
"vllm.worker.tpu_worker"
if
self
.
scheduler_config
.
is_multi_step
:
worker_class_name
=
"TPUWorker"
worker_module_name
=
"vllm.worker.multi_step_tpu_worker"
worker_class_name
=
"MultiStepTPUWorker"
else
:
worker_module_name
=
"vllm.worker.tpu_worker"
worker_class_name
=
"TPUWorker"
# GKE does not fetch environment information from metadata server
# GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we
# and instead sets these from within the Ray process. Therefore we
...
...
vllm/executor/ray_utils.py
View file @
539aa992
...
@@ -18,9 +18,14 @@ PG_WAIT_TIMEOUT = 1800
...
@@ -18,9 +18,14 @@ PG_WAIT_TIMEOUT = 1800
try
:
try
:
import
ray
import
ray
from
ray._private.state
import
available_resources_per_node
from
ray.util
import
placement_group_table
from
ray.util
import
placement_group_table
from
ray.util.placement_group
import
PlacementGroup
from
ray.util.placement_group
import
PlacementGroup
try
:
from
ray._private.state
import
available_resources_per_node
except
ImportError
:
# Ray 2.9.x doesn't expose `available_resources_per_node`
from
ray._private.state
import
state
as
_state
available_resources_per_node
=
_state
.
_available_resources_per_node
class
RayWorkerWrapper
(
WorkerWrapperBase
):
class
RayWorkerWrapper
(
WorkerWrapperBase
):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
...
...
vllm/executor/tpu_executor.py
View file @
539aa992
...
@@ -62,11 +62,17 @@ class TPUExecutor(ExecutorBase):
...
@@ -62,11 +62,17 @@ class TPUExecutor(ExecutorBase):
rank
:
int
=
0
,
rank
:
int
=
0
,
distributed_init_method
:
Optional
[
str
]
=
None
,
distributed_init_method
:
Optional
[
str
]
=
None
,
):
):
from
vllm.worker.tpu_worker
import
TPUWorker
if
self
.
scheduler_config
.
is_multi_step
:
from
vllm.worker.multi_step_tpu_worker
import
MultiStepTPUWorker
worker
=
TPUWorker
(
**
self
.
_get_worker_kwargs
(
local_rank
,
rank
,
worker
=
MultiStepTPUWorker
(
**
self
.
_get_worker_kwargs
(
distributed_init_method
))
local_rank
,
rank
,
distributed_init_method
))
return
worker
return
worker
else
:
from
vllm.worker.tpu_worker
import
TPUWorker
worker
=
TPUWorker
(
**
self
.
_get_worker_kwargs
(
local_rank
,
rank
,
distributed_init_method
))
return
worker
def
initialize_cache
(
def
initialize_cache
(
self
,
self
,
...
...
vllm/inputs/data.py
View file @
539aa992
...
@@ -139,6 +139,12 @@ class EncoderDecoderLLMInputs(LLMInputs):
...
@@ -139,6 +139,12 @@ class EncoderDecoderLLMInputs(LLMInputs):
available.
available.
"""
"""
encoder_multi_modal_data
:
NotRequired
[
Optional
[
"MultiModalDataDict"
]]
"""
Optional multi-modal data to pass to the encoder model,
if the model supports it.
"""
_T1
=
TypeVar
(
"_T1"
,
_T1
=
TypeVar
(
"_T1"
,
bound
=
SingletonPromptInputs
,
bound
=
SingletonPromptInputs
,
...
...
vllm/inputs/preprocess.py
View file @
539aa992
...
@@ -128,6 +128,7 @@ class InputPreprocessor:
...
@@ -128,6 +128,7 @@ class InputPreprocessor:
def
_prepare_decoder_input_ids_for_generation
(
def
_prepare_decoder_input_ids_for_generation
(
self
,
self
,
decoder_input_ids
:
Optional
[
List
[
int
]],
decoder_input_ids
:
Optional
[
List
[
int
]],
force_bos
:
bool
=
True
,
)
->
List
[
int
]:
)
->
List
[
int
]:
"""
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
Prepares `decoder_input_ids` for generation with encoder-decoder models.
...
@@ -157,8 +158,8 @@ class InputPreprocessor:
...
@@ -157,8 +158,8 @@ class InputPreprocessor:
# use decoder_start_token_id as decoder_input_ids
# use decoder_start_token_id as decoder_input_ids
decoder_input_ids
=
self
.
_get_default_enc_dec_decoder_prompt
()
decoder_input_ids
=
self
.
_get_default_enc_dec_decoder_prompt
()
if
(
len
(
decoder_input_ids
)
==
0
if
force_bos
and
(
len
(
decoder_input_ids
)
==
0
or
decoder_input_ids
[
0
]
!=
decoder_start_token_id
):
or
decoder_input_ids
[
0
]
!=
decoder_start_token_id
):
decoder_input_ids
=
[
decoder_start_token_id
]
+
decoder_input_ids
decoder_input_ids
=
[
decoder_start_token_id
]
+
decoder_input_ids
return
decoder_input_ids
return
decoder_input_ids
...
@@ -295,18 +296,25 @@ class InputPreprocessor:
...
@@ -295,18 +296,25 @@ class InputPreprocessor:
encoder_prompt
,
encoder_prompt_ids
,
encoder_mm_data
=
encoder_comps
encoder_prompt
,
encoder_prompt_ids
,
encoder_mm_data
=
encoder_comps
decoder_prompt
,
decoder_prompt_ids
,
decoder_mm_data
=
decoder_comps
decoder_prompt
,
decoder_prompt_ids
,
decoder_mm_data
=
decoder_comps
if
encoder_mm_data
is
not
None
or
decoder_mm_data
is
not
None
:
if
decoder_mm_data
is
not
None
:
raise
ValueError
(
"Multi-modal encoder-decoder models are "
raise
ValueError
(
"not supported yet"
)
"Multi-modality decoder inputs of encoder-decoder models are "
"not supported yet"
)
decoder_prompt_ids
=
(
# For Multi-Modal models (e.g., mllama), the text input can be
self
.
_prepare_decoder_input_ids_for_generation
(
decoder_prompt_ids
))
# <|image|><|begin_of_text|>hello world. And we should not add
# another <|begin_of_text|> to the beginning.
decoder_prompt_ids
=
(
self
.
_prepare_decoder_input_ids_for_generation
(
decoder_prompt_ids
,
force_bos
=
(
encoder_mm_data
is
None
and
decoder_mm_data
is
None
)))
return
EncoderDecoderLLMInputs
(
return
EncoderDecoderLLMInputs
(
prompt_token_ids
=
decoder_prompt_ids
,
prompt_token_ids
=
decoder_prompt_ids
,
prompt
=
decoder_prompt
,
prompt
=
decoder_prompt
,
multi_modal_data
=
decoder_mm_data
,
encoder_prompt_token_ids
=
encoder_prompt_ids
,
encoder_prompt_token_ids
=
encoder_prompt_ids
,
encoder_prompt
=
encoder_prompt
,
encoder_prompt
=
encoder_prompt
,
encoder_multi_modal_data
=
encoder_mm_data
,
)
)
def
_process_encoder_decoder_prompt
(
def
_process_encoder_decoder_prompt
(
...
...
vllm/inputs/registry.py
View file @
539aa992
import
functools
import
functools
from
array
import
array
from
collections
import
UserDict
from
collections
import
UserDict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Mapping
,
Optional
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Mapping
,
Optional
,
...
@@ -10,6 +9,7 @@ from transformers import PretrainedConfig
...
@@ -10,6 +9,7 @@ from transformers import PretrainedConfig
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_allowed_kwarg_only_overrides
from
.data
import
LLMInputs
from
.data
import
LLMInputs
...
@@ -22,10 +22,6 @@ logger = init_logger(__name__)
...
@@ -22,10 +22,6 @@ logger = init_logger(__name__)
C
=
TypeVar
(
"C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
C
=
TypeVar
(
"C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
# We cannot import it here because of circular dependencies.
VLLM_TOKEN_ID_ARRAY_TYPE
=
"l"
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
InputContext
:
class
InputContext
:
...
@@ -73,12 +69,17 @@ class DummyDataFactory(Protocol):
...
@@ -73,12 +69,17 @@ class DummyDataFactory(Protocol):
ctx
:
InputContext
,
ctx
:
InputContext
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
**
mm_processor_kwargs
:
Any
,
)
->
Tuple
[
"SequenceData"
,
Optional
[
"MultiModalDataDict"
]]:
)
->
Tuple
[
"SequenceData"
,
Optional
[
"MultiModalDataDict"
]]:
"""
"""
Create dummy data to be inputted into the model.
Create dummy data to be inputted into the model.
Note:
Note:
:data:`InputProcessor` is not applied to the dummy data.
:data:`InputProcessor` is not applied to the dummy data.
The :code:`mm_processor_kwargs` are overrides provided at
initialization time to values in the config whose values
may affect the number of tokens per instance.
"""
"""
...
...
...
@@ -111,6 +112,8 @@ class InputRegistry:
...
@@ -111,6 +112,8 @@ class InputRegistry:
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
self
.
_dummy_factories_by_model_type
:
Dict
[
Type
[
nn
.
Module
],
self
.
_dummy_factories_by_model_type
:
Dict
[
Type
[
nn
.
Module
],
DummyDataFactory
]
=
{}
DummyDataFactory
]
=
{}
self
.
_dummy_encoder_factories_by_model_type
:
Dict
[
Type
[
nn
.
Module
],
DummyDataFactory
]
=
{}
self
.
_input_processors_by_model_type
:
Dict
[
Type
[
nn
.
Module
],
self
.
_input_processors_by_model_type
:
Dict
[
Type
[
nn
.
Module
],
InputProcessor
]
=
{}
InputProcessor
]
=
{}
...
@@ -130,8 +133,7 @@ class InputRegistry:
...
@@ -130,8 +133,7 @@ class InputRegistry:
# Avoid circular import
# Avoid circular import
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
dummy_seq_data
=
SequenceData
(
dummy_seq_data
=
SequenceData
.
from_token_counts
((
0
,
seq_len
))
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
seq_len
)
dummy_multi_modal_data
=
None
dummy_multi_modal_data
=
None
return
dummy_seq_data
,
dummy_multi_modal_data
return
dummy_seq_data
,
dummy_multi_modal_data
...
@@ -158,11 +160,48 @@ class InputRegistry:
...
@@ -158,11 +160,48 @@ class InputRegistry:
return
wrapper
return
wrapper
def
_get_dummy_data_factory
(
self
,
model_cls
:
Type
[
nn
.
Module
]):
return
self
.
_dummy_factories_by_model_type
\
.
get
(
model_cls
,
self
.
_default_dummy_data_factory
)
def
register_dummy_encoder_data
(
self
,
factory
:
DummyDataFactory
):
"""
Register a dummy encoder data factory to a model class
This is similar to :meth:`~register_dummy_data`, but for encoder input.
"""
def
wrapper
(
model_cls
:
N
)
->
N
:
if
model_cls
in
self
.
_dummy_encoder_factories_by_model_type
:
logger
.
warning
(
"Model class %s already has dummy encoder data "
"registered to %s. It is overwritten by the new one."
,
model_cls
,
self
)
self
.
_dummy_encoder_factories_by_model_type
[
model_cls
]
=
factory
return
model_cls
return
wrapper
def
_get_dummy_encoder_data_factory
(
self
,
model_cls
:
Type
[
nn
.
Module
]):
if
model_cls
in
self
.
_dummy_encoder_factories_by_model_type
:
dummy_factory
=
self
.
_dummy_encoder_factories_by_model_type
[
model_cls
]
else
:
logger
.
warning
(
"No dummy encoder data factory registered to %s. "
"Using the dummy data factory for the model instead."
,
model_cls
)
dummy_factory
=
self
.
_get_dummy_data_factory
(
model_cls
)
return
dummy_factory
def
dummy_data_for_profiling
(
def
dummy_data_for_profiling
(
self
,
self
,
model_config
:
"ModelConfig"
,
model_config
:
"ModelConfig"
,
seq_len
:
int
,
seq_len
:
int
,
mm_registry
:
"MultiModalRegistry"
,
mm_registry
:
"MultiModalRegistry"
,
is_encoder_data
:
bool
=
False
,
)
->
Tuple
[
"SequenceData"
,
Optional
[
"MultiModalDataDict"
]]:
)
->
Tuple
[
"SequenceData"
,
Optional
[
"MultiModalDataDict"
]]:
"""
"""
Create dummy data for profiling the memory usage of a model.
Create dummy data for profiling the memory usage of a model.
...
@@ -180,22 +219,29 @@ class InputRegistry:
...
@@ -180,22 +219,29 @@ class InputRegistry:
from
vllm.model_executor.model_loader
import
get_model_architecture
from
vllm.model_executor.model_loader
import
get_model_architecture
model_cls
,
_
=
get_model_architecture
(
model_config
)
model_cls
,
_
=
get_model_architecture
(
model_config
)
dummy_factory
=
self
.
_dummy_factories_by_model_type
\
if
is_encoder_data
:
.
get
(
model_cls
,
self
.
_default_dummy_data_factory
)
dummy_factory
=
self
.
_get_dummy_encoder_data_factory
(
model_cls
)
else
:
dummy_factory
=
self
.
_get_dummy_data_factory
(
model_cls
)
mm_counts
=
mm_registry
.
get_mm_limits_per_prompt
(
model_config
)
mm_counts
=
mm_registry
.
get_mm_limits_per_prompt
(
model_config
)
mm_processor_kwargs
=
get_allowed_kwarg_only_overrides
(
dummy_factory
,
overrides
=
model_config
.
mm_processor_kwargs
)
seq_data
,
mm_data
=
dummy_factory
(
seq_data
,
mm_data
=
dummy_factory
(
InputContext
(
model_config
),
seq_len
,
InputContext
(
model_config
),
_MultiModalCounts
(
mm_counts
),
seq_len
,
**
mm_processor_kwargs
)
_MultiModalCounts
(
mm_counts
),
)
# Having more tokens is over-conservative but otherwise fine
# Having more tokens is over-conservative but otherwise fine
num_tokens
=
seq_data
.
prompt_token_ids
num_tokens
=
seq_data
.
prompt_token_ids
assert
len
(
num_tokens
)
>=
seq_len
,
(
if
len
(
num_tokens
)
<
seq_len
:
f
"Expected at least
{
seq_len
}
dummy tokens for profiling, "
if
is_encoder_data
:
f
"but found
{
len
(
num_tokens
)
}
tokens instead."
)
logger
.
warning
(
"Expected at least %d dummy encoder tokens for profiling, "
"but found %d tokens instead."
,
seq_len
,
len
(
num_tokens
))
else
:
raise
AssertionError
(
f
"Expected at least
{
seq_len
}
dummy tokens for profiling, "
f
"but found
{
len
(
num_tokens
)
}
tokens instead."
)
if
mm_data
is
not
None
:
if
mm_data
is
not
None
:
for
k
,
v
in
mm_data
.
items
():
for
k
,
v
in
mm_data
.
items
():
num_items
=
len
(
v
)
if
isinstance
(
v
,
list
)
else
1
num_items
=
len
(
v
)
if
isinstance
(
v
,
list
)
else
1
...
@@ -235,6 +281,10 @@ class InputRegistry:
...
@@ -235,6 +281,10 @@ class InputRegistry:
return
wrapper
return
wrapper
def
_get_model_input_processor
(
self
,
model_cls
:
Type
[
nn
.
Module
]):
return
self
.
_input_processors_by_model_type
\
.
get
(
model_cls
,
self
.
_default_input_processor
)
def
process_input
(
self
,
model_config
:
"ModelConfig"
,
def
process_input
(
self
,
model_config
:
"ModelConfig"
,
inputs
:
LLMInputs
)
->
LLMInputs
:
inputs
:
LLMInputs
)
->
LLMInputs
:
"""
"""
...
@@ -249,15 +299,17 @@ class InputRegistry:
...
@@ -249,15 +299,17 @@ class InputRegistry:
from
vllm.model_executor.model_loader
import
get_model_architecture
from
vllm.model_executor.model_loader
import
get_model_architecture
model_cls
,
_
=
get_model_architecture
(
model_config
)
model_cls
,
_
=
get_model_architecture
(
model_config
)
processor
=
self
.
_get_model_input_processor
(
model_cls
)
processor
=
self
.
_input_processors_by_model_type
\
mm_
processor
_kwargs
=
get_allowed_kwarg_only_overrides
(
.
get
(
model_cls
,
self
.
_default_input
_processor
)
processor
,
overrides
=
model_config
.
mm
_processor
_kwargs
)
return
processor
(
InputContext
(
model_config
),
inputs
)
return
processor
(
InputContext
(
model_config
),
inputs
,
**
mm_processor_kwargs
)
def
create_input_processor
(
self
,
model_config
:
"ModelConfig"
):
def
create_input_processor
(
self
,
model_config
:
"ModelConfig"
):
"""
"""
Create an input processor (see :meth:`process_input`) for a
Create an input processor (see :meth:`
_
process_input`) for a
specific model.
specific model.
"""
"""
return
functools
.
partial
(
self
.
process_input
,
model_config
)
return
functools
.
partial
(
self
.
process_input
,
model_config
)
vllm/lora/ops/bgmv_expand.py
View file @
539aa992
...
@@ -100,7 +100,7 @@ def _bgmv_expand(
...
@@ -100,7 +100,7 @@ def _bgmv_expand(
corresponding to each batch, An index of -1 means no lora should be
corresponding to each batch, An index of -1 means no lora should be
applied.
applied.
batches (int): batch size
batches (int): batch size
add_inputs (bool, optional): Defaults to False
.
adds the final lora
add_inputs (bool, optional): Defaults to False
,
adds the final lora
results to the output.
results to the output.
"""
"""
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
...
...
vllm/lora/ops/bgmv_expand_slice.py
View file @
539aa992
...
@@ -104,7 +104,7 @@ def _bgmv_expand_slice(
...
@@ -104,7 +104,7 @@ def _bgmv_expand_slice(
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be
corresponding to each batch, An index of -1 means no lora should be
applied.
applied.
slice_offst (int): output_tensor's offst
slice_offs
e
t (int): output_tensor's offs
e
t
slice_size (int): current output_tensor's size
slice_size (int): current output_tensor's size
batches (int): batch size
batches (int): batch size
add_inputs (bool, optional): Defaults to False.
add_inputs (bool, optional): Defaults to False.
...
...
vllm/lora/ops/sgmv_expand.py
View file @
539aa992
...
@@ -106,6 +106,7 @@ def _sgmv_expand(
...
@@ -106,6 +106,7 @@ def _sgmv_expand(
lora_indices_tensor
:
torch
.
Tensor
,
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
batches
:
int
,
max_seq_length
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
add_inputs
:
bool
=
False
,
add_inputs
:
bool
=
False
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -115,17 +116,19 @@ def _sgmv_expand(
...
@@ -115,17 +116,19 @@ def _sgmv_expand(
output_tensor (torch.Tensor): output tensor
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g.,
if the sequence length is [4, 6], it is
[0, 4, 10].
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,).
r
ecord the sequence
seq_len_tensor (torch.Tensor): (batch_size,).
R
ecord the sequence
length of the sequences
in the batch
length of the sequences in the batch
.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
corresponding to each batch. An index of -1 means no lora should be
applied.
applied.
batches (int): batch size
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
max_seq_length (int): The max sequence lengths of the sequences in the
in the batch
batch.
add_inputs (bool, optional): Defaults to False. adds the final lora
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
results to the output.
"""
"""
...
@@ -134,6 +137,7 @@ def _sgmv_expand(
...
@@ -134,6 +137,7 @@ def _sgmv_expand(
torch
.
float16
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
bfloat16
,
]
]
assert
inputs
.
size
(
0
)
==
token_nums
assert
inputs
.
size
(
1
)
==
lora_b_weights
.
size
(
-
1
)
assert
inputs
.
size
(
1
)
==
lora_b_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
...
...
vllm/lora/ops/sgmv_expand_slice.py
View file @
539aa992
...
@@ -112,6 +112,7 @@ def _sgmv_expand_slice(
...
@@ -112,6 +112,7 @@ def _sgmv_expand_slice(
lora_indices_tensor
:
torch
.
Tensor
,
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
batches
:
int
,
max_seq_length
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
slice_offset
:
int
,
slice_offset
:
int
,
slice_size
:
int
,
slice_size
:
int
,
add_inputs
:
bool
=
False
,
add_inputs
:
bool
=
False
,
...
@@ -124,20 +125,22 @@ def _sgmv_expand_slice(
...
@@ -124,20 +125,22 @@ def _sgmv_expand_slice(
output_tensor (torch.Tensor): output tensor
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g.,
if the sequence length is [4, 6], it is
[0, 4, 10].
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,).
r
ecord the sequence
seq_len_tensor (torch.Tensor): (batch_size,).
R
ecord the sequence
length of the sequences
in the batch
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
corresponding to each batch. An index of -1 means no lora should be
applied.
applied.
batches (int): batch size
batches (int): batch size
max_seq_length (int):
The max sequence lengths of the sequences
max_seq_length (int): The max sequence lengths of the sequences
in the batch
in the batch
slice_offst (int): output_tensor's offst
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
slice_size (int): current output_tensor's size
add_inputs (bool, optional):
Defaults to False
.
adds the final lora
add_inputs (bool, optional): Defaults to False
,
adds the final lora
results to the output.
.
results to the output.
"""
"""
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
...
@@ -145,6 +148,7 @@ def _sgmv_expand_slice(
...
@@ -145,6 +148,7 @@ def _sgmv_expand_slice(
torch
.
float16
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
bfloat16
,
]
]
assert
inputs
.
size
(
0
)
==
token_nums
assert
inputs
.
size
(
1
)
==
lora_b_weights
.
size
(
-
1
)
assert
inputs
.
size
(
1
)
==
lora_b_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
...
...
vllm/lora/ops/sgmv_shrink.py
View file @
539aa992
...
@@ -110,6 +110,7 @@ def _sgmv_shrink(
...
@@ -110,6 +110,7 @@ def _sgmv_shrink(
lora_indices_tensor
:
torch
.
Tensor
,
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
batches
:
int
,
max_seq_length
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
scaling
:
float
,
scaling
:
float
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -120,17 +121,19 @@ def _sgmv_shrink(
...
@@ -120,17 +121,19 @@ def _sgmv_shrink(
output_tensor (torch.Tensor): output tensor
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g.,
if the sequence length is [4, 6], it is
[0, 4].
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,).
r
ecord the sequence
seq_len_tensor (torch.Tensor): (batch_size,).
R
ecord the sequence
length of the sequences
in the batch
length of the sequences in the batch
.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
corresponding to each batch. An index of -1 means no lora should be
applied.
applied.
batches (int): batch size
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
max_seq_length (int): The max sequence lengths of the sequences in the
in the batch
batch.
scaling (float): Scaling factor.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
scaling (float): Scaling factor.
"""
"""
assert
inputs
.
dtype
==
lora_a_weights
.
dtype
assert
inputs
.
dtype
==
lora_a_weights
.
dtype
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
...
@@ -138,6 +141,7 @@ def _sgmv_shrink(
...
@@ -138,6 +141,7 @@ def _sgmv_shrink(
torch
.
float16
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
bfloat16
,
]
]
assert
inputs
.
size
(
0
)
==
token_nums
assert
inputs
.
size
(
1
)
==
lora_a_weights
.
size
(
-
1
)
assert
inputs
.
size
(
1
)
==
lora_a_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
...
...
Prev
1
…
9
10
11
12
13
14
15
16
17
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment