Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99b471c2
Commit
99b471c2
authored
May 21, 2024
by
zhuwenwen
Browse files
merge v0.4.1
parents
1925d2e9
468d761b
Changes
336
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
1817 additions
and
437 deletions
+1817
-437
vllm/transformers_utils/configs/jais.py
vllm/transformers_utils/configs/jais.py
+4
-2
vllm/transformers_utils/detokenizer.py
vllm/transformers_utils/detokenizer.py
+165
-9
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+23
-158
vllm/transformers_utils/tokenizer_group/__init__.py
vllm/transformers_utils/tokenizer_group/__init__.py
+1
-1
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
...transformers_utils/tokenizer_group/ray_tokenizer_group.py
+3
-0
vllm/transformers_utils/tokenizers/baichuan.py
vllm/transformers_utils/tokenizers/baichuan.py
+7
-7
vllm/usage/usage_lib.py
vllm/usage/usage_lib.py
+4
-4
vllm/utils.py
vllm/utils.py
+195
-26
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+4
-5
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+432
-0
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+320
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+327
-131
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+15
-9
vllm/worker/neuron_worker.py
vllm/worker/neuron_worker.py
+53
-7
vllm/worker/worker.py
vllm/worker/worker.py
+106
-78
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+158
-0
No files found.
vllm/transformers_utils/configs/jais.py
View file @
99b471c2
...
@@ -222,13 +222,15 @@ class JAISConfig(PretrainedConfig):
...
@@ -222,13 +222,15 @@ class JAISConfig(PretrainedConfig):
f
"got
{
alibi_scaling_type
}
"
)
f
"got
{
alibi_scaling_type
}
"
)
if
(
alibi_scaling_factor
is
not
None
if
(
alibi_scaling_factor
is
not
None
and
not
isinstance
(
alibi_scaling_factor
,
float
)
and
not
isinstance
(
alibi_scaling_factor
,
float
)
or
alibi_scaling_factor
<=
1.0
):
or
(
alibi_scaling_factor
is
not
None
and
alibi_scaling_factor
<=
1.0
)):
raise
ValueError
(
raise
ValueError
(
f
"`alibi_scaling`'s factor field must be a float > 1.0,"
f
"`alibi_scaling`'s factor field must be a float > 1.0,"
f
"got
{
alibi_scaling_factor
}
"
)
f
"got
{
alibi_scaling_factor
}
"
)
if
(
alibi_dynamic_scaling
is
not
None
if
(
alibi_dynamic_scaling
is
not
None
and
not
isinstance
(
alibi_dynamic_scaling
,
int
)
and
not
isinstance
(
alibi_dynamic_scaling
,
int
)
or
alibi_dynamic_scaling
<=
1
):
or
(
alibi_dynamic_scaling
is
not
None
and
alibi_dynamic_scaling
<=
1
)):
raise
ValueError
(
raise
ValueError
(
f
"`alibi_scaling`'s `train_seq_len` field must be an"
f
"`alibi_scaling`'s `train_seq_len` field must be an"
f
"integer > 1, got
{
alibi_dynamic_scaling
}
"
)
f
"integer > 1, got
{
alibi_dynamic_scaling
}
"
)
vllm/transformers_utils/detokenizer.py
View file @
99b471c2
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
vllm.sequence
import
Logprob
,
SamplingParams
,
Sequence
,
SequenceGroup
from
vllm.sequence
import
Logprob
,
SamplingParams
,
Sequence
,
SequenceGroup
from
vllm.transformers_utils.tokenizer
import
(
convert_prompt_ids_to_tokens
,
detokenize_incrementally
)
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
BaseTokenizerGroup
)
BaseTokenizerGroup
)
...
@@ -89,12 +87,15 @@ class Detokenizer:
...
@@ -89,12 +87,15 @@ class Detokenizer:
prev_tokens
.
extend
(
next_iter_tokens
)
prev_tokens
.
extend
(
next_iter_tokens
)
def
decode_sequence_inplace
(
self
,
seq
:
Sequence
,
def
decode_sequence_inplace
(
self
,
seq
:
Sequence
,
prms
:
SamplingParams
)
->
None
:
prms
:
SamplingParams
)
->
int
:
"""Decodes the new token for a sequence. In-place operation.
"""Decodes the new token for a sequence. In-place operation.
Args:
Args:
seq: The sequence to decode.
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.
prms: The sampling parameters used to generate the sequence.
Returns:
The number of characters added to the output text.
"""
"""
all_input_ids
=
seq
.
get_token_ids
()
all_input_ids
=
seq
.
get_token_ids
()
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
...
@@ -148,10 +149,165 @@ class Detokenizer:
...
@@ -148,10 +149,165 @@ class Detokenizer:
)
)
sample_logprob
.
decoded_token
=
new_text
sample_logprob
.
decoded_token
=
new_text
if
seq
.
tokens
is
None
:
seq
.
tokens
.
extend
(
new_tokens
)
seq
.
tokens
=
new_tokens
else
:
seq
.
tokens
.
extend
(
new_tokens
)
seq
.
prefix_offset
=
prefix_offset
seq
.
prefix_offset
=
prefix_offset
seq
.
read_offset
=
read_offset
seq
.
read_offset
=
read_offset
seq
.
output_text
+=
new_decoded_token_text
seq
.
output_text
+=
new_decoded_token_text
return
len
(
new_decoded_token_text
)
def
_convert_tokens_to_string_with_added_encoders
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
output_tokens
:
List
[
str
],
skip_special_tokens
:
bool
,
spaces_between_special_tokens
:
bool
,
)
->
str
:
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts
:
List
[
str
]
=
[]
current_sub_text
:
List
[
str
]
=
[]
all_special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
for
token
in
output_tokens
:
if
skip_special_tokens
and
token
in
all_special_tokens
:
continue
if
token
in
tokenizer
.
get_added_vocab
():
if
current_sub_text
:
sub_text
=
tokenizer
.
convert_tokens_to_string
(
current_sub_text
)
sub_texts
.
append
(
sub_text
)
current_sub_text
=
[]
sub_texts
.
append
(
token
)
else
:
current_sub_text
.
append
(
token
)
if
current_sub_text
:
sub_text
=
tokenizer
.
convert_tokens_to_string
(
current_sub_text
)
sub_texts
.
append
(
sub_text
)
if
spaces_between_special_tokens
:
return
" "
.
join
(
sub_texts
)
else
:
return
""
.
join
(
sub_texts
)
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
def
convert_prompt_ids_to_tokens
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
prompt_ids
:
List
[
int
],
skip_special_tokens
:
bool
=
False
,
)
->
Tuple
[
List
[
str
],
int
,
int
]:
"""Converts the prompt ids to tokens and returns the tokens and offsets
for incremental detokenization.
Note that not all tokens are converted to strings. Only the tokens that
are necessary for incremental detokenization are converted to strings.
"""
# We do not need to convert the whole prompt to tokens.
# Offset a little more in case we have special tokens.
new_tokens
=
tokenizer
.
convert_ids_to_tokens
(
prompt_ids
[
-
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET
-
2
:],
skip_special_tokens
=
skip_special_tokens
)
read_offset
=
len
(
new_tokens
)
prefix_offset
=
max
(
read_offset
-
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET
,
0
)
return
new_tokens
,
prefix_offset
,
read_offset
# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def
detokenize_incrementally
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
all_input_ids
:
List
[
int
],
prev_tokens
:
Optional
[
List
[
str
]],
prefix_offset
:
int
,
read_offset
:
int
,
skip_special_tokens
:
bool
=
False
,
spaces_between_special_tokens
:
bool
=
True
,
)
->
Tuple
[
List
[
str
],
str
,
int
,
int
]:
"""Detokenizes the input ids incrementally and returns the new tokens
and the new text.
If `prev_tokens` is None, this function will convert the input ids to
tokens and return the tokens and the new text. Otherwise, it will return the
new tokens and the new text.
This function will also return the new prefix offset and the new read
offset to be used in the next iteration.
The offsets are necessary to defeat cleanup algorithms in the decode which
decide to add a space or not depending on the surrounding ids.
Args:
tokenizer: The tokenizer to use.
all_input_ids: The input ids. The last id is the new token id.
prev_tokens: The previous tokens. If None, this function will convert
the input ids to tokens and return the tokens and the new text.
prefix_offset: The prefix offset.
read_offset: The read offset.
skip_special_tokens: Whether to skip special tokens.
spaces_between_special_tokens: Whether to add spaces between special
tokens.
"""
new_token_id
=
all_input_ids
[
-
1
]
# This is the first iteration for this sequence
is_first_iter
=
prev_tokens
is
None
if
is_first_iter
:
(
prev_tokens
,
prefix_offset
,
read_offset
)
=
convert_prompt_ids_to_tokens
(
tokenizer
,
all_input_ids
[:
-
1
],
skip_special_tokens
=
skip_special_tokens
)
assert
prev_tokens
is
not
None
# If the new token id is out of bounds, return an empty string.
if
new_token_id
>=
len
(
tokenizer
):
new_tokens
=
[
""
]
else
:
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens
=
tokenizer
.
convert_ids_to_tokens
(
[
new_token_id
],
skip_special_tokens
=
skip_special_tokens
)
if
isinstance
(
new_tokens
,
str
):
new_tokens
=
[
new_tokens
]
output_tokens
=
prev_tokens
+
new_tokens
# If this is the first iteration, return all tokens.
if
is_first_iter
:
new_tokens
=
output_tokens
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
# surrounding ids.
if
tokenizer
.
is_fast
or
not
tokenizer
.
get_added_vocab
():
prefix_text
=
tokenizer
.
convert_tokens_to_string
(
output_tokens
[
prefix_offset
:
read_offset
])
new_text
=
tokenizer
.
convert_tokens_to_string
(
output_tokens
[
prefix_offset
:])
else
:
prefix_text
=
_convert_tokens_to_string_with_added_encoders
(
tokenizer
,
output_tokens
[
prefix_offset
:
read_offset
],
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
new_text
=
_convert_tokens_to_string_with_added_encoders
(
tokenizer
,
output_tokens
[
prefix_offset
:],
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
if
len
(
new_text
)
<=
len
(
prefix_text
)
or
new_text
.
endswith
(
"�"
):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
return
new_tokens
,
""
,
prefix_offset
,
read_offset
new_text
=
new_text
[
len
(
prefix_text
):]
return
new_tokens
,
new_text
,
read_offset
,
len
(
output_tokens
)
vllm/transformers_utils/tokenizer.py
View file @
99b471c2
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
os
from
typing
import
Optional
,
Union
from
transformers
import
(
AutoTokenizer
,
PreTrainedTokenizer
,
from
transformers
import
(
AutoTokenizer
,
PreTrainedTokenizer
,
PreTrainedTokenizerFast
)
PreTrainedTokenizerFast
)
from
vllm.config
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizers
import
*
from
vllm.transformers_utils.tokenizers
import
BaichuanTokenizer
from
vllm.utils
import
make_async
from
vllm.utils
import
make_async
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -28,7 +30,7 @@ def get_cached_tokenizer(
...
@@ -28,7 +30,7 @@ def get_cached_tokenizer(
tokenizer_all_special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
tokenizer_all_special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
tokenizer_len
=
len
(
tokenizer
)
tokenizer_len
=
len
(
tokenizer
)
class
CachedTokenizer
(
tokenizer
.
__class__
):
class
CachedTokenizer
(
tokenizer
.
__class__
):
# type: ignore
@
property
@
property
def
all_special_ids
(
self
):
def
all_special_ids
(
self
):
...
@@ -57,9 +59,26 @@ def get_tokenizer(
...
@@ -57,9 +59,26 @@ def get_tokenizer(
tokenizer_mode
:
str
=
"auto"
,
tokenizer_mode
:
str
=
"auto"
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
download_dir
:
Optional
[
str
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
"""Gets a tokenizer for the given model name via Huggingface."""
"""Gets a tokenizer for the given model name via Huggingface/modelscope."""
if
VLLM_USE_MODELSCOPE
:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from
modelscope.hub.snapshot_download
import
snapshot_download
# Only set the tokenizer here, model will be downloaded on the workers.
if
not
os
.
path
.
exists
(
tokenizer_name
):
tokenizer_path
=
snapshot_download
(
model_id
=
tokenizer_name
,
cache_dir
=
download_dir
,
revision
=
tokenizer_revision
,
# Ignore weights - we only need the tokenizer.
ignore_file_pattern
=
[
"*.pt"
,
"*.safetensors"
,
"*.bin"
])
tokenizer_name
=
tokenizer_path
if
tokenizer_mode
==
"slow"
:
if
tokenizer_mode
==
"slow"
:
if
kwargs
.
get
(
"use_fast"
,
False
):
if
kwargs
.
get
(
"use_fast"
,
False
):
raise
ValueError
(
raise
ValueError
(
...
@@ -126,157 +145,3 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
...
@@ -126,157 +145,3 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
get_lora_tokenizer_async
=
make_async
(
get_lora_tokenizer
)
get_lora_tokenizer_async
=
make_async
(
get_lora_tokenizer
)
def
_convert_tokens_to_string_with_added_encoders
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
output_tokens
:
List
[
str
],
skip_special_tokens
:
bool
,
spaces_between_special_tokens
:
bool
,
)
->
str
:
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts
=
[]
current_sub_text
=
[]
all_special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
for
token
in
output_tokens
:
if
skip_special_tokens
and
token
in
all_special_tokens
:
continue
if
token
in
tokenizer
.
get_added_vocab
():
if
current_sub_text
:
sub_text
=
tokenizer
.
convert_tokens_to_string
(
current_sub_text
)
sub_texts
.
append
(
sub_text
)
current_sub_text
=
[]
sub_texts
.
append
(
token
)
else
:
current_sub_text
.
append
(
token
)
if
current_sub_text
:
sub_text
=
tokenizer
.
convert_tokens_to_string
(
current_sub_text
)
sub_texts
.
append
(
sub_text
)
if
spaces_between_special_tokens
:
return
" "
.
join
(
sub_texts
)
else
:
return
""
.
join
(
sub_texts
)
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
def
convert_prompt_ids_to_tokens
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
prompt_ids
:
List
[
int
],
skip_special_tokens
:
bool
=
False
,
)
->
Tuple
[
List
[
str
],
int
,
int
]:
"""Converts the prompt ids to tokens and returns the tokens and offsets
for incremental detokenization.
Note that not all tokens are converted to strings. Only the tokens that
are necessary for incremental detokenization are converted to strings.
"""
# Offset a little more in case we have special tokens.
prefix_offset
=
max
(
len
(
prompt_ids
)
-
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET
-
2
,
0
)
# We do not need to convert the whole prompt to tokens.
new_tokens
=
tokenizer
.
convert_ids_to_tokens
(
prompt_ids
[
prefix_offset
:],
skip_special_tokens
=
skip_special_tokens
)
prefix_offset
=
max
(
len
(
new_tokens
)
-
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET
,
0
)
read_offset
=
len
(
new_tokens
)
return
new_tokens
,
prefix_offset
,
read_offset
# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def
detokenize_incrementally
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
all_input_ids
:
List
[
int
],
prev_tokens
:
Optional
[
List
[
str
]],
prefix_offset
:
int
,
read_offset
:
int
,
skip_special_tokens
:
bool
=
False
,
spaces_between_special_tokens
:
bool
=
True
,
)
->
Tuple
[
List
[
str
],
str
,
int
,
int
]:
"""Detokenizes the input ids incrementally and returns the new tokens
and the new text.
If `prev_tokens` is None, this function will convert the input ids to
tokens and return the tokens and the new text. Otherwise, it will return the
new tokens and the new text.
This function will also return the new prefix offset and the new read
offset to be used in the next iteration.
The offsets are necessary to defeat cleanup algorithms in the decode which
decide to add a space or not depending on the surrounding ids.
Args:
tokenizer: The tokenizer to use.
all_input_ids: The input ids. The last id is the new token id.
prev_tokens: The previous tokens. If None, this function will convert
the input ids to tokens and return the tokens and the new text.
prefix_offset: The prefix offset.
read_offset: The read offset.
skip_special_tokens: Whether to skip special tokens.
spaces_between_special_tokens: Whether to add spaces between special
tokens.
"""
new_token_id
=
all_input_ids
[
-
1
]
# This is the first iteration for this sequence
is_first_iter
=
prev_tokens
is
None
if
is_first_iter
:
(
prev_tokens
,
prefix_offset
,
read_offset
)
=
convert_prompt_ids_to_tokens
(
tokenizer
,
all_input_ids
[:
-
1
],
skip_special_tokens
=
skip_special_tokens
)
# If the new token id is out of bounds, return an empty string.
if
new_token_id
>=
len
(
tokenizer
):
new_tokens
=
[
""
]
else
:
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens
=
tokenizer
.
convert_ids_to_tokens
(
[
new_token_id
],
skip_special_tokens
=
skip_special_tokens
)
output_tokens
=
prev_tokens
+
new_tokens
# If this is the first iteration, return all tokens.
if
is_first_iter
:
new_tokens
=
output_tokens
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
# surrounding ids.
if
tokenizer
.
is_fast
or
not
tokenizer
.
get_added_vocab
():
prefix_text
=
tokenizer
.
convert_tokens_to_string
(
output_tokens
[
prefix_offset
:
read_offset
])
new_text
=
tokenizer
.
convert_tokens_to_string
(
output_tokens
[
prefix_offset
:])
else
:
prefix_text
=
_convert_tokens_to_string_with_added_encoders
(
tokenizer
,
output_tokens
[
prefix_offset
:
read_offset
],
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
new_text
=
_convert_tokens_to_string_with_added_encoders
(
tokenizer
,
output_tokens
[
prefix_offset
:],
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
)
if
len
(
new_text
)
>
len
(
prefix_text
)
and
not
new_text
.
endswith
(
"�"
):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
new_text
=
new_text
[
len
(
prefix_text
):]
return
new_tokens
,
new_text
,
read_offset
,
len
(
output_tokens
)
else
:
return
new_tokens
,
""
,
prefix_offset
,
read_offset
vllm/transformers_utils/tokenizer_group/__init__.py
View file @
99b471c2
...
@@ -11,7 +11,7 @@ if ray:
...
@@ -11,7 +11,7 @@ if ray:
from
vllm.transformers_utils.tokenizer_group.ray_tokenizer_group
import
(
from
vllm.transformers_utils.tokenizer_group.ray_tokenizer_group
import
(
RayTokenizerGroupPool
)
RayTokenizerGroupPool
)
else
:
else
:
RayTokenizerGroupPool
=
None
RayTokenizerGroupPool
=
None
# type: ignore
def
get_tokenizer_group
(
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
],
def
get_tokenizer_group
(
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
],
...
...
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
View file @
99b471c2
...
@@ -51,6 +51,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
...
@@ -51,6 +51,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
enable_lora
=
enable_lora
,
enable_lora
=
enable_lora
,
max_num_seqs
=
max_num_seqs
,
max_num_seqs
=
max_num_seqs
,
max_input_length
=
max_input_length
,
max_input_length
=
max_input_length
,
**
tokenizer_config
,
)
)
ray_tokenizer_group_cls
=
ray
.
remote
(
ray_tokenizer_group_cls
=
ray
.
remote
(
...
@@ -88,6 +89,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
...
@@ -88,6 +89,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
This is blocking.
This is blocking.
"""
"""
self
.
_ensure_queue_initialized
()
self
.
_ensure_queue_initialized
()
assert
self
.
_idle_actors
is
not
None
if
self
.
_idle_actors
.
empty
():
if
self
.
_idle_actors
.
empty
():
raise
RuntimeError
(
"No idle actors available."
)
raise
RuntimeError
(
"No idle actors available."
)
...
@@ -119,6 +121,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
...
@@ -119,6 +121,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
This is non-blocking.
This is non-blocking.
"""
"""
self
.
_ensure_queue_initialized
()
self
.
_ensure_queue_initialized
()
assert
self
.
_idle_actors
is
not
None
actor
=
await
self
.
_idle_actors
.
get
()
actor
=
await
self
.
_idle_actors
.
get
()
try
:
try
:
...
...
vllm/transformers_utils/tokenizers/baichuan.py
View file @
99b471c2
...
@@ -16,11 +16,11 @@ logger = logging.get_logger(__name__)
...
@@ -16,11 +16,11 @@ logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"tokenizer.model"
}
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"tokenizer.model"
}
PRETRAINED_VOCAB_FILES_MAP
=
{
PRETRAINED_VOCAB_FILES_MAP
=
{
# type: ignore
"vocab_file"
:
{},
"vocab_file"
:
{},
"tokenizer_file"
:
{},
"tokenizer_file"
:
{},
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{}
# type: ignore
class
BaichuanTokenizer
(
PreTrainedTokenizer
):
class
BaichuanTokenizer
(
PreTrainedTokenizer
):
...
@@ -114,9 +114,9 @@ class BaichuanTokenizer(PreTrainedTokenizer):
...
@@ -114,9 +114,9 @@ class BaichuanTokenizer(PreTrainedTokenizer):
token
=
self
.
sp_model
.
IdToPiece
(
index
)
token
=
self
.
sp_model
.
IdToPiece
(
index
)
return
token
return
token
def
convert_tokens_to_string
(
self
,
tokens
):
def
convert_tokens_to_string
(
self
,
tokens
:
List
[
str
]
):
"""Converts a sequence of tokens (string) in a single string."""
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens
=
[]
current_sub_tokens
:
List
[
str
]
=
[]
out_string
=
""
out_string
=
""
prev_is_special
=
False
prev_is_special
=
False
for
i
,
token
in
enumerate
(
tokens
):
for
i
,
token
in
enumerate
(
tokens
):
...
@@ -148,9 +148,9 @@ class BaichuanTokenizer(PreTrainedTokenizer):
...
@@ -148,9 +148,9 @@ class BaichuanTokenizer(PreTrainedTokenizer):
`Tuple(str)`: Paths to the files saved.
`Tuple(str)`: Paths to the files saved.
"""
"""
if
not
os
.
path
.
isdir
(
save_directory
):
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
e
rror
(
f
"Vocabulary path (
{
save_directory
}
) "
raise
ValueE
rror
(
f
"Vocabulary path (
{
save_directory
}
) "
"should be a directory"
)
"should be a directory"
)
return
out_vocab_file
=
os
.
path
.
join
(
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
...
...
vllm/usage/usage_lib.py
View file @
99b471c2
...
@@ -7,7 +7,7 @@ import time
...
@@ -7,7 +7,7 @@ import time
from
enum
import
Enum
from
enum
import
Enum
from
pathlib
import
Path
from
pathlib
import
Path
from
threading
import
Thread
from
threading
import
Thread
from
typing
import
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
from
uuid
import
uuid4
from
uuid
import
uuid4
import
cpuinfo
import
cpuinfo
...
@@ -124,7 +124,7 @@ class UsageMessage:
...
@@ -124,7 +124,7 @@ class UsageMessage:
def
report_usage
(
self
,
def
report_usage
(
self
,
model_architecture
:
str
,
model_architecture
:
str
,
usage_context
:
UsageContext
,
usage_context
:
UsageContext
,
extra_kvs
:
Dict
[
str
,
a
ny
]
=
None
)
->
None
:
extra_kvs
:
Optional
[
Dict
[
str
,
A
ny
]
]
=
None
)
->
None
:
t
=
Thread
(
target
=
self
.
_report_usage_worker
,
t
=
Thread
(
target
=
self
.
_report_usage_worker
,
args
=
(
model_architecture
,
usage_context
,
extra_kvs
or
{}),
args
=
(
model_architecture
,
usage_context
,
extra_kvs
or
{}),
daemon
=
True
)
daemon
=
True
)
...
@@ -132,13 +132,13 @@ class UsageMessage:
...
@@ -132,13 +132,13 @@ class UsageMessage:
def
_report_usage_worker
(
self
,
model_architecture
:
str
,
def
_report_usage_worker
(
self
,
model_architecture
:
str
,
usage_context
:
UsageContext
,
usage_context
:
UsageContext
,
extra_kvs
:
Dict
[
str
,
a
ny
])
->
None
:
extra_kvs
:
Dict
[
str
,
A
ny
])
->
None
:
self
.
_report_usage_once
(
model_architecture
,
usage_context
,
extra_kvs
)
self
.
_report_usage_once
(
model_architecture
,
usage_context
,
extra_kvs
)
self
.
_report_continous_usage
()
self
.
_report_continous_usage
()
def
_report_usage_once
(
self
,
model_architecture
:
str
,
def
_report_usage_once
(
self
,
model_architecture
:
str
,
usage_context
:
UsageContext
,
usage_context
:
UsageContext
,
extra_kvs
:
Dict
[
str
,
a
ny
])
->
None
:
extra_kvs
:
Dict
[
str
,
A
ny
])
->
None
:
# Platform information
# Platform information
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device_property
=
torch
.
cuda
.
get_device_properties
(
0
)
device_property
=
torch
.
cuda
.
get_device_properties
(
0
)
...
...
vllm/utils.py
View file @
99b471c2
import
asyncio
import
asyncio
import
enum
import
enum
import
gc
import
gc
import
glob
import
os
import
os
import
socket
import
socket
import
subprocess
import
subprocess
import
uuid
import
uuid
import
warnings
import
warnings
from
collections
import
OrderedD
ict
from
collections
import
defaultd
ict
from
functools
import
lru_cache
,
partial
from
functools
import
lru_cache
,
partial
from
platform
import
uname
from
platform
import
uname
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Generic
,
Hashable
,
List
,
from
typing
import
(
Any
,
AsyncIterator
,
Awaitable
,
Callable
,
Dict
,
Generic
,
Optional
,
Tuple
,
TypeVar
,
Union
)
Hashable
,
List
,
Optional
,
OrderedDict
,
Tuple
,
TypeVar
,
Union
)
import
psutil
import
psutil
import
torch
import
torch
...
@@ -23,9 +25,9 @@ logger = init_logger(__name__)
...
@@ -23,9 +25,9 @@ logger = init_logger(__name__)
STR_DTYPE_TO_TORCH_DTYPE
=
{
STR_DTYPE_TO_TORCH_DTYPE
=
{
"half"
:
torch
.
half
,
"half"
:
torch
.
half
,
#
"bfloat16": torch.bfloat16,
"bfloat16"
:
torch
.
bfloat16
,
"float"
:
torch
.
float
,
"float"
:
torch
.
float
,
# "fp8
_e5m2
": torch.uint8,
# "fp8": torch.uint8,
}
}
...
@@ -51,7 +53,7 @@ class Counter:
...
@@ -51,7 +53,7 @@ class Counter:
class
LRUCache
(
Generic
[
T
]):
class
LRUCache
(
Generic
[
T
]):
def
__init__
(
self
,
capacity
:
int
):
def
__init__
(
self
,
capacity
:
int
):
self
.
cache
=
OrderedDict
[
Hashable
,
T
]()
self
.
cache
:
OrderedDict
[
Hashable
,
T
]
=
OrderedDict
()
self
.
capacity
=
capacity
self
.
capacity
=
capacity
def
__contains__
(
self
,
key
:
Hashable
)
->
bool
:
def
__contains__
(
self
,
key
:
Hashable
)
->
bool
:
...
@@ -60,7 +62,7 @@ class LRUCache(Generic[T]):
...
@@ -60,7 +62,7 @@ class LRUCache(Generic[T]):
def
__len__
(
self
)
->
int
:
def
__len__
(
self
)
->
int
:
return
len
(
self
.
cache
)
return
len
(
self
.
cache
)
def
__getitem__
(
self
,
key
:
Hashable
)
->
T
:
def
__getitem__
(
self
,
key
:
Hashable
)
->
Optional
[
T
]
:
return
self
.
get
(
key
)
return
self
.
get
(
key
)
def
__setitem__
(
self
,
key
:
Hashable
,
value
:
T
)
->
None
:
def
__setitem__
(
self
,
key
:
Hashable
,
value
:
T
)
->
None
:
...
@@ -76,7 +78,7 @@ class LRUCache(Generic[T]):
...
@@ -76,7 +78,7 @@ class LRUCache(Generic[T]):
key
:
Hashable
,
key
:
Hashable
,
default_value
:
Optional
[
T
]
=
None
)
->
Optional
[
T
]:
default_value
:
Optional
[
T
]
=
None
)
->
Optional
[
T
]:
if
key
in
self
.
cache
:
if
key
in
self
.
cache
:
value
=
self
.
cache
[
key
]
value
:
Optional
[
T
]
=
self
.
cache
[
key
]
self
.
cache
.
move_to_end
(
key
)
self
.
cache
.
move_to_end
(
key
)
else
:
else
:
value
=
default_value
value
=
default_value
...
@@ -87,7 +89,7 @@ class LRUCache(Generic[T]):
...
@@ -87,7 +89,7 @@ class LRUCache(Generic[T]):
self
.
cache
.
move_to_end
(
key
)
self
.
cache
.
move_to_end
(
key
)
self
.
_remove_old_if_needed
()
self
.
_remove_old_if_needed
()
def
_on_remove
(
self
,
key
:
Hashable
,
value
:
T
):
def
_on_remove
(
self
,
key
:
Hashable
,
value
:
Optional
[
T
]
):
pass
pass
def
remove_oldest
(
self
):
def
remove_oldest
(
self
):
...
@@ -100,9 +102,11 @@ class LRUCache(Generic[T]):
...
@@ -100,9 +102,11 @@ class LRUCache(Generic[T]):
while
len
(
self
.
cache
)
>
self
.
capacity
:
while
len
(
self
.
cache
)
>
self
.
capacity
:
self
.
remove_oldest
()
self
.
remove_oldest
()
def
pop
(
self
,
key
:
Hashable
,
default_value
:
Optional
[
Any
]
=
None
)
->
T
:
def
pop
(
self
,
key
:
Hashable
,
default_value
:
Optional
[
T
]
=
None
)
->
Optional
[
T
]:
run_on_remove
=
key
in
self
.
cache
run_on_remove
=
key
in
self
.
cache
value
=
self
.
cache
.
pop
(
key
,
default_value
)
value
:
Optional
[
T
]
=
self
.
cache
.
pop
(
key
,
default_value
)
if
run_on_remove
:
if
run_on_remove
:
self
.
_on_remove
(
key
,
value
)
self
.
_on_remove
(
key
,
value
)
return
value
return
value
...
@@ -117,6 +121,15 @@ def is_hip() -> bool:
...
@@ -117,6 +121,15 @@ def is_hip() -> bool:
return
torch
.
version
.
hip
is
not
None
return
torch
.
version
.
hip
is
not
None
@
lru_cache
(
maxsize
=
None
)
def
is_cpu
()
->
bool
:
from
importlib.metadata
import
PackageNotFoundError
,
version
try
:
return
"cpu"
in
version
(
"vllm"
)
except
PackageNotFoundError
:
return
False
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
is_neuron
()
->
bool
:
def
is_neuron
()
->
bool
:
try
:
try
:
...
@@ -150,6 +163,17 @@ def random_uuid() -> str:
...
@@ -150,6 +163,17 @@ def random_uuid() -> str:
return
str
(
uuid
.
uuid4
().
hex
)
return
str
(
uuid
.
uuid4
().
hex
)
@
lru_cache
(
maxsize
=
None
)
def
get_vllm_instance_id
():
"""
If the environment variable VLLM_INSTANCE_ID is set, return it.
Otherwise, return a random UUID.
Instance id represents an instance of the VLLM. All processes in the same
instance should have the same instance id.
"""
return
os
.
environ
.
get
(
"VLLM_INSTANCE_ID"
,
f
"vllm-instance-
{
random_uuid
()
}
"
)
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
in_wsl
()
->
bool
:
def
in_wsl
()
->
bool
:
# Reference: https://github.com/microsoft/WSL/issues/4071
# Reference: https://github.com/microsoft/WSL/issues/4071
...
@@ -171,7 +195,43 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
...
@@ -171,7 +195,43 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
return
_async_wrapper
return
_async_wrapper
def
merge_async_iterators
(
*
iterators
:
AsyncIterator
[
T
])
->
AsyncIterator
[
Tuple
[
int
,
T
]]:
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue
:
asyncio
.
Queue
[
Union
[
Tuple
[
int
,
T
],
Exception
]]
=
asyncio
.
Queue
()
finished
=
[
False
]
*
len
(
iterators
)
async
def
producer
(
i
:
int
,
iterator
:
AsyncIterator
[
T
]):
try
:
async
for
item
in
iterator
:
await
queue
.
put
((
i
,
item
))
except
Exception
as
e
:
await
queue
.
put
(
e
)
finished
[
i
]
=
True
_tasks
=
[
asyncio
.
create_task
(
producer
(
i
,
iterator
))
for
i
,
iterator
in
enumerate
(
iterators
)
]
async
def
consumer
():
while
not
all
(
finished
)
or
not
queue
.
empty
():
item
=
await
queue
.
get
()
if
isinstance
(
item
,
Exception
):
raise
item
yield
item
await
asyncio
.
gather
(
*
_tasks
)
return
consumer
()
def
get_ip
()
->
str
:
def
get_ip
()
->
str
:
host_ip
=
os
.
environ
.
get
(
"HOST_IP"
)
host_ip
=
os
.
environ
.
get
(
"HOST_IP"
)
if
host_ip
:
if
host_ip
:
...
@@ -223,8 +283,12 @@ def get_open_port() -> int:
...
@@ -223,8 +283,12 @@ def get_open_port() -> int:
return
s
.
getsockname
()[
1
]
return
s
.
getsockname
()[
1
]
def
set_cuda_visible_devices
(
device_ids
:
List
[
int
])
->
None
:
def
update_environment_variables
(
envs
:
Dict
[
str
,
str
]):
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
","
.
join
(
map
(
str
,
device_ids
))
for
k
,
v
in
envs
.
items
():
if
k
in
os
.
environ
and
os
.
environ
[
k
]
!=
v
:
logger
.
warning
(
f
"Overwriting environment variable
{
k
}
"
f
"from '
{
os
.
environ
[
k
]
}
' to '
{
v
}
'"
)
os
.
environ
[
k
]
=
v
def
chunk_list
(
lst
,
chunk_size
):
def
chunk_list
(
lst
,
chunk_size
):
...
@@ -257,7 +321,7 @@ def get_nvcc_cuda_version() -> Optional[Version]:
...
@@ -257,7 +321,7 @@ def get_nvcc_cuda_version() -> Optional[Version]:
return
nvcc_cuda_version
return
nvcc_cuda_version
def
_generate_random_fp8
_e5m2
(
def
_generate_random_fp8
(
tensor
:
torch
.
tensor
,
tensor
:
torch
.
tensor
,
low
:
float
,
low
:
float
,
high
:
float
,
high
:
float
,
...
@@ -270,10 +334,10 @@ def _generate_random_fp8_e5m2(
...
@@ -270,10 +334,10 @@ def _generate_random_fp8_e5m2(
#-----|-------------|-------------------
#-----|-------------|-------------------
# Inf | N/A | s.11111.00
# Inf | N/A | s.11111.00
# NaN | s.1111.111 | s.11111.{01,10,11}
# NaN | s.1111.111 | s.11111.{01,10,11}
from
vllm
._C
import
cache_
ops
from
vllm
import
_custom_ops
as
ops
tensor_tmp
=
torch
.
empty_like
(
tensor
,
dtype
=
torch
.
float16
)
tensor_tmp
=
torch
.
empty_like
(
tensor
,
dtype
=
torch
.
float16
)
tensor_tmp
.
uniform_
(
low
,
high
)
tensor_tmp
.
uniform_
(
low
,
high
)
cache_
ops
.
convert_fp8
_e5m2
(
tensor_tmp
,
tensor
)
ops
.
convert_fp8
(
tensor_tmp
,
tensor
)
del
tensor_tmp
del
tensor_tmp
...
@@ -285,7 +349,7 @@ def create_kv_caches_with_random(
...
@@ -285,7 +349,7 @@ def create_kv_caches_with_random(
head_size
:
int
,
head_size
:
int
,
cache_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]],
cache_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]],
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
seed
:
Optional
[
int
]
=
0
,
seed
:
int
=
0
,
device
:
Optional
[
str
]
=
"cuda"
,
device
:
Optional
[
str
]
=
"cuda"
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
...
@@ -302,7 +366,7 @@ def create_kv_caches_with_random(
...
@@ -302,7 +366,7 @@ def create_kv_caches_with_random(
raise
ValueError
(
f
"Invalid model dtype:
{
model_dtype
}
"
)
raise
ValueError
(
f
"Invalid model dtype:
{
model_dtype
}
"
)
elif
cache_dtype
in
[
"half"
,
"bfloat16"
,
"float"
]:
elif
cache_dtype
in
[
"half"
,
"bfloat16"
,
"float"
]:
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_dtype
]
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_dtype
]
elif
cache_dtype
==
"fp8
_e5m2
"
:
elif
cache_dtype
==
"fp8"
:
torch_dtype
=
torch
.
uint8
torch_dtype
=
torch
.
uint8
else
:
else
:
raise
ValueError
(
f
"Invalid kv cache dtype:
{
cache_dtype
}
"
)
raise
ValueError
(
f
"Invalid kv cache dtype:
{
cache_dtype
}
"
)
...
@@ -319,10 +383,10 @@ def create_kv_caches_with_random(
...
@@ -319,10 +383,10 @@ def create_kv_caches_with_random(
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
torch_dtype
,
dtype
=
torch_dtype
,
device
=
device
)
device
=
device
)
if
cache_dtype
==
'fp8_e5m2'
:
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
_generate_random_fp8_e5m2
(
key_cache
,
-
scale
,
scale
)
elif
torch_dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
key_cache
.
uniform_
(
-
scale
,
scale
)
key_cache
.
uniform_
(
-
scale
,
scale
)
elif
cache_dtype
==
'fp8'
:
_generate_random_fp8
(
key_cache
,
-
scale
,
scale
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Does not support key cache of type
{
cache_dtype
}
"
)
f
"Does not support key cache of type
{
cache_dtype
}
"
)
...
@@ -334,10 +398,10 @@ def create_kv_caches_with_random(
...
@@ -334,10 +398,10 @@ def create_kv_caches_with_random(
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
torch_dtype
,
dtype
=
torch_dtype
,
device
=
device
)
device
=
device
)
if
cache_dtype
==
'fp8_e5m2'
:
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
_generate_random_fp8_e5m2
(
value_cache
,
-
scale
,
scale
)
elif
torch_dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
value_cache
.
uniform_
(
-
scale
,
scale
)
value_cache
.
uniform_
(
-
scale
,
scale
)
elif
cache_dtype
==
'fp8'
:
_generate_random_fp8
(
value_cache
,
-
scale
,
scale
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Does not support value cache of type
{
cache_dtype
}
"
)
f
"Does not support value cache of type
{
cache_dtype
}
"
)
...
@@ -362,6 +426,8 @@ def is_pin_memory_available() -> bool:
...
@@ -362,6 +426,8 @@ def is_pin_memory_available() -> bool:
elif
is_neuron
():
elif
is_neuron
():
print_warning_once
(
"Pin memory is not supported on Neuron."
)
print_warning_once
(
"Pin memory is not supported on Neuron."
)
return
False
return
False
elif
is_cpu
():
return
False
return
True
return
True
...
@@ -389,7 +455,7 @@ class CudaMemoryProfiler:
...
@@ -389,7 +455,7 @@ class CudaMemoryProfiler:
gc
.
collect
()
gc
.
collect
()
def
str_to_int_tuple
(
s
:
str
)
->
Tuple
[
int
]:
def
str_to_int_tuple
(
s
:
str
)
->
Tuple
[
int
,
...
]:
"""Convert a string to a tuple of integers."""
"""Convert a string to a tuple of integers."""
try
:
try
:
return
tuple
(
map
(
int
,
s
.
split
(
","
)))
return
tuple
(
map
(
int
,
s
.
split
(
","
)))
...
@@ -438,3 +504,106 @@ def maybe_expand_dim(tensor: torch.Tensor,
...
@@ -438,3 +504,106 @@ def maybe_expand_dim(tensor: torch.Tensor,
if
tensor
.
ndim
<
target_dims
:
if
tensor
.
ndim
<
target_dims
:
tensor
=
tensor
.
view
(
-
1
,
*
([
size
]
*
(
target_dims
-
tensor
.
ndim
)))
tensor
=
tensor
.
view
(
-
1
,
*
([
size
]
*
(
target_dims
-
tensor
.
ndim
)))
return
tensor
return
tensor
def
merge_dicts
(
dict1
:
Dict
[
Any
,
List
[
Any
]],
dict2
:
Dict
[
Any
,
List
[
Any
]])
->
Dict
[
Any
,
List
[
Any
]]:
"""Merge 2 dicts that have key -> List of items.
When a key conflicts, the values in dict1 is prioritized.
"""
merged_dict
=
defaultdict
(
list
)
for
key
,
value
in
dict1
.
items
():
merged_dict
[
key
].
extend
(
value
)
for
key
,
value
in
dict2
.
items
():
merged_dict
[
key
].
extend
(
value
)
return
dict
(
merged_dict
)
def
init_cached_hf_modules
():
"""
Lazy initialization of the Hugging Face modules.
"""
from
transformers.dynamic_module_utils
import
init_hf_modules
init_hf_modules
()
def
nccl_integrity_check
(
filepath
):
"""
when the library is corrupted, we cannot catch
the exception in python. it will crash the process.
instead, we use the exit code of `ldd` to check
if the library is corrupted. if not, we will return
the version of the library.
"""
exit_code
=
os
.
system
(
f
"ldd
{
filepath
}
2>&1 > /dev/null"
)
if
exit_code
!=
0
:
raise
RuntimeError
(
f
"Failed to load NCCL library from
{
filepath
}
."
)
import
ctypes
nccl
=
ctypes
.
CDLL
(
filepath
)
version
=
ctypes
.
c_int
()
nccl
.
ncclGetVersion
.
restype
=
ctypes
.
c_int
nccl
.
ncclGetVersion
.
argtypes
=
[
ctypes
.
POINTER
(
ctypes
.
c_int
)]
result
=
nccl
.
ncclGetVersion
(
ctypes
.
byref
(
version
))
assert
result
==
0
return
version
.
value
@
lru_cache
(
maxsize
=
None
)
def
find_library
(
lib_name
:
str
)
->
str
:
"""
Find the library file in the system.
`lib_name` is full filename, with both prefix and suffix.
This function resolves `lib_name` to the full path of the library.
"""
# Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
# According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
# `/sbin/ldconfig` should exist in all Linux systems.
# `/sbin/ldconfig` searches the library in the system
libs
=
subprocess
.
check_output
([
"/sbin/ldconfig"
,
"-p"
]).
decode
()
# each line looks like the following:
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
locs
=
[
line
.
split
()[
-
1
]
for
line
in
libs
.
splitlines
()
if
lib_name
in
line
]
# `LD_LIBRARY_PATH` searches the library in the user-defined paths
env_ld_library_path
=
os
.
getenv
(
"LD_LIBRARY_PATH"
)
if
not
locs
and
env_ld_library_path
:
locs
=
[
os
.
path
.
join
(
dir
,
lib_name
)
for
dir
in
env_ld_library_path
.
split
(
":"
)
if
os
.
path
.
exists
(
os
.
path
.
join
(
dir
,
lib_name
))
]
if
not
locs
:
raise
ValueError
(
f
"Cannot find
{
lib_name
}
in the system."
)
return
locs
[
0
]
def
find_nccl_library
():
so_file
=
os
.
environ
.
get
(
"VLLM_NCCL_SO_PATH"
,
""
)
# check if we have vllm-managed nccl
vllm_nccl_path
=
None
if
torch
.
version
.
cuda
is
not
None
:
cuda_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
path
=
os
.
path
.
expanduser
(
f
"~/.config/vllm/nccl/cu
{
cuda_major
}
/libnccl.so.*"
)
files
=
glob
.
glob
(
path
)
vllm_nccl_path
=
files
[
0
]
if
files
else
None
# manually load the nccl library
if
so_file
:
logger
.
info
(
f
"Found nccl from environment variable VLLM_NCCL_SO_PATH=
{
so_file
}
"
)
else
:
if
torch
.
version
.
cuda
is
not
None
:
so_file
=
vllm_nccl_path
or
find_library
(
"libnccl.so.2"
)
elif
torch
.
version
.
hip
is
not
None
:
so_file
=
find_library
(
"librccl.so.1"
)
else
:
raise
ValueError
(
"NCCL only supports CUDA and ROCm backends."
)
logger
.
info
(
f
"Found nccl from library
{
so_file
}
"
)
return
so_file
vllm/worker/cache_engine.py
View file @
99b471c2
...
@@ -82,8 +82,7 @@ class CacheEngine:
...
@@ -82,8 +82,7 @@ class CacheEngine:
@
staticmethod
@
staticmethod
def
get_cache_block_size
(
def
get_cache_block_size
(
block_size
:
int
,
cache_config
:
CacheConfig
,
cache_dtype
:
str
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
)
->
int
:
)
->
int
:
...
@@ -91,13 +90,13 @@ class CacheEngine:
...
@@ -91,13 +90,13 @@ class CacheEngine:
num_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
num_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
key_cache_block
=
block_size
*
num_heads
*
head_size
key_cache_block
=
cache_config
.
block_size
*
num_heads
*
head_size
value_cache_block
=
key_cache_block
value_cache_block
=
key_cache_block
total
=
num_layers
*
(
key_cache_block
+
value_cache_block
)
total
=
num_layers
*
(
key_cache_block
+
value_cache_block
)
if
cache_dtype
==
"auto"
:
if
cache_config
.
cache_dtype
==
"auto"
:
dtype
=
model_config
.
dtype
dtype
=
model_config
.
dtype
else
:
else
:
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_dtype
]
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_
config
.
cache_
dtype
]
dtype_size
=
_get_dtype_size
(
dtype
)
dtype_size
=
_get_dtype_size
(
dtype
)
return
dtype_size
*
total
return
dtype_size
*
total
...
...
vllm/worker/cpu_model_runner.py
0 → 100644
View file @
99b471c2
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
,
maybe_expand_dim
logger
=
init_logger
(
__name__
)
_PAD_SLOT_ID
=
-
1
class
CPUModelRunner
:
def
__init__
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
*
args
,
**
kwargs
,
):
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
lora_config
=
lora_config
self
.
vision_language_config
=
vision_language_config
self
.
load_config
=
load_config
self
.
is_driver_worker
=
is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self
.
sliding_window
=
(
model_config
.
get_sliding_window
()
if
model_config
is
not
None
else
None
)
self
.
device_config
=
(
device_config
if
device_config
is
not
None
else
DeviceConfig
())
self
.
device
=
self
.
device_config
.
device
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
dtype
if
model_config
is
not
None
else
None
)
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
self
.
block_size
:
int
# Set after initial profiling.
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
load_config
=
self
.
load_config
,
device_config
=
self
.
device_config
,
vision_language_config
=
self
.
vision_language_config
,
lora_config
=
self
.
lora_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
)
def
_prepare_prompt
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
],
Optional
[
torch
.
Tensor
]]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
prompt_lens
:
List
[
int
]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
computed_len
=
seq_data
.
get_num_computed_tokens
()
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
# Prompt token num
input_tokens
.
extend
(
prompt_tokens
)
# Token ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions
.
extend
(
list
(
range
(
computed_len
,
prompt_len
)))
if
seq_group_metadata
.
multi_modal_data
:
multi_modal_input_list
.
append
(
seq_group_metadata
.
multi_modal_data
.
data
)
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
start_idx
=
max
(
0
,
prompt_len
-
self
.
sliding_window
)
for
i
in
range
(
computed_len
,
prompt_len
):
if
i
<
start_idx
:
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
block_number
=
block_table
[
i
//
self
.
block_size
]
# type: ignore
block_offset
=
i
%
self
.
block_size
# type: ignore
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
if
multi_modal_input_list
:
assert
self
.
vision_language_config
,
(
"Multi-modal inputs are only supported by "
"vision language models."
)
multi_modal_input
=
torch
.
cat
(
multi_modal_input_list
,
dim
=
0
).
to
(
self
.
device
)
else
:
multi_modal_input
=
None
num_prompt_tokens
=
len
(
input_tokens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
prompt_lens
=
prompt_lens
,
num_prefills
=
len
(
prompt_lens
),
num_prefill_tokens
=
num_prompt_tokens
,
num_decode_tokens
=
0
,
prefill_metadata
=
None
,
decode_metadata
=
None
,
max_context_len
=
None
,
context_lens
=
None
,
block_tables
=
torch
.
tensor
([]),
slot_mapping
=
slot_mapping
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
,
multi_modal_input
)
def
_prepare_decode
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
not
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
token_chunk_size
==
1
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
(
generation_token
)
seq_len
=
seq_data
.
get_len
()
position
=
seq_len
-
1
input_positions
.
append
(
position
)
context_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq_len
,
self
.
sliding_window
)
context_lens
.
append
(
context_len
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
if
self
.
sliding_window
is
not
None
:
sliding_window_blocks
=
(
self
.
sliding_window
//
self
.
block_size
)
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
max_context_len
=
max
(
context_lens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
block_tables
)
block_tables
=
make_tensor_with_pad
(
block_tables
,
max_len
=
max_block_table_len
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
self
.
device
,
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
slot_mapping
=
slot_mapping
,
prompt_lens
=
None
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
len
(
input_tokens
),
max_context_len
=
max_context_len
,
num_prefills
=
0
,
prefill_metadata
=
None
,
decode_metadata
=
None
,
context_lens
=
context_lens
,
block_tables
=
block_tables
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
)
def
_prepare_sample
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt_lens
:
List
[
int
],
)
->
SamplingMetadata
:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
selected_token_start_idx
=
0
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
if
seq_group_metadata
.
is_prompt
:
assert
len
(
seq_ids
)
==
1
subquery_len
=
prompt_lens
[
i
]
if
sampling_params
.
prompt_logprobs
is
not
None
:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx
+=
subquery_len
-
1
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
(
(
categorized_sample_indices_start_idx
,
categorized_sampled_token_indices_start_idx
))
categorized_sample_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
if
sampling_params
.
prompt_logprobs
is
not
None
:
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
subquery_len
-
1
))
selected_token_indices
.
append
(
selected_token_start_idx
+
subquery_len
-
1
)
selected_token_start_idx
+=
subquery_len
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
sampling_params
.
seed
)
else
:
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
num_seqs
))
selected_token_start_idx
+=
num_seqs
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
zip
(
range
(
categorized_sample_indices_start_idx
,
categorized_sample_indices_start_idx
+
num_seqs
),
range
(
categorized_sampled_token_indices_start_idx
,
categorized_sampled_token_indices_start_idx
+
num_seqs
)))
categorized_sample_indices_start_idx
+=
num_seqs
categorized_sampled_token_indices_start_idx
+=
num_seqs
if
sampling_params
.
seed
is
not
None
:
generators
.
append
(
seq_group_metadata
.
state
.
generator
)
selected_token_indices
=
torch
.
tensor
(
selected_token_indices
,
dtype
=
torch
.
long
)
categorized_sample_indices
=
{
t
:
maybe_expand_dim
(
torch
.
tensor
(
seq_ids
,
dtype
=
torch
.
int
),
2
,
2
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups
,
seq_data
=
seq_data
,
prompt_lens
=
prompt_lens
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
categorized_sample_indices
,
generators
=
generators
,
)
return
sampling_metadata
def
prepare_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
,
Optional
[
torch
.
Tensor
]]:
multi_modal_input
=
None
if
self
.
is_driver_worker
:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
,
multi_modal_input
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
attn_metadata
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
prompt_lens
=
[]
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
# Broadcast the metadata.
metadata_dict
=
{
"input_tokens"
:
input_tokens
,
"input_positions"
:
input_positions
,
"selected_token_indices"
:
sampling_metadata
.
selected_token_indices
,
}
metadata_dict
.
update
(
attn_metadata
.
asdict_zerocopy
())
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
else
:
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
input_tokens
=
metadata_dict
.
pop
(
"input_tokens"
)
input_positions
=
metadata_dict
.
pop
(
"input_positions"
)
selected_token_indices
=
metadata_dict
.
pop
(
"selected_token_indices"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_data
=
None
,
prompt_lens
=
None
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
None
,
generators
=
None
,
perform_sampling
=
False
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
multi_modal_input
)
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
kv_caches
:
List
[
torch
.
Tensor
],
)
->
Optional
[
SamplerOutput
]:
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
multi_modal_input
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
model_executable
=
self
.
model
execute_model_kwargs
=
{
"input_ids"
:
input_tokens
,
"positions"
:
input_positions
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
attn_metadata
,
}
if
self
.
vision_language_config
:
execute_model_kwargs
.
update
({
"image_input"
:
multi_modal_input
})
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
# Only perform sampling in the driver worker.
if
not
sampling_metadata
.
perform_sampling
:
return
None
# Sample the next token.
output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
return
output
vllm/worker/cpu_worker.py
0 → 100644
View file @
99b471c2
"""A CPU worker class."""
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.distributed
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
(
broadcast_tensor_dict
,
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.worker.cpu_model_runner
import
CPUModelRunner
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
logger
=
init_logger
(
__name__
)
class
CPUCacheEngine
:
"""Manages the KV cache for CPU backend.
This class is responsible for initializing and managing CPU KV
caches. It also provides methods for performing KV cache operations, such
as copying.
"""
def
__init__
(
self
,
cache_config
:
CacheConfig
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
device_config
:
DeviceConfig
)
->
None
:
assert
device_config
.
device_type
==
"cpu"
self
.
cache_config
=
cache_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
head_size
=
model_config
.
get_head_size
()
self
.
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
self
.
num_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
block_size
=
cache_config
.
block_size
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
# for CPU backend, because we want to reuse KV cache management
# in the scheduler.
self
.
num_cpu_blocks
=
cache_config
.
num_gpu_blocks
if
cache_config
.
cache_dtype
==
"auto"
:
self
.
dtype
=
model_config
.
dtype
else
:
self
.
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
# Get attention backend.
self
.
attn_backend
=
get_attn_backend
(
model_config
.
dtype
)
# Initialize the cache.
self
.
cpu_cache
=
self
.
_allocate_kv_cache
(
self
.
num_cpu_blocks
)
def
_allocate_kv_cache
(
self
,
num_blocks
:
int
,
)
->
List
[
torch
.
Tensor
]:
"""Allocates KV cache on CPU."""
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
,
self
.
block_size
,
self
.
num_heads
,
self
.
head_size
)
kv_cache
:
List
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
self
.
num_layers
):
kv_cache
.
append
(
torch
.
empty
(
kv_cache_shape
,
dtype
=
self
.
dtype
,
device
=
"cpu"
))
return
kv_cache
def
swap_in
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
raise
NotImplementedError
(
"Swap is not supported in CPUCacheEngine."
)
def
swap_out
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
raise
NotImplementedError
(
"Swap is not supported in CPUCacheEngine."
)
def
copy
(
self
,
src_to_dsts
:
Dict
[
int
,
List
[
int
]])
->
None
:
self
.
attn_backend
.
copy_blocks
(
self
.
cpu_cache
,
src_to_dsts
)
@
staticmethod
def
get_cache_block_size
(
block_size
:
int
,
cache_dtype
:
str
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
)
->
int
:
head_size
=
model_config
.
get_head_size
()
num_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
key_cache_block
=
block_size
*
num_heads
*
head_size
value_cache_block
=
key_cache_block
total
=
num_layers
*
(
key_cache_block
+
value_cache_block
)
if
cache_dtype
==
"auto"
:
dtype
=
model_config
.
dtype
else
:
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_dtype
]
dtype_size
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
return
dtype_size
*
total
class
CPUWorker
(
LoraNotSupportedWorkerBase
):
"""A worker class that executes (a partition of) the model on a CPU socket.
Each worker is associated with a single CPU socket. The worker is
responsible for maintaining the KV cache and executing the model on the
CPU. In case of distributed inference, each worker is assigned a partition
of the model.
"""
def
__init__
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
vision_language_config
:
Optional
[
VisionLanguageConfig
]
=
None
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
)
->
None
:
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
load_config
=
load_config
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
lora_config
=
lora_config
self
.
vision_language_config
=
vision_language_config
self
.
is_driver_worker
=
is_driver_worker
if
self
.
is_driver_worker
:
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
if
self
.
model_config
.
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
init_cached_hf_modules
()
self
.
model_runner
=
CPUModelRunner
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
load_config
=
self
.
load_config
,
lora_config
=
self
.
lora_config
,
vision_language_config
=
self
.
vision_language_config
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
is_driver_worker
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self
.
cache_engine
:
CPUCacheEngine
self
.
cpu_cache
:
List
[
torch
.
Tensor
]
def
init_device
(
self
)
->
None
:
self
.
init_distributed_environment
()
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
def
load_model
(
self
):
self
.
model_runner
.
load_model
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of blocks available for the KV cache.
This determines how many KV blocks can fit into the configured CPU
KV cache space.
Note that since vLLM assumes a block resides on GPU if it can be
modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
This allows us to reuse the scheduler of vLLM without generalizing it
to different devices.
"""
# For CPU device, the block number will be calculated based on the
# cpu_kvcache_space.
cache_block_size
=
self
.
get_cache_block_size_bytes
()
num_cpu_blocks
=
int
(
self
.
cache_config
.
cpu_kvcache_space_bytes
//
cache_block_size
)
num_cpu_blocks
=
max
(
num_cpu_blocks
,
0
)
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_gpu_blocks
=
num_cpu_blocks
num_cpu_blocks
=
0
return
num_gpu_blocks
,
num_cpu_blocks
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""Initialize the KV cache. Currently, swappable CPU memory is not
supported.
Since this worker does not support GPUs, we use the num_gpu_blocks to
determine how many non-swappable CPU blocks to allocate.
"""
assert
(
num_cpu_blocks
==
0
),
f
"
{
type
(
self
)
}
does not support swappable cache"
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_cpu_blocks
=
num_gpu_blocks
self
.
_validate_num_cpu_blocks
(
num_cpu_blocks
)
self
.
cache_config
.
num_gpu_blocks
=
num_cpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
0
# Initialize the cache.
self
.
_init_cache_engine
()
def
_validate_num_cpu_blocks
(
self
,
num_cpu_blocks
:
int
)
->
None
:
"""Raise errors if the num_cpu_blocks is invalid.
"""
if
num_cpu_blocks
<=
0
:
raise
ValueError
(
"No available memory for the cache blocks. "
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
"initializing the engine."
)
max_seq_len
=
self
.
cache_config
.
block_size
*
num_cpu_blocks
if
self
.
model_config
.
max_model_len
>
max_seq_len
:
raise
ValueError
(
f
"The model's max seq len (
{
self
.
model_config
.
max_model_len
}
) "
"is larger than the maximum number of tokens that can be "
f
"stored in KV cache (
{
max_seq_len
}
). Try increasing "
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
"initializing the engine."
)
def
_init_cache_engine
(
self
)
->
None
:
self
.
cache_engine
=
CPUCacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
,
self
.
device_config
)
self
.
cpu_cache
=
self
.
cache_engine
.
cpu_cache
self
.
model_runner
.
block_size
=
self
.
cache_engine
.
block_size
assert
self
.
cpu_cache
is
not
None
# Populate the cache to warmup the memory
for
layer_cache
in
self
.
cpu_cache
:
layer_cache
.
fill_
(
0
)
def
cache_copy
(
self
,
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
if
blocks_to_copy
:
self
.
cache_engine
.
copy
(
blocks_to_copy
)
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
,
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]]
=
None
,
)
->
List
[
SamplerOutput
]:
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
num_seq_groups
:
int
=
len
(
seq_group_metadata_list
)
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
assert
len
(
blocks_to_swap_in
)
==
0
assert
len
(
blocks_to_swap_out
)
==
0
data
:
Dict
[
str
,
Any
]
=
{
"num_seq_groups"
:
num_seq_groups
,
"blocks_to_copy"
:
blocks_to_copy
,
}
broadcast_tensor_dict
(
data
,
src
=
0
)
else
:
data
=
broadcast_tensor_dict
(
src
=
0
)
num_seq_groups
=
data
[
"num_seq_groups"
]
blocks_to_copy
=
data
[
"blocks_to_copy"
]
assert
blocks_to_copy
is
not
None
self
.
cache_copy
(
blocks_to_copy
)
# If there is no input, we don't need to execute the model.
if
num_seq_groups
==
0
:
return
[]
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
self
.
cpu_cache
)
# CPU worker only supports single-step execution.
return
[
output
]
def
init_distributed_environment
(
self
)
->
None
:
"""Initialize the distributed environment."""
parallel_config
=
self
.
parallel_config
rank
=
self
.
rank
distributed_init_method
=
self
.
distributed_init_method
init_distributed_environment
(
world_size
=
parallel_config
.
world_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
backend
=
"gloo"
,
)
# A small all_reduce for warmup.
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
).
cpu
())
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
def
get_cache_block_size_bytes
(
self
)
->
int
:
"""Return the size in bytes of a single KV cache block.
"""
return
CPUCacheEngine
.
get_cache_block_size
(
self
.
cache_config
.
block_size
,
self
.
cache_config
.
cache_dtype
,
self
.
model_config
,
self
.
parallel_config
)
vllm/worker/model_runner.py
View file @
99b471c2
import
contextlib
import
contextlib
import
time
import
time
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
enum
import
IntEnum
from
typing
import
Dict
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataPerStage
,
from
vllm.config
import
(
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
get_attn_backend
)
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
,
with_pynccl_for_all_reduce
from
vllm.distributed.device_communicators
import
(
custom_all_reduce
,
pynccl_utils
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.parallel_utils
import
custom_all_reduce
,
pynccl_utils
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast_tensor_dict
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
with_pynccl_for_all_reduce
)
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
SequenceData
,
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
async_tensor_h2d
,
from
vllm.utils
import
(
CudaMemoryProfiler
,
async_tensor_h2d
,
is_hip
,
is_pin_memory_available
,
make_tensor_with_pad
,
is_pin_memory_available
,
make_tensor_with_pad
,
maybe_expand_dim
)
maybe_expand_dim
)
...
@@ -39,6 +39,66 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
...
@@ -39,6 +39,66 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
]
]
class
PreparePromptMetadata
(
NamedTuple
):
input_tokens
:
List
[
int
]
input_positions
:
List
[
int
]
attn_metadata
:
Optional
[
AttentionMetadataPerStage
]
prompt_lens
:
List
[
int
]
subquery_lens
:
List
[
int
]
lora_index_mapping
:
List
[
int
]
lora_prompt_mapping
:
List
[
int
]
lora_requests
:
Set
[
LoRARequest
]
multi_modal_input
:
Optional
[
torch
.
Tensor
]
slot_mapping
:
List
[
int
]
@
classmethod
def
empty
(
cls
):
return
PreparePromptMetadata
(
input_tokens
=
[],
input_positions
=
[],
attn_metadata
=
None
,
prompt_lens
=
[],
subquery_lens
=
[],
lora_index_mapping
=
[],
lora_prompt_mapping
=
[],
lora_requests
=
set
(),
multi_modal_input
=
None
,
slot_mapping
=
[],
)
class
PrepareDecodeMetadata
(
NamedTuple
):
input_tokens
:
List
[
int
]
input_positions
:
List
[
int
]
attn_metadata
:
Optional
[
AttentionMetadata
]
lora_index_mapping
:
List
[
int
]
lora_prompt_mapping
:
List
[
int
]
lora_requests
:
Set
[
LoRARequest
]
slot_mapping
:
List
[
int
]
@
classmethod
def
empty
(
cls
):
return
PrepareDecodeMetadata
(
input_tokens
=
[],
input_positions
=
[],
attn_metadata
=
None
,
lora_index_mapping
=
[],
lora_prompt_mapping
=
[],
lora_requests
=
set
(),
slot_mapping
=
[],
)
# How batches are constructed.
class
BatchType
(
IntEnum
):
# Every batch is prefill.
PREFILL
=
0
# Every batch is decode.
DECODE
=
1
# Batch is a mixture of prefill and decode.
MIXED
=
2
class
ModelRunner
:
class
ModelRunner
:
def
__init__
(
def
__init__
(
...
@@ -47,6 +107,7 @@ class ModelRunner:
...
@@ -47,6 +107,7 @@ class ModelRunner:
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
is_driver_worker
:
bool
=
False
,
...
@@ -56,6 +117,7 @@ class ModelRunner:
...
@@ -56,6 +117,7 @@ class ModelRunner:
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
load_config
=
load_config
self
.
is_driver_worker
=
is_driver_worker
self
.
is_driver_worker
=
is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py.
# model_config can be None in tests/samplers/test_sampler.py.
...
@@ -66,23 +128,17 @@ class ModelRunner:
...
@@ -66,23 +128,17 @@ class ModelRunner:
if
device_config
is
not
None
else
DeviceConfig
())
if
device_config
is
not
None
else
DeviceConfig
())
self
.
device
=
self
.
device_config
.
device
self
.
device
=
self
.
device_config
.
device
self
.
model
=
None
# Set after load_model.
self
.
block_size
=
None
# Set after initial profiling.
self
.
lora_manager
:
LRUCacheWorkerLoRAManager
=
None
self
.
lora_manager
=
None
self
.
graph_runners
:
Dict
[
int
,
CUDAGraphRunner
]
=
{}
self
.
graph_runners
:
Dict
[
int
,
CUDAGraphRunner
]
=
{}
self
.
graph_memory_pool
=
None
# Set during graph capture.
self
.
graph_memory_pool
:
Optional
[
Tuple
[
int
,
int
]]
=
None
# Set during graph capture.
self
.
max_context_len_to_capture
=
(
self
.
max_context_len_to_capture
=
(
self
.
model_config
.
max_context_len_to_capture
self
.
model_config
.
max_context_len_to_capture
if
self
.
model_config
is
not
None
else
0
)
if
self
.
model_config
is
not
None
else
0
)
# When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self
.
graph_block_tables
=
None
# Set after initial profiling.
self
.
pin_memory
=
is_pin_memory_available
()
self
.
pin_memory
=
is_pin_memory_available
()
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
vision_language_config
=
vision_language_config
self
.
vision_language_config
=
vision_language_config
...
@@ -90,15 +146,28 @@ class ModelRunner:
...
@@ -90,15 +146,28 @@ class ModelRunner:
self
.
attn_backend
=
get_attn_backend
(
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
dtype
if
model_config
is
not
None
else
None
)
self
.
model_config
.
dtype
if
model_config
is
not
None
else
None
)
# Lazy initialization
self
.
model
:
torch
.
nn
.
Module
# Set after load_model
self
.
block_size
:
int
# Set after initial profiling.
# When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self
.
graph_block_tables
:
torch
.
Tensor
# Set after initial profiling.
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
with
CudaMemoryProfiler
()
as
m
:
with
CudaMemoryProfiler
()
as
m
:
self
.
model
=
get_model
(
self
.
model
=
get_model
(
self
.
model_config
,
model_config
=
self
.
model_config
,
self
.
device_config
,
device_config
=
self
.
device_config
,
load_config
=
self
.
load_config
,
lora_config
=
self
.
lora_config
,
lora_config
=
self
.
lora_config
,
vision_language_config
=
self
.
vision_language_config
,
vision_language_config
=
self
.
vision_language_config
,
parallel_config
=
self
.
parallel_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
)
scheduler_config
=
self
.
scheduler_config
,
)
self
.
model_memory_usage
=
m
.
consumed_memory
self
.
model_memory_usage
=
m
.
consumed_memory
logger
.
info
(
f
"Loading model weights took "
logger
.
info
(
f
"Loading model weights took "
...
@@ -120,6 +189,26 @@ class ModelRunner:
...
@@ -120,6 +189,26 @@ class ModelRunner:
self
.
model
.
embedding_padding_modules
)
self
.
model
.
embedding_padding_modules
)
self
.
model
=
self
.
lora_manager
.
create_lora_manager
(
self
.
model
)
self
.
model
=
self
.
lora_manager
.
create_lora_manager
(
self
.
model
)
if
self
.
kv_cache_dtype
==
"fp8"
and
is_hip
():
# Currently scaled KV cache is only enabled on ROCm
if
self
.
model_config
.
quantization_param_path
is
not
None
:
if
callable
(
getattr
(
self
.
model
,
"load_kv_cache_scales"
,
None
)):
self
.
model
.
load_kv_cache_scales
(
self
.
model_config
.
quantization_param_path
)
else
:
raise
RuntimeError
(
"Using FP8 KV cache and scaling "
"factors provided but model "
f
"
{
self
.
model
.
__class__
}
does not "
"support loading scaling factors."
)
else
:
logger
.
warn
(
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!"
)
elif
self
.
model_config
.
quantization_param_path
is
not
None
:
logger
.
warn
(
"KV cache scaling factors provided, "
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used."
)
def
set_block_size
(
self
,
block_size
:
int
)
->
None
:
def
set_block_size
(
self
,
block_size
:
int
)
->
None
:
self
.
block_size
=
block_size
self
.
block_size
=
block_size
...
@@ -134,10 +223,7 @@ class ModelRunner:
...
@@ -134,10 +223,7 @@ class ModelRunner:
def
_prepare_prompt
(
def
_prepare_prompt
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
],
)
->
PreparePromptMetadata
:
List
[
int
],
List
[
int
],
List
[
int
],
Set
[
LoRARequest
],
torch
.
Tensor
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
...
@@ -151,6 +237,9 @@ class ModelRunner:
...
@@ -151,6 +237,9 @@ class ModelRunner:
prefix_block_tables
:
List
[
List
[
int
]]
=
[]
prefix_block_tables
:
List
[
List
[
int
]]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
multi_modal_input_list
:
List
[
torch
.
Tensor
]
=
[]
if
len
(
seq_group_metadata_list
)
==
0
:
return
PreparePromptMetadata
.
empty
()
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
...
@@ -160,7 +249,8 @@ class ModelRunner:
...
@@ -160,7 +249,8 @@ class ModelRunner:
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
if
(
self
.
scheduler_config
is
not
None
if
(
self
.
scheduler_config
is
not
None
and
self
.
scheduler_config
.
chunked_prefill_enabled
and
self
.
scheduler_config
.
chunked_prefill_enabled
and
computed_block_nums
is
not
None
):
and
not
(
computed_block_nums
is
None
or
computed_block_nums
==
[])):
raise
RuntimeError
(
raise
RuntimeError
(
"chunked prefill cannot be used with prefix caching "
"chunked prefill cannot be used with prefix caching "
"now."
)
"now."
)
...
@@ -172,13 +262,8 @@ class ModelRunner:
...
@@ -172,13 +262,8 @@ class ModelRunner:
# it contains output tokens.
# it contains output tokens.
prefill_end
=
min
(
seq_data
.
get_len
(),
prefill_end
=
min
(
seq_data
.
get_len
(),
computed_len
+
token_chunk_size
)
computed_len
+
token_chunk_size
)
# TODO(sang): Rename it after chunked prefill is introduced.
prompt_tokens
=
seq_data
.
get_token_ids
()[
computed_len
:
prefill_end
]
prompt_tokens
=
seq_data
.
get_token_ids
()[
computed_len
:
prefill_end
]
prompt_len
=
len
(
prompt_tokens
)
prompt_len
=
prefill_end
# Right now, the prefill_end is always same as the length of
# sequence. However, once chunked prefill is introduced, this
# assumption can be changed.
assert
prefill_end
==
seq_data
.
get_len
()
prompt_lens
.
append
(
prompt_len
)
prompt_lens
.
append
(
prompt_len
)
# NOTE: This only works for oooooooxxx style attention.
# NOTE: This only works for oooooooxxx style attention.
...
@@ -188,6 +273,14 @@ class ModelRunner:
...
@@ -188,6 +273,14 @@ class ModelRunner:
computed_len
=
len
(
computed_block_nums
)
*
self
.
block_size
computed_len
=
len
(
computed_block_nums
)
*
self
.
block_size
prompt_tokens
=
prompt_tokens
[
computed_len
:]
prompt_tokens
=
prompt_tokens
[
computed_len
:]
prefix_block_tables
.
append
(
computed_block_nums
)
prefix_block_tables
.
append
(
computed_block_nums
)
elif
self
.
scheduler_config
.
chunked_prefill_enabled
:
if
seq_group_metadata
.
block_tables
is
not
None
:
# Prefill has chunked before.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
prefix_block_tables
.
append
(
block_table
)
else
:
# The first prefill.
prefix_block_tables
.
append
([])
else
:
else
:
prefix_block_tables
.
append
([])
prefix_block_tables
.
append
([])
# Right now, prefill start is always 0. However, this
# Right now, prefill start is always 0. However, this
...
@@ -202,7 +295,6 @@ class ModelRunner:
...
@@ -202,7 +295,6 @@ class ModelRunner:
# NOTE(woosuk): Here we assume that the first token in the prompt
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
# is always the first token in the sequence.
input_positions
.
extend
(
list
(
range
(
computed_len
,
prefill_end
)))
input_positions
.
extend
(
list
(
range
(
computed_len
,
prefill_end
)))
lora_id
=
seq_group_metadata
.
lora_int_id
lora_id
=
seq_group_metadata
.
lora_int_id
if
lora_id
>
0
:
if
lora_id
>
0
:
...
@@ -250,20 +342,8 @@ class ModelRunner:
...
@@ -250,20 +342,8 @@ class ModelRunner:
max_subquery_len
=
max
(
subquery_lens
)
max_subquery_len
=
max
(
subquery_lens
)
max_prompt_len
=
max
(
prompt_lens
)
max_prompt_len
=
max
(
prompt_lens
)
num_prompt_tokens
=
len
(
input_tokens
)
assert
max_subquery_len
>
0
assert
max_subquery_len
>
0
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
lora_index_mapping
=
lora_index_mapping
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
device
=
self
.
device
)
...
@@ -315,11 +395,8 @@ class ModelRunner:
...
@@ -315,11 +395,8 @@ class ModelRunner:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
is_prompt
=
True
,
slot_mapping
=
slot_mapping
,
prompt_lens
=
prompt_lens
,
prompt_lens
=
prompt_lens
,
prompt_lens_tensor
=
prompt_lens_tensor
,
prompt_lens_tensor
=
prompt_lens_tensor
,
num_prompt_tokens
=
num_prompt_tokens
,
num_generation_tokens
=
0
,
max_subquery_len
=
max_subquery_len
,
max_subquery_len
=
max_subquery_len
,
max_context_len
=
None
,
max_context_len
=
None
,
max_prompt_len
=
max_prompt_len
,
max_prompt_len
=
max_prompt_len
,
...
@@ -328,18 +405,25 @@ class ModelRunner:
...
@@ -328,18 +405,25 @@ class ModelRunner:
context_lens
=
context_lens_tensor
,
context_lens
=
context_lens_tensor
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
use_cuda_graph
=
False
,
use_cuda_graph
=
False
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
,
subquery_lens
,
lora_index_mapping
,
lora_prompt_mapping
,
return
PreparePromptMetadata
(
lora_requests
,
multi_modal_input
)
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
attn_metadata
=
attn_metadata
,
prompt_lens
=
prompt_lens
,
subquery_lens
=
subquery_lens
,
lora_index_mapping
=
lora_index_mapping
,
lora_prompt_mapping
=
lora_prompt_mapping
,
lora_requests
=
lora_requests
,
multi_modal_input
=
multi_modal_input
,
slot_mapping
=
slot_mapping
,
)
def
_prepare_decode
(
def
_prepare_decode
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
],
)
->
PrepareDecodeMetadata
:
List
[
int
],
Set
[
LoRARequest
]]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
...
@@ -349,6 +433,9 @@ class ModelRunner:
...
@@ -349,6 +433,9 @@ class ModelRunner:
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_requests
:
Set
[
LoRARequest
]
=
set
()
lora_requests
:
Set
[
LoRARequest
]
=
set
()
if
len
(
seq_group_metadata_list
)
==
0
:
return
PrepareDecodeMetadata
.
empty
()
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
not
seq_group_metadata
.
is_prompt
assert
not
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
token_chunk_size
==
1
assert
seq_group_metadata
.
token_chunk_size
==
1
...
@@ -407,25 +494,16 @@ class ModelRunner:
...
@@ -407,25 +494,16 @@ class ModelRunner:
lora_index_mapping
.
append
(
0
)
lora_index_mapping
.
append
(
0
)
batch_size
=
graph_batch_size
batch_size
=
graph_batch_size
input_tokens
=
torch
.
tensor
(
input_tokens
,
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
long
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
if
use_captured_graph
:
if
use_captured_graph
:
# When using cuda-graph all these tensors should be
# When using cuda-graph all these tensors should be
# padded.
# padded.
assert
context_lens
.
shape
[
0
]
==
input_tokens
.
shape
[
0
]
assert
context_lens
_tensor
.
shape
[
0
]
==
len
(
input_tokens
)
assert
context_lens
.
shape
[
0
]
==
input_positions
.
shape
[
0
]
assert
context_lens
_tensor
.
shape
[
0
]
==
len
(
input_positions
)
assert
context_lens
.
shape
[
0
]
==
slot_mapping
.
shape
[
0
]
assert
context_lens
_tensor
.
shape
[
0
]
==
len
(
slot_mapping
)
# The shape of graph_block_tables is
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
# [max batch size, max context len // block size].
...
@@ -447,23 +525,26 @@ class ModelRunner:
...
@@ -447,23 +525,26 @@ class ModelRunner:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
is_prompt
=
False
,
slot_mapping
=
slot_mapping
,
prompt_lens
=
None
,
prompt_lens
=
None
,
prompt_lens_tensor
=
None
,
prompt_lens_tensor
=
None
,
num_prompt_tokens
=
0
,
num_generation_tokens
=
len
(
input_tokens
),
max_subquery_len
=
None
,
max_subquery_len
=
None
,
max_context_len
=
max_context_len
,
max_context_len
=
max_context_len
,
max_prompt_len
=
None
,
max_prompt_len
=
None
,
subquery_start_loc
=
None
,
subquery_start_loc
=
None
,
seq_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens
=
context_lens
,
context_lens
=
context_lens
_tensor
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
use_cuda_graph
=
use_captured_graph
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
return
PrepareDecodeMetadata
(
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
)
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
attn_metadata
=
attn_metadata
,
lora_index_mapping
=
lora_index_mapping
,
lora_prompt_mapping
=
lora_prompt_mapping
,
lora_requests
=
lora_requests
,
slot_mapping
=
slot_mapping
,
)
def
_prepare_sample
(
def
_prepare_sample
(
self
,
self
,
...
@@ -475,7 +556,11 @@ class ModelRunner:
...
@@ -475,7 +556,11 @@ class ModelRunner:
selected_token_indices
:
List
[
int
]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
selected_token_start_idx
=
0
selected_token_start_idx
=
0
categorized_sample_indices
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
categorized_sample_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
...
@@ -493,10 +578,9 @@ class ModelRunner:
...
@@ -493,10 +578,9 @@ class ModelRunner:
categorized_sample_indices_start_idx
+=
subquery_len
-
1
categorized_sample_indices_start_idx
+=
subquery_len
-
1
categorized_sample_indices
[
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
([
sampling_params
.
sampling_type
].
append
(
categorized_sample_indices_start_idx
,
(
categorized_sample_indices_start_idx
,
categorized_sampled_token_indices_start_idx
categorized_sampled_token_indices_start_idx
))
])
categorized_sample_indices_start_idx
+=
1
categorized_sample_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
...
@@ -520,15 +604,16 @@ class ModelRunner:
...
@@ -520,15 +604,16 @@ class ModelRunner:
categorized_sample_indices
[
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
sampling_params
.
sampling_type
].
extend
(
zip
(
list
(
range
(
zip
(
categorized_sample_indices_start_idx
,
range
(
categorized_sample_indices_start_idx
+
categorized_sample_indices_start_idx
,
num_seqs
),
categorized_sample_indices_start_idx
+
range
(
num_seqs
),
categorized_sampled_token_indices_start_idx
,
range
(
categorized_sampled_token_indices_start_idx
+
categorized_sampled_token_indices_start_idx
,
num_seqs
)))
categorized_sampled_token_indices_start_idx
+
num_seqs
))))
categorized_sample_indices_start_idx
+=
num_seqs
categorized_sample_indices_start_idx
+=
num_seqs
categorized_sampled_token_indices_start_idx
+=
num_seqs
categorized_sampled_token_indices_start_idx
+=
num_seqs
...
@@ -565,30 +650,70 @@ class ModelRunner:
...
@@ -565,30 +650,70 @@ class ModelRunner:
def
prepare_input_tensors
(
def
prepare_input_tensors
(
self
,
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]
]
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
SamplingMetadata
,
Set
[
in
t
],
LoRAMapping
,
torch
.
Tensor
]:
Set
[
LoRAReques
t
],
LoRAMapping
,
torch
.
Tensor
]:
if
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
# NOTE: We assume that all sequences in the group are all prompts or
prefill_reqs
=
[]
# all decodes.
decode_reqs
=
[]
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
for
seq_group_meta
in
seq_group_metadata_list
:
if
seq_group_meta
.
is_prompt
:
prefill_reqs
.
append
(
seq_group_meta
)
else
:
decode_reqs
.
append
(
seq_group_meta
)
# Prepare input tensors.
# Prepare input tensors.
if
is_prompt
:
(
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
,
input_tokens
,
subquery_lens
,
lora_index_mapping
,
lora_prompt_mapping
,
input_positions
,
lora_requests
,
multi_modal_input
prefill_attn_metadata
,
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
prompt_lens
,
else
:
subquery_lens
,
(
input_tokens
,
input_positions
,
attn_metadata
,
lora_index_mapping
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_prompt_mapping
,
lora_requests
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
lora_requests
,
prompt_lens
=
[]
multi_modal_input
,
subquery_lens
=
None
slot_mapping
,
multi_modal_input
=
None
)
=
self
.
_prepare_prompt
(
prefill_reqs
)
(
decode_input_tokens
,
decode_input_positions
,
decode_attn_metadata
,
decode_lora_index_mapping
,
decode_lora_prompt_mapping
,
decode_lora_requests
,
decode_slot_mapping
,
)
=
self
.
_prepare_decode
(
decode_reqs
)
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
prompt_lens
,
subquery_lens
)
subquery_lens
)
if
not
self
.
scheduler_config
.
chunked_prefill_enabled
:
assert
(
len
(
prefill_reqs
)
and
len
(
decode_reqs
))
==
0
num_prefills
=
len
(
prompt_lens
)
num_prefill_tokens
=
len
(
input_tokens
)
num_decode_tokens
=
len
(
decode_input_tokens
)
# Coalesce tensors. Note that attn_metadata is currently not
# coalesced for simplicity.
input_tokens
.
extend
(
decode_input_tokens
)
input_positions
.
extend
(
decode_input_positions
)
slot_mapping
.
extend
(
decode_slot_mapping
)
lora_index_mapping
.
extend
(
decode_lora_index_mapping
)
lora_prompt_mapping
.
extend
(
decode_lora_prompt_mapping
)
lora_requests
.
update
(
decode_lora_requests
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
if
self
.
lora_config
:
if
self
.
lora_config
:
lora_mapping
=
LoRAMapping
(
lora_mapping
=
LoRAMapping
(
lora_index_mapping
,
lora_index_mapping
,
...
@@ -598,6 +723,16 @@ class ModelRunner:
...
@@ -598,6 +723,16 @@ class ModelRunner:
lora_mapping
=
None
lora_mapping
=
None
# Broadcast the metadata.
# Broadcast the metadata.
# If batch contains both prefill and decode, it sends 2 broadcasts.
# If it only contains 1 type, it triggers a single broadcast.
if
(
prefill_attn_metadata
is
not
None
and
decode_attn_metadata
is
not
None
):
batch_type
=
BatchType
.
MIXED
elif
prefill_attn_metadata
is
not
None
:
batch_type
=
BatchType
.
PREFILL
else
:
batch_type
=
BatchType
.
DECODE
metadata_dict
=
{
metadata_dict
=
{
"input_tokens"
:
input_tokens
,
"input_tokens"
:
input_tokens
,
"input_positions"
:
input_positions
,
"input_positions"
:
input_positions
,
...
@@ -606,19 +741,50 @@ class ModelRunner:
...
@@ -606,19 +741,50 @@ class ModelRunner:
"lora_requests"
:
lora_requests
,
"lora_requests"
:
lora_requests
,
"lora_mapping"
:
lora_mapping
,
"lora_mapping"
:
lora_mapping
,
"multi_modal_input"
:
multi_modal_input
,
"multi_modal_input"
:
multi_modal_input
,
"num_prefill_tokens"
:
num_prefill_tokens
,
"num_decode_tokens"
:
num_decode_tokens
,
"slot_mapping"
:
slot_mapping
,
"num_prefills"
:
num_prefills
,
"batch_type"
:
batch_type
,
}
}
metadata_dict
.
update
(
attn_metadata
.
asdict_zerocopy
())
if
prefill_attn_metadata
is
not
None
:
metadata_dict
.
update
(
prefill_attn_metadata
.
asdict_zerocopy
())
else
:
assert
decode_attn_metadata
is
not
None
metadata_dict
.
update
(
decode_attn_metadata
.
asdict_zerocopy
())
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
# Broadcast decode attn metadata for mixed batch type.
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
# We can potentially reduce the overhead by coelescing tensors.
if
batch_type
==
BatchType
.
MIXED
:
assert
decode_attn_metadata
is
not
None
metadata_dict
=
decode_attn_metadata
.
asdict_zerocopy
()
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
else
:
else
:
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
input_tokens
=
metadata_dict
.
pop
(
"input_tokens"
)
input_tokens
=
metadata_dict
.
pop
(
"input_tokens"
)
input_positions
=
metadata_dict
.
pop
(
"input_positions"
)
input_positions
=
metadata_dict
.
pop
(
"input_positions"
)
slot_mapping
=
metadata_dict
.
pop
(
"slot_mapping"
)
num_prefills
=
metadata_dict
.
pop
(
"num_prefills"
)
selected_token_indices
=
metadata_dict
.
pop
(
selected_token_indices
=
metadata_dict
.
pop
(
"selected_token_indices"
)
"selected_token_indices"
)
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
multi_modal_input
=
metadata_dict
.
pop
(
"multi_modal_input"
)
multi_modal_input
=
metadata_dict
.
pop
(
"multi_modal_input"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
num_prefill_tokens
=
metadata_dict
.
pop
(
"num_prefill_tokens"
)
num_decode_tokens
=
metadata_dict
.
pop
(
"num_decode_tokens"
)
batch_type
=
metadata_dict
.
pop
(
"batch_type"
)
# Create an attention metadata.
prefill_attn_metadata
=
None
decode_attn_metadata
=
None
if
batch_type
==
BatchType
.
PREFILL
or
batch_type
==
BatchType
.
MIXED
:
prefill_attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
else
:
decode_attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
sampling_metadata
=
SamplingMetadata
(
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_groups
=
None
,
seq_data
=
None
,
seq_data
=
None
,
...
@@ -629,6 +795,23 @@ class ModelRunner:
...
@@ -629,6 +795,23 @@ class ModelRunner:
perform_sampling
=
False
,
perform_sampling
=
False
,
)
)
# if it is a mixed batch, decode attn_metadata is broadcasted
# separately.
if
batch_type
==
BatchType
.
MIXED
:
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
decode_attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
attn_metadata
=
AttentionMetadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
slot_mapping
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
prefill_metadata
=
prefill_attn_metadata
,
decode_metadata
=
decode_attn_metadata
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
return
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
lora_requests
,
lora_mapping
,
sampling_metadata
,
lora_requests
,
lora_mapping
,
multi_modal_input
)
multi_modal_input
)
...
@@ -636,7 +819,7 @@ class ModelRunner:
...
@@ -636,7 +819,7 @@ class ModelRunner:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]
]
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
(
input_tokens
,
input_positions
,
attn_metadata
,
sampling_metadata
,
...
@@ -646,8 +829,10 @@ class ModelRunner:
...
@@ -646,8 +829,10 @@ class ModelRunner:
if
self
.
lora_config
:
if
self
.
lora_config
:
self
.
set_active_loras
(
lora_requests
,
lora_mapping
)
self
.
set_active_loras
(
lora_requests
,
lora_mapping
)
# Execute the model.
# Currently cuda graph is only supported by the decode phase.
if
attn_metadata
.
use_cuda_graph
:
prefill_meta
=
attn_metadata
.
prefill_metadata
decode_meta
=
attn_metadata
.
decode_metadata
if
prefill_meta
is
None
and
decode_meta
.
use_cuda_graph
:
graph_batch_size
=
input_tokens
.
shape
[
0
]
graph_batch_size
=
input_tokens
.
shape
[
0
]
model_executable
=
self
.
graph_runners
[
graph_batch_size
]
model_executable
=
self
.
graph_runners
[
graph_batch_size
]
else
:
else
:
...
@@ -748,7 +933,7 @@ class ModelRunner:
...
@@ -748,7 +933,7 @@ class ModelRunner:
raise
RuntimeError
(
"LoRA is not enabled."
)
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
remove_all_loras
()
return
self
.
lora_manager
.
remove_all_loras
()
def
set_active_loras
(
self
,
lora_requests
:
Lis
t
[
LoRARequest
],
def
set_active_loras
(
self
,
lora_requests
:
Se
t
[
LoRARequest
],
lora_mapping
:
LoRAMapping
)
->
None
:
lora_mapping
:
LoRAMapping
)
->
None
:
if
not
self
.
lora_manager
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
raise
RuntimeError
(
"LoRA is not enabled."
)
...
@@ -825,13 +1010,10 @@ class ModelRunner:
...
@@ -825,13 +1010,10 @@ class ModelRunner:
# memory usage of CUDA graph.
# memory usage of CUDA graph.
for
batch_size
in
reversed
(
batch_size_capture_list
):
for
batch_size
in
reversed
(
batch_size_capture_list
):
# Create dummy attn_metadata.
# Create dummy attn_metadata.
attn
_metadata
=
self
.
attn_backend
.
make_metadata
(
decode
_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
is_prompt
=
False
,
slot_mapping
=
slot_mapping
[:
batch_size
],
prompt_lens
=
None
,
prompt_lens
=
None
,
prompt_lens_tensor
=
None
,
prompt_lens_tensor
=
None
,
num_prompt_tokens
=
0
,
num_generation_tokens
=
batch_size
,
max_subquery_len
=
None
,
max_subquery_len
=
None
,
max_context_len
=
self
.
max_context_len_to_capture
,
max_context_len
=
self
.
max_context_len_to_capture
,
max_prompt_len
=
None
,
max_prompt_len
=
None
,
...
@@ -840,6 +1022,14 @@ class ModelRunner:
...
@@ -840,6 +1022,14 @@ class ModelRunner:
context_lens
=
context_lens
[:
batch_size
],
context_lens
=
context_lens
[:
batch_size
],
block_tables
=
block_tables
[:
batch_size
],
block_tables
=
block_tables
[:
batch_size
],
use_cuda_graph
=
True
,
use_cuda_graph
=
True
,
)
attn_metadata
=
AttentionMetadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
slot_mapping
=
slot_mapping
[:
batch_size
],
prefill_metadata
=
None
,
decode_metadata
=
decode_metadata
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
)
...
@@ -885,10 +1075,16 @@ class CUDAGraphRunner:
...
@@ -885,10 +1075,16 @@ class CUDAGraphRunner:
def
__init__
(
self
,
model
:
nn
.
Module
):
def
__init__
(
self
,
model
:
nn
.
Module
):
self
.
model
=
model
self
.
model
=
model
self
.
graph
=
None
self
.
input_buffers
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
input_buffers
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
output_buffers
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
output_buffers
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
_graph
:
Optional
[
torch
.
cuda
.
CUDAGraph
]
=
None
@
property
def
graph
(
self
):
assert
self
.
_graph
is
not
None
return
self
.
_graph
def
capture
(
def
capture
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -898,7 +1094,7 @@ class CUDAGraphRunner:
...
@@ -898,7 +1094,7 @@ class CUDAGraphRunner:
memory_pool
,
memory_pool
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
assert
self
.
graph
is
None
assert
self
.
_
graph
is
None
# Run the model once without capturing the graph.
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
# kernel launches for initial benchmarking (e.g., Triton autotune).
...
@@ -915,8 +1111,8 @@ class CUDAGraphRunner:
...
@@ -915,8 +1111,8 @@ class CUDAGraphRunner:
# Capture the graph.
# Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self
.
graph
=
torch
.
cuda
.
CUDAGraph
()
self
.
_
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
self
.
graph
,
pool
=
memory_pool
):
# noqa: SIM117
with
torch
.
cuda
.
graph
(
self
.
_
graph
,
pool
=
memory_pool
):
# noqa: SIM117
with
_maybe_pynccl
():
with
_maybe_pynccl
():
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
,
input_ids
,
...
@@ -933,8 +1129,8 @@ class CUDAGraphRunner:
...
@@ -933,8 +1129,8 @@ class CUDAGraphRunner:
"positions"
:
positions
,
"positions"
:
positions
,
"kv_caches"
:
kv_caches
,
"kv_caches"
:
kv_caches
,
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
"context_lens"
:
attn_metadata
.
context_lens
,
"context_lens"
:
attn_metadata
.
decode_metadata
.
context_lens
,
"block_tables"
:
attn_metadata
.
block_tables
,
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
}
}
self
.
output_buffers
=
{
"hidden_states"
:
hidden_states
}
self
.
output_buffers
=
{
"hidden_states"
:
hidden_states
}
return
return
...
@@ -955,10 +1151,10 @@ class CUDAGraphRunner:
...
@@ -955,10 +1151,10 @@ class CUDAGraphRunner:
self
.
input_buffers
[
"positions"
].
copy_
(
positions
,
non_blocking
=
True
)
self
.
input_buffers
[
"positions"
].
copy_
(
positions
,
non_blocking
=
True
)
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
attn_metadata
.
slot_mapping
,
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
attn_metadata
.
slot_mapping
,
non_blocking
=
True
)
non_blocking
=
True
)
self
.
input_buffers
[
"context_lens"
].
copy_
(
attn_metadata
.
context_lens
,
self
.
input_buffers
[
"context_lens"
].
copy_
(
non_blocking
=
True
)
attn_metadata
.
decode_metadata
.
context_lens
,
non_blocking
=
True
)
self
.
input_buffers
[
"block_tables"
].
copy_
(
attn_metadata
.
block_tables
,
self
.
input_buffers
[
"block_tables"
].
copy_
(
non_blocking
=
True
)
attn_metadata
.
decode_metadata
.
block_tables
,
non_blocking
=
True
)
# Run the graph.
# Run the graph.
self
.
graph
.
replay
()
self
.
graph
.
replay
()
...
...
vllm/worker/neuron_model_runner.py
View file @
99b471c2
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
vllm.config
import
(
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.
neuron_
model_loader
import
get_neuron_model
from
vllm.model_executor.model_loader
.neuron
import
get_neuron_model
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
(
async_tensor_h2d
,
is_pin_memory_available
,
from
vllm.utils
import
(
async_tensor_h2d
,
is_pin_memory_available
,
...
@@ -34,9 +35,11 @@ class NeuronModelRunner:
...
@@ -34,9 +35,11 @@ class NeuronModelRunner:
self
.
device_config
=
(
device_config
self
.
device_config
=
(
device_config
if
device_config
is
not
None
else
DeviceConfig
())
if
device_config
is
not
None
else
DeviceConfig
())
self
.
device
=
self
.
device_config
.
device
self
.
device
=
self
.
device_config
.
device
self
.
model
=
None
self
.
pin_memory
=
is_pin_memory_available
()
self
.
pin_memory
=
is_pin_memory_available
()
# Lazy initialization.
self
.
model
:
nn
.
Module
# initialize after load_model.
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
model
=
get_neuron_model
(
self
.
model_config
,
self
.
model
=
get_neuron_model
(
self
.
model_config
,
parallel_config
=
self
.
parallel_config
,
parallel_config
=
self
.
parallel_config
,
...
@@ -147,7 +150,11 @@ class NeuronModelRunner:
...
@@ -147,7 +150,11 @@ class NeuronModelRunner:
selected_token_indices
:
List
[
int
]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
generators
:
List
[
torch
.
Generator
]
=
[]
selected_token_start_idx
=
0
selected_token_start_idx
=
0
categorized_sample_indices
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
:
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]]
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
categorized_sample_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
categorized_sampled_token_indices_start_idx
=
0
...
@@ -165,10 +172,9 @@ class NeuronModelRunner:
...
@@ -165,10 +172,9 @@ class NeuronModelRunner:
categorized_sample_indices_start_idx
+=
prompt_len
-
1
categorized_sample_indices_start_idx
+=
prompt_len
-
1
categorized_sample_indices
[
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
([
sampling_params
.
sampling_type
].
append
(
categorized_sample_indices_start_idx
,
(
categorized_sample_indices_start_idx
,
categorized_sampled_token_indices_start_idx
categorized_sampled_token_indices_start_idx
))
])
categorized_sample_indices_start_idx
+=
1
categorized_sample_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
categorized_sampled_token_indices_start_idx
+=
1
...
@@ -237,7 +243,7 @@ class NeuronModelRunner:
...
@@ -237,7 +243,7 @@ class NeuronModelRunner:
def
prepare_input_tensors
(
def
prepare_input_tensors
(
self
,
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]
]
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
SamplingMetadata
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
SamplingMetadata
]:
# NOTE: We assume that all sequences in the group are all prompts or
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
# all decodes.
...
@@ -259,7 +265,7 @@ class NeuronModelRunner:
...
@@ -259,7 +265,7 @@ class NeuronModelRunner:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]
]
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
(
input_tokens
,
input_positions
,
input_block_ids
,
sampling_metadata
(
input_tokens
,
input_positions
,
input_block_ids
,
sampling_metadata
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
...
...
vllm/worker/neuron_worker.py
View file @
99b471c2
"""A Neuron worker class."""
"""A Neuron worker class."""
from
typing
import
List
,
Optional
from
typing
import
List
,
Tuple
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
vllm.config
import
(
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
ModelConfig
,
SchedulerConfig
)
ParallelConfig
,
SchedulerConfig
)
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.worker.neuron_model_runner
import
NeuronModelRunner
from
vllm.worker.neuron_model_runner
import
NeuronModelRunner
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
class
NeuronWorker
:
class
NeuronWorker
(
LoraNotSupportedWorkerBase
)
:
"""A worker class that executes the model on a group of neuron cores.
"""A worker class that executes the model on a group of neuron cores.
"""
"""
...
@@ -21,11 +22,17 @@ class NeuronWorker:
...
@@ -21,11 +22,17 @@ class NeuronWorker:
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
)
->
None
:
)
->
None
:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
if
self
.
model_config
.
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
init_cached_hf_modules
()
self
.
model_runner
=
NeuronModelRunner
(
model_config
,
parallel_config
,
self
.
model_runner
=
NeuronModelRunner
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
)
scheduler_config
,
device_config
)
...
@@ -37,16 +44,55 @@ class NeuronWorker:
...
@@ -37,16 +44,55 @@ class NeuronWorker:
def
load_model
(
self
):
def
load_model
(
self
):
self
.
model_runner
.
load_model
()
self
.
model_runner
.
load_model
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of available KV blocks.
Swapping is not yet supported, so always return num_cpu_blocks=0.
We configure num_gpu_blocks to be equal to max_num_seqs.
"""
# Set the number of GPU blocks to be the same as the maximum number of
# sequences that can be processed in a single batch. This is equivalent
# to schedule without PagedAttention.
num_gpu_blocks
=
self
.
scheduler_config
.
max_num_seqs
# Swap not yet supported with Neuron backend.
num_cpu_blocks
=
0
return
num_gpu_blocks
,
num_cpu_blocks
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""Initialize the KV cache.
"""
# Different values are not tested.
assert
num_cpu_blocks
==
0
assert
num_gpu_blocks
==
self
.
scheduler_config
.
max_num_seqs
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Optional
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
num_seq_groups
=
len
(
seq_group_metadata_list
)
num_seq_groups
=
len
(
seq_group_metadata_list
)
# If there is no input, we don't need to execute the model.
# If there is no input, we don't need to execute the model.
if
num_seq_groups
==
0
:
if
num_seq_groups
==
0
:
return
{}
return
[]
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
)
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
)
return
output
# Neuron worker only supports single-step output. Wrap the output in a
# list to conform to interface.
return
[
output
]
def
get_cache_block_size_bytes
(
self
)
->
int
:
"""Determine the size in bytes of a cache block.
This is required for speculative decoding; it is not yet implemented.
"""
raise
NotImplementedError
vllm/worker/worker.py
View file @
99b471c2
"""A GPU worker class."""
"""A GPU worker class."""
import
gc
import
gc
import
os
import
os
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
(
broadcast_tensor_dict
,
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.distributed.device_communicators
import
pynccl_utils
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
init_custom_ar
)
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor.parallel_utils
import
pynccl_utils
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast_tensor_dict
)
from
vllm.model_executor.parallel_utils.custom_all_reduce
import
init_custom_ar
from
vllm.model_executor.parallel_utils.parallel_state
import
(
ensure_model_parallel_initialized
)
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.worker_base
import
WorkerBase
class
Worker
:
class
Worker
(
WorkerBase
)
:
"""A worker class that executes (a partition of) the model on a GPU.
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for
Each worker is associated with a single GPU. The worker is responsible for
...
@@ -35,26 +37,33 @@ class Worker:
...
@@ -35,26 +37,33 @@ class Worker:
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
local_rank
:
int
,
local_rank
:
int
,
rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
distributed_init_method
:
str
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
vision_language_config
:
Optional
[
VisionLanguageConfig
]
=
None
,
vision_language_config
:
Optional
[
VisionLanguageConfig
]
=
None
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
is_driver_worker
:
bool
=
False
,
)
->
None
:
)
->
None
:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
local_rank
=
local_rank
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
distributed_init_method
=
distributed_init_method
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
load_config
=
load_config
self
.
is_driver_worker
=
is_driver_worker
self
.
is_driver_worker
=
is_driver_worker
if
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
if
self
.
model_config
.
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
init_cached_hf_modules
()
self
.
vision_language_config
=
vision_language_config
self
.
vision_language_config
=
vision_language_config
if
self
.
vision_language_config
:
if
self
.
vision_language_config
:
assert
not
self
.
lora_config
,
(
assert
not
self
.
lora_config
,
(
...
@@ -65,15 +74,16 @@ class Worker:
...
@@ -65,15 +74,16 @@ class Worker:
parallel_config
,
parallel_config
,
scheduler_config
,
scheduler_config
,
device_config
,
device_config
,
load_config
=
load_config
,
lora_config
=
self
.
lora_config
,
lora_config
=
self
.
lora_config
,
kv_cache_dtype
=
kv_
cache_dtype
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
is_driver_worker
=
is_driver_worker
,
is_driver_worker
=
is_driver_worker
,
vision_language_config
=
vision_language_config
)
vision_language_config
=
vision_language_config
,
)
# Uninitialized cache engine. Will be initialized by
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
# initialize_cache.
self
.
cache_config
=
None
self
.
cache_engine
:
CacheEngine
self
.
cache_engine
=
None
self
.
gpu_cache
:
List
[
torch
.
Tensor
]
self
.
gpu_cache
=
None
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
if
self
.
device_config
.
device
.
type
==
"cuda"
:
if
self
.
device_config
.
device
.
type
==
"cuda"
:
...
@@ -97,9 +107,9 @@ class Worker:
...
@@ -97,9 +107,9 @@ class Worker:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Not support device type:
{
self
.
device_config
.
device
}
"
)
f
"Not support device type:
{
self
.
device_config
.
device
}
"
)
# Initialize the distributed environment.
# Initialize the distributed environment.
init_distributed_environment
(
self
.
parallel_config
,
self
.
rank
,
init_
worker_
distributed_environment
(
self
.
parallel_config
,
self
.
rank
,
self
.
distributed_init_method
,
self
.
distributed_init_method
,
self
.
local_rank
)
self
.
local_rank
)
# Set random seed.
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
set_random_seed
(
self
.
model_config
.
seed
)
...
@@ -107,20 +117,17 @@ class Worker:
...
@@ -107,20 +117,17 @@ class Worker:
self
.
model_runner
.
load_model
()
self
.
model_runner
.
load_model
()
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
profile_num_available_blocks
(
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
self
,
"""Profiles the peak memory usage of the model to determine how many
block_size
:
int
,
KV blocks may be allocated without OOMs.
gpu_memory_utilization
:
float
,
cpu_swap_space
:
int
,
The engine will first conduct a profiling of the existing memory usage.
cache_dtype
:
str
,
Then, it calculate the maximum possible number of GPU and CPU blocks
)
->
Tuple
[
int
,
int
]:
that can be allocated with the remaining free memory.
"""Profiles the peak memory usage of the model and returns the maximum
number of GPU and CPU cache blocks that can be allocated.
.. tip::
You may limit the usage of GPU memory
Args:
by adjusting the `gpu_memory_utilization` parameter.
block_size: The size of the cache block.
gpu_memory_utilization: The fraction of the total GPU memory to use.
cpu_swap_space: The size of the CPU swap space in bytes.
"""
"""
# Profile the memory usage of the model and get the maximum number of
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
# cache blocks that can be allocated with the remaining free memory.
...
@@ -141,12 +148,12 @@ class Worker:
...
@@ -141,12 +148,12 @@ class Worker:
"Error in memory profiling. This happens when the GPU memory was "
"Error in memory profiling. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance."
)
"not properly cleaned up before initializing the vLLM instance."
)
cache_block_size
=
self
.
get_cache_block_size_bytes
(
cache_block_size
=
self
.
get_cache_block_size_bytes
()
block_size
,
cache_dtype
)
num_gpu_blocks
=
int
(
num_gpu_blocks
=
int
(
(
total_gpu_memory
*
gpu_memory_utilization
-
peak_memory
)
//
(
total_gpu_memory
*
self
.
cache_config
.
gpu_memory_utilization
-
cache_block_size
)
peak_memory
)
//
cache_block_size
)
num_cpu_blocks
=
int
(
cpu_swap_space
//
cache_block_size
)
num_cpu_blocks
=
int
(
self
.
cache_config
.
swap_space_bytes
//
cache_block_size
)
num_gpu_blocks
=
max
(
num_gpu_blocks
,
0
)
num_gpu_blocks
=
max
(
num_gpu_blocks
,
0
)
num_cpu_blocks
=
max
(
num_cpu_blocks
,
0
)
num_cpu_blocks
=
max
(
num_cpu_blocks
,
0
)
if
self
.
model_runner
.
lora_manager
:
if
self
.
model_runner
.
lora_manager
:
...
@@ -155,14 +162,30 @@ class Worker:
...
@@ -155,14 +162,30 @@ class Worker:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
num_gpu_blocks
,
num_cpu_blocks
return
num_gpu_blocks
,
num_cpu_blocks
def
init_cache_engine
(
self
,
cache_config
:
CacheConfig
)
->
None
:
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
self
.
cache_config
=
cache_config
num_cpu_blocks
:
int
)
->
None
:
"""Allocate GPU and CPU KV cache with the specified number of blocks.
This also warms up the model, which may record CUDA graphs.
"""
raise_if_cache_size_invalid
(
num_gpu_blocks
,
self
.
cache_config
.
block_size
,
self
.
model_config
.
max_model_len
)
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
self
.
_init_cache_engine
()
self
.
_warm_up_model
()
def
_init_cache_engine
(
self
):
assert
self
.
cache_config
.
num_gpu_blocks
is
not
None
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
)
self
.
parallel_config
)
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
self
.
model_runner
.
set_block_size
(
self
.
cache_engine
.
block_size
)
self
.
model_runner
.
set_block_size
(
self
.
cache_engine
.
block_size
)
def
warm_up_model
(
self
)
->
None
:
def
_
warm_up_model
(
self
)
->
None
:
if
not
self
.
model_config
.
enforce_eager
:
if
not
self
.
model_config
.
enforce_eager
:
self
.
model_runner
.
capture_model
(
self
.
gpu_cache
)
self
.
model_runner
.
capture_model
(
self
.
gpu_cache
)
# Reset the seed to ensure that the random state is not affected by
# Reset the seed to ensure that the random state is not affected by
...
@@ -191,14 +214,16 @@ class Worker:
...
@@ -191,14 +214,16 @@ class Worker:
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]]
=
None
,
)
->
Optional
[
SamplerOutput
]:
num_lookahead_slots
:
int
=
0
,
)
->
List
[
SamplerOutput
]:
if
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
assert
seq_group_metadata_list
is
not
None
num_seq_groups
=
len
(
seq_group_metadata_list
)
num_seq_groups
=
len
(
seq_group_metadata_list
)
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
assert
blocks_to_copy
is
not
None
data
=
{
data
:
Dict
[
str
,
Any
]
=
{
"num_seq_groups"
:
num_seq_groups
,
"num_seq_groups"
:
num_seq_groups
,
"blocks_to_swap_in"
:
blocks_to_swap_in
,
"blocks_to_swap_in"
:
blocks_to_swap_in
,
"blocks_to_swap_out"
:
blocks_to_swap_out
,
"blocks_to_swap_out"
:
blocks_to_swap_out
,
...
@@ -212,15 +237,21 @@ class Worker:
...
@@ -212,15 +237,21 @@ class Worker:
blocks_to_swap_out
=
data
[
"blocks_to_swap_out"
]
blocks_to_swap_out
=
data
[
"blocks_to_swap_out"
]
blocks_to_copy
=
data
[
"blocks_to_copy"
]
blocks_to_copy
=
data
[
"blocks_to_copy"
]
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
self
.
cache_swap
(
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
self
.
cache_swap
(
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
# If there is no input, we don't need to execute the model.
# If there is no input, we don't need to execute the model.
if
num_seq_groups
==
0
:
if
num_seq_groups
==
0
:
return
{}
return
[]
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
self
.
gpu_cache
)
self
.
gpu_cache
)
return
output
# Worker only supports single-step execution. Wrap the output in a list
# to conform to interface.
return
[
output
]
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
self
.
model_runner
.
add_lora
(
lora_request
)
return
self
.
model_runner
.
add_lora
(
lora_request
)
...
@@ -239,40 +270,23 @@ class Worker:
...
@@ -239,40 +270,23 @@ class Worker:
def
vocab_size
(
self
)
->
int
:
def
vocab_size
(
self
)
->
int
:
return
self
.
model_runner
.
vocab_size
return
self
.
model_runner
.
vocab_size
def
get_cache_block_size_bytes
(
self
,
block_size
:
int
,
def
get_cache_block_size_bytes
(
self
)
->
int
:
cache_dtype
:
str
)
->
int
:
"""Get the size of the KV cache block size in bytes.
"""Get the size of the KV cache block size in bytes.
"""
"""
return
CacheEngine
.
get_cache_block_size
(
block_size
,
cache_dtype
,
return
CacheEngine
.
get_cache_block_size
(
self
.
cache_config
,
self
.
model_config
,
self
.
model_config
,
self
.
parallel_config
)
self
.
parallel_config
)
def
init_distributed_environment
(
def
init_
worker_
distributed_environment
(
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
rank
:
int
,
rank
:
int
,
distributed_init_method
:
Optional
[
str
]
=
None
,
distributed_init_method
:
Optional
[
str
]
=
None
,
local_rank
:
int
=
-
1
,
local_rank
:
int
=
-
1
,
)
->
None
:
)
->
None
:
"""Initialize the distributed environment."""
"""Initialize the distributed environment."""
if
torch
.
distributed
.
is_initialized
():
init_distributed_environment
(
parallel_config
.
world_size
,
rank
,
torch_world_size
=
torch
.
distributed
.
get_world_size
()
distributed_init_method
,
local_rank
)
if
torch_world_size
!=
parallel_config
.
world_size
:
raise
RuntimeError
(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f
"(
{
torch_world_size
}
vs.
{
parallel_config
.
world_size
}
)."
)
elif
not
distributed_init_method
:
raise
ValueError
(
"distributed_init_method must be set if torch.distributed "
"is not already initialized"
)
else
:
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
world_size
=
parallel_config
.
world_size
,
rank
=
rank
,
init_method
=
distributed_init_method
,
)
if
pynccl_utils
.
is_initialized
():
if
pynccl_utils
.
is_initialized
():
pynccl_world_size
=
pynccl_utils
.
get_world_size
()
pynccl_world_size
=
pynccl_utils
.
get_world_size
()
...
@@ -284,17 +298,10 @@ def init_distributed_environment(
...
@@ -284,17 +298,10 @@ def init_distributed_environment(
elif
parallel_config
.
world_size
>
1
:
elif
parallel_config
.
world_size
>
1
:
# NOTE(woosuk): We don't initialize pynccl process group when world size
# NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1.
# is 1.
pynccl_utils
.
init_process_group
(
# NOTE(kaichao): By default, pynccl will use information inside
world_size
=
parallel_config
.
world_size
,
# `parallel_state` for initialization.
local_rank
=
local_rank
,
pynccl_utils
.
init_process_group
()
rank
=
rank
,
init_method
=
distributed_init_method
,
)
# A small all_reduce for warmup.
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
).
cuda
())
if
pynccl_utils
.
is_initialized
():
pynccl_utils
.
all_reduce
(
torch
.
zeros
(
1
).
cuda
())
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
parallel_config
.
pipeline_parallel_size
)
...
@@ -302,6 +309,11 @@ def init_distributed_environment(
...
@@ -302,6 +309,11 @@ def init_distributed_environment(
if
not
parallel_config
.
disable_custom_all_reduce
:
if
not
parallel_config
.
disable_custom_all_reduce
:
init_custom_ar
()
init_custom_ar
()
# A small all_reduce for warmup.
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
).
cuda
())
if
pynccl_utils
.
is_initialized
():
pynccl_utils
.
all_reduce
(
torch
.
zeros
(
1
).
cuda
())
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
# Check if the GPU supports the dtype.
# Check if the GPU supports the dtype.
...
@@ -315,3 +327,19 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
...
@@ -315,3 +327,19 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
f
"
{
compute_capability
[
0
]
}
.
{
compute_capability
[
1
]
}
. "
f
"
{
compute_capability
[
0
]
}
.
{
compute_capability
[
1
]
}
. "
"You can use float16 instead by explicitly setting the"
"You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half."
)
"`dtype` flag in CLI, for example: --dtype=half."
)
def
raise_if_cache_size_invalid
(
num_gpu_blocks
,
block_size
,
max_model_len
)
->
None
:
if
num_gpu_blocks
<=
0
:
raise
ValueError
(
"No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine."
)
max_seq_len
=
block_size
*
num_gpu_blocks
if
max_model_len
>
max_seq_len
:
raise
ValueError
(
f
"The model's max seq len (
{
max_model_len
}
) "
"is larger than the maximum number of tokens that can be "
f
"stored in KV cache (
{
max_seq_len
}
). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine."
)
vllm/worker/worker_base.py
0 → 100644
View file @
99b471c2
import
datetime
import
importlib
import
os
import
tempfile
import
threading
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Set
,
Tuple
from
vllm.logger
import
enable_trace_function_call
,
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.utils
import
get_vllm_instance_id
,
update_environment_variables
logger
=
init_logger
(
__name__
)
class
WorkerBase
(
ABC
):
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware.
"""
@
abstractmethod
def
init_device
(
self
)
->
None
:
"""Initialize device state, such as loading the model or other on-device
memory allocations.
"""
raise
NotImplementedError
@
abstractmethod
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise
NotImplementedError
@
abstractmethod
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""Initialize the KV cache with the given size in blocks.
"""
raise
NotImplementedError
@
abstractmethod
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
List
[
SamplerOutput
]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
raise
NotImplementedError
@
abstractmethod
def
get_cache_block_size_bytes
(
self
)
->
int
:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
"""
raise
NotImplementedError
@
abstractmethod
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
raise
NotImplementedError
@
abstractmethod
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
@
abstractmethod
def
list_loras
(
self
)
->
Set
[
int
]:
raise
NotImplementedError
class
LoraNotSupportedWorkerBase
(
WorkerBase
):
"""Partial implementation of WorkerBase that raises exceptions when LoRA
methods are invoked.
"""
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
raise
ValueError
(
f
"
{
type
(
self
)
}
does not support LoRA"
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
ValueError
(
f
"
{
type
(
self
)
}
does not support LoRA"
)
def
list_loras
(
self
)
->
Set
[
int
]:
raise
ValueError
(
f
"
{
type
(
self
)
}
does not support LoRA"
)
class
WorkerWrapperBase
:
"""
The whole point of this class is to lazily initialize the worker.
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
"""
def
__init__
(
self
,
worker_module_name
=
None
,
worker_class_name
=
None
,
trust_remote_code
:
bool
=
False
)
->
None
:
self
.
worker_module_name
=
worker_module_name
self
.
worker_class_name
=
worker_class_name
self
.
worker
=
None
if
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
init_cached_hf_modules
()
@
staticmethod
def
update_environment_variables
(
envs
:
Dict
[
str
,
str
])
->
None
:
key
=
'CUDA_VISIBLE_DEVICES'
if
key
in
envs
and
key
in
os
.
environ
:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
# suppress the warning in `update_environment_variables`
del
os
.
environ
[
key
]
update_environment_variables
(
envs
)
def
init_worker
(
self
,
*
args
,
**
kwargs
):
"""
Actual initialization of the worker class, and set up
function tracing if required.
Arguments are passed to the worker class constructor.
"""
if
int
(
os
.
getenv
(
"VLLM_TRACE_FUNCTION"
,
"0"
)):
tmp_dir
=
tempfile
.
gettempdir
()
filename
=
(
f
"VLLM_TRACE_FUNCTION_for_process_
{
os
.
getpid
()
}
"
f
"_thread_
{
threading
.
get_ident
()
}
_"
f
"at_
{
datetime
.
datetime
.
now
()
}
.log"
).
replace
(
" "
,
"_"
)
log_path
=
os
.
path
.
join
(
tmp_dir
,
"vllm"
,
get_vllm_instance_id
(),
filename
)
os
.
makedirs
(
os
.
path
.
dirname
(
log_path
),
exist_ok
=
True
)
enable_trace_function_call
(
log_path
)
mod
=
importlib
.
import_module
(
self
.
worker_module_name
)
worker_class
=
getattr
(
mod
,
self
.
worker_class_name
)
self
.
worker
=
worker_class
(
*
args
,
**
kwargs
)
def
execute_method
(
self
,
method
,
*
args
,
**
kwargs
):
try
:
target
=
self
if
self
.
worker
is
None
else
self
.
worker
executor
=
getattr
(
target
,
method
)
return
executor
(
*
args
,
**
kwargs
)
except
Exception
as
e
:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg
=
(
f
"Error executing method
{
method
}
. "
"This might cause deadlock in distributed execution."
)
logger
.
exception
(
msg
)
raise
e
Prev
1
…
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment