Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
500b93c8
Commit
500b93c8
authored
Jul 25, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1
parents
99426767
38c4b7e8
Changes
282
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1001 additions
and
730 deletions
+1001
-730
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+4
-4
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+5
-4
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+4
-0
vllm/transformers_utils/configs/chameleon.py
vllm/transformers_utils/configs/chameleon.py
+138
-0
vllm/transformers_utils/detokenizer.py
vllm/transformers_utils/detokenizer.py
+8
-0
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+5
-3
vllm/transformers_utils/tokenizer_group/__init__.py
vllm/transformers_utils/tokenizer_group/__init__.py
+9
-5
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
...ransformers_utils/tokenizer_group/base_tokenizer_group.py
+7
-0
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
...transformers_utils/tokenizer_group/ray_tokenizer_group.py
+3
-1
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
+6
-0
vllm/usage/usage_lib.py
vllm/usage/usage_lib.py
+5
-4
vllm/utils.py
vllm/utils.py
+66
-12
vllm/version.py
vllm/version.py
+1
-1
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+0
-3
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+437
-476
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+17
-1
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+4
-4
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+208
-140
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+73
-71
vllm/worker/worker.py
vllm/worker/worker.py
+1
-1
No files found.
vllm/spec_decode/util.py
View file @
500b93c8
from
contextlib
import
contextmanager
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -53,8 +53,8 @@ def create_sequence_group_output(
token_id_logprob_rank
:
int
,
token_id_logprob
:
float
,
seq_id
:
SeqId
,
topk_token_ids
:
List
[
int
],
topk_logprobs
:
List
[
float
],
topk_token_ids
:
List
[
Optional
[
int
]
]
,
topk_logprobs
:
List
[
Optional
[
float
]
]
,
)
->
CompletionSequenceGroupOutput
:
"""Create a SequenceGroupOutput given the sampling results.
...
...
@@ -68,7 +68,7 @@ def create_sequence_group_output(
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs
:
Dict
[
int
,
Logprob
]
=
{
logprobs
:
Dict
[
Optional
[
int
]
,
Logprob
]
=
{
token_id
:
Logprob
(
logprob
=
token_id_logprob
,
rank
=
token_id_logprob_rank
,
...
...
vllm/transformers_utils/config.py
View file @
500b93c8
...
...
@@ -5,10 +5,10 @@ from transformers import GenerationConfig, PretrainedConfig
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.configs
import
(
Cha
tGLM
Config
,
Dbrx
Config
,
JAIS
Config
,
Medusa
Config
,
MLPSpeculator
Config
,
MPT
Config
,
RWConfig
)
from
vllm.transformers_utils.configs
import
(
Cha
meleon
Config
,
ChatGLM
Config
,
Dbrx
Config
,
JAIS
Config
,
M
edusaConfig
,
M
LPSpeculatorConfig
,
MPTConfig
,
RWConfig
)
if
VLLM_USE_MODELSCOPE
:
from
modelscope
import
AutoConfig
...
...
@@ -18,6 +18,7 @@ else:
logger
=
init_logger
(
__name__
)
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
"chameleon"
:
ChameleonConfig
,
"chatglm"
:
ChatGLMConfig
,
"dbrx"
:
DbrxConfig
,
"mpt"
:
MPTConfig
,
...
...
vllm/transformers_utils/configs/__init__.py
View file @
500b93c8
from
vllm.transformers_utils.configs.chameleon
import
(
ChameleonConfig
,
ChameleonVQVAEConfig
)
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
...
...
@@ -10,6 +12,8 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
__all__
=
[
"ChameleonConfig"
,
"ChameleonVQVAEConfig"
,
"ChatGLMConfig"
,
"DbrxConfig"
,
"MPTConfig"
,
...
...
vllm/transformers_utils/configs/chameleon.py
0 → 100644
View file @
500b93c8
from
typing
import
List
,
Optional
from
transformers
import
PretrainedConfig
#TODO (ywang96): Remove this file and import it from
# transformers once the new release with Chameleon support
# is available.
class
ChameleonConfig
(
PretrainedConfig
):
model_type
=
"chameleon"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
vocab_size
=
65536
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
32
,
num_attention_heads
=
32
,
num_key_value_heads
=
32
,
hidden_act
=
"silu"
,
max_position_embeddings
=
4096
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-05
,
use_cache
=
True
,
pad_token_id
=
None
,
bos_token_id
=
1
,
eos_token_id
=
2
,
tie_word_embeddings
=
False
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
model_parallel_size
=
1
,
swin_norm
=
False
,
vq_config
=
None
,
vocabulary_map
=
None
,
mlp_bias
=
False
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
mlp_bias
=
mlp_bias
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
_rope_scaling_validation
()
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
model_parallel_size
=
model_parallel_size
self
.
swin_norm
=
swin_norm
if
vq_config
is
None
:
vq_config
=
{}
self
.
vq_config
=
ChameleonVQVAEConfig
(
**
vq_config
)
self
.
vocabulary_map
=
vocabulary_map
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
def
_rope_scaling_validation
(
self
):
"""
Validate the `rope_scaling` configuration.
"""
if
self
.
rope_scaling
is
None
:
return
if
not
isinstance
(
self
.
rope_scaling
,
dict
)
or
len
(
self
.
rope_scaling
)
!=
2
:
raise
ValueError
(
"`rope_scaling` must be a dictionary with with two fields, "
f
"`type` and `factor`, got
{
self
.
rope_scaling
}
"
)
rope_scaling_type
=
self
.
rope_scaling
.
get
(
"type"
,
None
)
rope_scaling_factor
=
self
.
rope_scaling
.
get
(
"factor"
,
None
)
if
rope_scaling_type
is
None
or
rope_scaling_type
not
in
[
"linear"
,
"dynamic"
]:
raise
ValueError
(
"`rope_scaling`'s type field must be one of ['linear', "
f
"'dynamic'], got
{
rope_scaling_type
}
"
)
if
rope_scaling_factor
is
None
or
not
isinstance
(
rope_scaling_factor
,
float
)
or
rope_scaling_factor
<=
1.0
:
raise
ValueError
(
"`rope_scaling`'s factor field must be a float > 1, "
f
"got
{
rope_scaling_factor
}
"
)
class
ChameleonVQVAEConfig
(
PretrainedConfig
):
model_type
=
"chameleon_vqgan"
def
__init__
(
self
,
embed_dim
:
int
=
256
,
num_embeddings
:
int
=
8192
,
double_latent
:
bool
=
False
,
latent_channels
:
int
=
256
,
resolution
:
int
=
512
,
in_channels
:
int
=
3
,
base_channels
:
int
=
128
,
channel_multiplier
:
List
[
int
]
=
[
1
,
1
,
2
,
2
,
4
],
#noqa
num_res_blocks
:
int
=
2
,
attn_resolutions
:
Optional
[
List
[
int
]]
=
None
,
dropout
:
float
=
0.0
,
attn_type
:
str
=
"vanilla"
,
initializer_range
=
0.02
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
embed_dim
=
embed_dim
self
.
num_embeddings
=
num_embeddings
self
.
double_latent
=
double_latent
self
.
latent_channels
=
latent_channels
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
base_channels
=
base_channels
self
.
channel_multiplier
=
channel_multiplier
self
.
num_res_blocks
=
num_res_blocks
self
.
attn_resolutions
=
attn_resolutions
self
.
dropout
=
dropout
self
.
attn_type
=
attn_type
self
.
initializer_range
=
initializer_range
vllm/transformers_utils/detokenizer.py
View file @
500b93c8
...
...
@@ -165,6 +165,12 @@ class Detokenizer:
return
len
(
new_decoded_token_text
)
def
_replace_none_with_empty
(
tokens
:
List
[
Optional
[
str
]]):
for
i
,
token
in
enumerate
(
tokens
):
if
token
is
None
:
tokens
[
i
]
=
""
def
_convert_tokens_to_string_with_added_encoders
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
output_tokens
:
List
[
str
],
...
...
@@ -223,6 +229,8 @@ def convert_prompt_ids_to_tokens(
read_offset
=
len
(
new_tokens
)
prefix_offset
=
max
(
read_offset
-
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET
,
0
)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty
(
new_tokens
)
return
new_tokens
,
prefix_offset
,
read_offset
...
...
vllm/transformers_utils/tokenizer.py
View file @
500b93c8
...
...
@@ -88,6 +88,9 @@ def get_tokenizer(
"Cannot use the fast tokenizer in slow tokenizer mode."
)
kwargs
[
"use_fast"
]
=
False
if
"truncation_side"
not
in
kwargs
:
kwargs
[
"truncation_side"
]
=
"left"
try
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_name
,
...
...
@@ -134,14 +137,13 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
if
lora_request
is
None
:
return
None
try
:
tokenizer
=
get_tokenizer
(
lora_request
.
lora_local_path
,
*
args
,
**
kwargs
)
tokenizer
=
get_tokenizer
(
lora_request
.
lora_path
,
*
args
,
**
kwargs
)
except
OSError
as
e
:
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
logger
.
warning
(
"No tokenizer found in %s, using base model tokenizer instead. "
"(Exception: %s)"
,
lora_request
.
lora_
local_
path
,
e
)
"(Exception: %s)"
,
lora_request
.
lora_path
,
e
)
tokenizer
=
None
return
tokenizer
...
...
vllm/transformers_utils/tokenizer_group/__init__.py
View file @
500b93c8
from
typing
import
Optional
from
typing
import
Optional
,
Type
from
vllm.config
import
TokenizerPoolConfig
from
vllm.executor.ray_utils
import
ray
...
...
@@ -16,18 +16,22 @@ else:
def
get_tokenizer_group
(
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
],
**
init_kwargs
)
->
BaseTokenizerGroup
:
tokenizer_cls
:
Type
[
BaseTokenizerGroup
]
if
tokenizer_pool_config
is
None
:
return
TokenizerGroup
(
**
init_kwargs
)
if
tokenizer_pool_config
.
pool_type
==
"ray"
:
tokenizer_cls
=
TokenizerGroup
elif
isinstance
(
tokenizer_pool_config
.
pool_type
,
type
)
and
issubclass
(
tokenizer_pool_config
.
pool_type
,
BaseTokenizerGroup
):
tokenizer_cls
=
tokenizer_pool_config
.
pool_type
elif
tokenizer_pool_config
.
pool_type
==
"ray"
:
if
RayTokenizerGroupPool
is
None
:
raise
ImportError
(
"RayTokenizerGroupPool is not available. Please install "
"the ray package to use the Ray tokenizer group pool."
)
return
RayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
**
init_kwargs
)
tokenizer_cls
=
RayTokenizerGroupPool
else
:
raise
ValueError
(
f
"Unknown pool type:
{
tokenizer_pool_config
.
pool_type
}
"
)
return
tokenizer_cls
.
from_config
(
tokenizer_pool_config
,
**
init_kwargs
)
__all__
=
[
"get_tokenizer_group"
,
"BaseTokenizerGroup"
]
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
View file @
500b93c8
...
...
@@ -3,12 +3,19 @@ from typing import List, Optional
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
TokenizerPoolConfig
from
vllm.lora.request
import
LoRARequest
class
BaseTokenizerGroup
(
ABC
):
"""A group of tokenizers that can be used for LoRA adapters."""
@
classmethod
@
abstractmethod
def
from_config
(
cls
,
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
],
**
init_kwargs
)
->
"BaseTokenizerGroup"
:
pass
@
abstractmethod
def
ping
(
self
)
->
bool
:
"""Check if the tokenizer group is alive."""
...
...
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
View file @
500b93c8
...
...
@@ -29,8 +29,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
_worker_cls
=
TokenizerGroup
@
classmethod
def
from_config
(
cls
,
tokenizer_pool_config
:
TokenizerPoolConfig
,
def
from_config
(
cls
,
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
]
,
**
init_kwargs
)
->
"RayTokenizerGroupPool"
:
if
not
tokenizer_pool_config
:
raise
ValueError
(
"tokenizer_pool_config must not be None."
)
ray_actor_options
=
(
tokenizer_pool_config
.
extra_config
or
{
"num_cpus"
:
0
})
...
...
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
View file @
500b93c8
...
...
@@ -2,6 +2,7 @@ from typing import List, Optional
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
TokenizerPoolConfig
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizer
import
(
get_lora_tokenizer
,
get_lora_tokenizer_async
,
...
...
@@ -24,6 +25,11 @@ class TokenizerGroup(BaseTokenizerGroup):
self
.
lora_tokenizers
=
LRUCache
[
PreTrainedTokenizer
](
capacity
=
max_num_seqs
)
if
enable_lora
else
None
@
classmethod
def
from_config
(
cls
,
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
],
**
init_kwargs
)
->
"TokenizerGroup"
:
return
cls
(
**
init_kwargs
)
def
ping
(
self
)
->
bool
:
"""Check if the tokenizer group is alive."""
return
True
...
...
vllm/usage/usage_lib.py
View file @
500b93c8
...
...
@@ -16,12 +16,12 @@ import requests
import
torch
import
vllm.envs
as
envs
from
vllm.connections
import
global_http_connection
from
vllm.version
import
__version__
as
VLLM_VERSION
_config_home
=
envs
.
VLLM_CONFIG_ROOT
_USAGE_STATS_JSON_PATH
=
os
.
path
.
join
(
_config_home
,
"vllm/usage_stats.json"
)
_USAGE_STATS_DO_NOT_TRACK_PATH
=
os
.
path
.
join
(
_config_home
,
"vllm/do_not_track"
)
_USAGE_STATS_JSON_PATH
=
os
.
path
.
join
(
_config_home
,
"usage_stats.json"
)
_USAGE_STATS_DO_NOT_TRACK_PATH
=
os
.
path
.
join
(
_config_home
,
"do_not_track"
)
_USAGE_STATS_ENABLED
=
None
_USAGE_STATS_SERVER
=
envs
.
VLLM_USAGE_STATS_SERVER
...
...
@@ -205,7 +205,8 @@ class UsageMessage:
def
_send_to_server
(
self
,
data
):
try
:
requests
.
post
(
_USAGE_STATS_SERVER
,
json
=
data
)
global_http_client
=
global_http_connection
.
get_sync_client
()
global_http_client
.
post
(
_USAGE_STATS_SERVER
,
json
=
data
)
except
requests
.
exceptions
.
RequestException
:
# silently ignore unless we are using debug log
logging
.
debug
(
"Failed to send usage data to server"
)
...
...
vllm/utils.py
View file @
500b93c8
...
...
@@ -20,6 +20,7 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Union
)
import
numpy
as
np
import
numpy.typing
as
npt
import
psutil
import
torch
import
torch.types
...
...
@@ -40,6 +41,15 @@ STR_DTYPE_TO_TORCH_DTYPE = {
# "fp8_e5m2": torch.uint8,
}
TORCH_DTYPE_TO_NUMPY_DTYPE
=
{
torch
.
float16
:
np
.
float16
,
torch
.
float32
:
np
.
float32
,
torch
.
float64
:
np
.
float64
,
torch
.
uint8
:
np
.
uint8
,
torch
.
int32
:
np
.
int32
,
torch
.
int64
:
np
.
int64
,
}
P
=
ParamSpec
(
'P'
)
K
=
TypeVar
(
"K"
)
T
=
TypeVar
(
"T"
)
...
...
@@ -415,9 +425,10 @@ def init_kmp_env():
os
.
environ
[
'KMP_REDUCTION_BARRIER_PATTERN'
]
=
"dist,dist"
def
chunk_list
(
lst
:
List
[
T
],
chunk_size
:
int
)
->
List
[
List
[
T
]]
:
def
chunk_list
(
lst
:
List
[
T
],
chunk_size
:
int
):
"""Yield successive chunk_size chunks from lst."""
return
[
lst
[
i
:
i
+
chunk_size
]
for
i
in
range
(
0
,
len
(
lst
),
chunk_size
)]
for
i
in
range
(
0
,
len
(
lst
),
chunk_size
):
yield
lst
[
i
:
i
+
chunk_size
]
def
cdiv
(
a
:
int
,
b
:
int
)
->
int
:
...
...
@@ -616,23 +627,54 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]:
f
"(e.g., 1, 2, 3). Given input:
{
s
}
"
)
from
e
def
make_tensor_with_pad
(
x
:
List
[
List
[
int
]],
max_len
:
int
,
pad
:
int
,
dtype
:
torch
.
dtype
,
device
:
Optional
[
Union
[
str
,
torch
.
device
]],
)
->
torch
.
Tensor
:
"""Make a padded tensor of a 2D inputs.
def
make_ndarray_with_pad
(
x
:
List
[
List
[
T
]],
pad
:
T
,
dtype
:
npt
.
DTypeLike
,
*
,
max_len
:
Optional
[
int
]
=
None
,
)
->
npt
.
NDArray
:
"""
Make a padded array from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
padded_x
=
np
.
zeros
([
len
(
x
),
max_len
],
dtype
=
np
.
int32
)
+
pad
if
max_len
is
None
:
# Unlike for most functions, map is faster than a genexpr over `len`
max_len
=
max
(
map
(
len
,
x
),
default
=
0
)
padded_x
=
np
.
full
((
len
(
x
),
max_len
),
pad
,
dtype
=
dtype
)
for
ind
,
blocktb
in
enumerate
(
x
):
assert
len
(
blocktb
)
<=
max_len
padded_x
[
ind
,
:
len
(
blocktb
)]
=
blocktb
return
torch
.
tensor
(
padded_x
,
dtype
=
dtype
,
device
=
device
)
return
padded_x
def
make_tensor_with_pad
(
x
:
List
[
List
[
T
]],
pad
:
T
,
dtype
:
torch
.
dtype
,
*
,
max_len
:
Optional
[
int
]
=
None
,
device
:
Optional
[
Union
[
str
,
torch
.
device
]]
=
None
,
pin_memory
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Make a padded tensor from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
np_dtype
=
TORCH_DTYPE_TO_NUMPY_DTYPE
[
dtype
]
padded_x
=
make_ndarray_with_pad
(
x
,
pad
,
np_dtype
,
max_len
=
max_len
)
tensor
=
torch
.
from_numpy
(
padded_x
).
to
(
device
)
if
pin_memory
:
tensor
=
tensor
.
pin_memory
()
return
tensor
def
async_tensor_h2d
(
...
...
@@ -677,6 +719,11 @@ def merge_dicts(dict1: Dict[K, List[T]],
return
dict
(
merged_dict
)
def
flatten_2d_lists
(
lists
:
List
[
List
[
T
]])
->
List
[
T
]:
"""Flatten a list of lists to a single list."""
return
[
item
for
sublist
in
lists
for
item
in
sublist
]
def
init_cached_hf_modules
()
->
None
:
"""
Lazy initialization of the Hugging Face modules.
...
...
@@ -939,3 +986,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
processed_args
.
append
(
arg
)
return
super
().
parse_args
(
processed_args
,
namespace
)
async
def
_run_task_with_lock
(
task
:
Callable
,
lock
:
asyncio
.
Lock
,
*
args
,
**
kwargs
):
"""Utility function to run async task in a lock"""
async
with
lock
:
return
await
task
(
*
args
,
**
kwargs
)
vllm/version.py
View file @
500b93c8
...
...
@@ -9,4 +9,4 @@ except Exception as e:
stacklevel
=
2
)
__commit__
=
"COMMIT_HASH_PLACEHOLDER"
__version__
=
"0.5.
2
"
__version__
=
"0.5.
3.post1
"
vllm/worker/cpu_model_runner.py
View file @
500b93c8
...
...
@@ -276,11 +276,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
dtype
=
torch
.
int
,
device
=
self
.
device
)
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
block_tables
)
block_tables
=
make_tensor_with_pad
(
block_tables
,
max_len
=
max_block_table_len
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
self
.
device
,
...
...
vllm/worker/model_runner.py
View file @
500b93c8
...
...
@@ -2,7 +2,8 @@ import dataclasses
import
gc
import
time
import
warnings
from
collections
import
defaultdict
import
weakref
from
dataclasses
import
dataclass
,
field
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
,
Union
)
...
...
@@ -38,6 +39,7 @@ from vllm.model_executor.model_loader import get_model
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.models.interfaces
import
(
supports_lora
,
supports_vision
)
from
vllm.model_executor.models.utils
import
set_cpu_offload_max_bytes
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensors
,
MultiModalInputs
)
from
vllm.prompt_adapter.layers
import
PromptAdapterMapping
...
...
@@ -47,10 +49,11 @@ from vllm.prompt_adapter.worker_manager import (
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
get_kv_cache_torch_dtype
,
is_hip
,
is_pin_memory_available
,
make_tensor_with_pad
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
flatten_2d_lists
,
get_kv_cache_torch_dtype
,
is_hip
,
is_pin_memory_available
)
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
,
...
...
@@ -74,7 +77,7 @@ _NUM_WARMUP_ITERS = 2
TModelInputForGPU
=
TypeVar
(
'TModelInputForGPU'
,
bound
=
"ModelInputForGPU"
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForGPU
(
ModelRunnerInputBase
):
"""
This base class contains metadata needed for the base model forward pass
...
...
@@ -124,7 +127,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
return
cls
(
**
tensor_dict
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForGPUWithSamplingMetadata
(
ModelInputForGPU
):
"""
Used by the ModelRunner.
...
...
@@ -165,6 +168,425 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
return
cls
(
**
tensor_dict
)
class
ModelInputForGPUBuilder
(
ModelRunnerInputBuilderBase
[
ModelInputForGPU
]):
"""Build ModelInputForGPU from SequenceGroupMetadata."""
@
dataclass
class
InterDataForSeqGroup
:
"""Intermediate data for the current sequence group."""
# From sequence group metadata.
request_id
:
str
seq_ids
:
List
[
int
]
is_prompt
:
bool
block_tables
:
Optional
[
Dict
[
int
,
List
[
int
]]]
computed_block_nums
:
List
[
int
]
n_seqs
:
int
=
0
# Input tokens and positions.
input_tokens
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
input_positions
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
# The sequence length (may be capped to the sliding window).
seq_lens
:
List
[
int
]
=
field
(
default_factory
=
list
)
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens
:
List
[
int
]
=
field
(
default_factory
=
list
)
# The query length.
query_lens
:
List
[
int
]
=
field
(
default_factory
=
list
)
# The number of tokens that are already computed.
context_lens
:
List
[
int
]
=
field
(
default_factory
=
list
)
# The current sliding window block.
curr_sliding_window_blocks
:
List
[
int
]
=
field
(
default_factory
=
list
)
# LoRA inputs.
lora_index_mapping
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
lora_prompt_mapping
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
lora_requests
:
Set
[
LoRARequest
]
=
field
(
default_factory
=
set
)
# Prompt adapter inputs.
prompt_adapter_index_mapping
:
List
[
int
]
=
field
(
default_factory
=
list
)
prompt_adapter_prompt_mapping
:
List
[
int
]
=
field
(
default_factory
=
list
)
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
# Multi-modal inputs.
multi_modal_inputs
:
Optional
[
MultiModalInputs
]
=
None
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit
:
bool
=
False
def
__post_init__
(
self
):
self
.
n_seqs
=
len
(
self
.
seq_ids
)
self
.
input_tokens
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
input_positions
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
seq_lens
=
[
0
]
*
self
.
n_seqs
self
.
orig_seq_lens
=
[
0
]
*
self
.
n_seqs
self
.
query_lens
=
[
0
]
*
self
.
n_seqs
self
.
context_lens
=
[
0
]
*
self
.
n_seqs
self
.
curr_sliding_window_blocks
=
[
0
]
*
self
.
n_seqs
self
.
lora_index_mapping
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
lora_prompt_mapping
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
def
__init__
(
self
,
runner
:
"GPUModelRunnerBase"
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
):
super
().
__init__
()
# Compute functions for each sequence in a sequence group.
# WARNING: The order of the functions matters!
self
.
per_seq_compute_fns
=
[
self
.
_compute_lens
,
self
.
_compute_for_prefix_cache_hit
,
self
.
_compute_for_sliding_window
,
self
.
_compute_lora_input
,
]
# Compute functions for each sequence group.
# WARNING: The order of the functions matters!
self
.
per_seq_group_compute_fns
=
[
self
.
_compute_prompt_adapter_input
,
self
.
_compute_multi_modal_input
,
]
self
.
runner
=
runner
self
.
model_input_cls
=
self
.
runner
.
_model_input_cls
self
.
attn_backend
=
self
.
runner
.
attn_backend
self
.
scheduler_config
=
self
.
runner
.
scheduler_config
self
.
sliding_window
=
self
.
runner
.
sliding_window
self
.
block_size
=
self
.
runner
.
block_size
self
.
enable_lora
=
self
.
runner
.
lora_config
is
not
None
self
.
enable_prompt_adapter
=
(
self
.
runner
.
prompt_adapter_config
is
not
None
)
self
.
multi_modal_input_mapper
=
self
.
runner
.
multi_modal_input_mapper
self
.
finished_requests_ids
=
finished_requests_ids
self
.
decode_only
=
True
# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self
.
inter_data_list
:
List
[
ModelInputForGPUBuilder
.
InterDataForSeqGroup
]
=
[]
# Attention metadata inputs.
self
.
attn_metadata_builder
=
self
.
attn_backend
.
make_metadata_builder
(
weakref
.
proxy
(
self
))
# Engine/Model configurations.
self
.
chunked_prefill_enabled
=
(
self
.
scheduler_config
is
not
None
and
self
.
scheduler_config
.
chunked_prefill_enabled
)
if
self
.
sliding_window
is
not
None
:
self
.
sliding_window_blocks
=
(
self
.
sliding_window
+
self
.
block_size
-
1
)
//
self
.
block_size
self
.
block_aligned_sliding_window
=
\
self
.
sliding_window_blocks
*
self
.
block_size
def
_compute_lens
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""Compute context length, sequence length and tokens
for the given sequence data.
"""
seq_data
=
seq_group_metadata
.
seq_data
[
inter_data
.
seq_ids
[
seq_idx
]]
token_chunk_size
=
seq_group_metadata
.
token_chunk_size
# Compute context length (the number of tokens that are
# already computed) and sequence length (total number of tokens).
seq_len
=
seq_data
.
get_len
()
if
inter_data
.
is_prompt
:
context_len
=
seq_data
.
get_num_computed_tokens
()
else
:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
context_len
=
seq_len
-
1
seq_len
=
min
(
seq_len
,
context_len
+
token_chunk_size
)
# Compute tokens.
if
inter_data
.
is_prompt
:
tokens
=
seq_data
.
get_token_ids
()[
context_len
:
seq_len
]
else
:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens
=
[
seq_data
.
get_last_token_id
()]
inter_data
.
seq_lens
[
seq_idx
]
=
seq_len
inter_data
.
orig_seq_lens
[
seq_idx
]
=
seq_len
inter_data
.
context_lens
[
seq_idx
]
=
context_len
inter_data
.
input_tokens
[
seq_idx
]
=
tokens
inter_data
.
input_positions
[
seq_idx
]
=
list
(
range
(
context_len
,
seq_len
))
inter_data
.
query_lens
[
seq_idx
]
=
seq_len
-
context_len
if
inter_data
.
is_prompt
else
1
def
_compute_for_prefix_cache_hit
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""Check if hit prefix cache (i.e., some blocks are already computed).
If hit, update input tokens and positions to only compute the
remaining blocks.
"""
computed_block_nums
=
inter_data
.
computed_block_nums
# Note that prefix caching does not support sliding window.
prefix_cache_hit
=
(
computed_block_nums
is
not
None
and
len
(
computed_block_nums
)
>
0
and
self
.
sliding_window
is
None
and
inter_data
.
is_prompt
)
inter_data
.
prefix_cache_hit
=
prefix_cache_hit
if
self
.
chunked_prefill_enabled
and
prefix_cache_hit
:
raise
RuntimeError
(
"chunked prefill cannot be used with prefix caching now."
)
# If prefix cache is hit, advance context length to bypass
# hit blocks. Accordingly, input tokens, position and query length
# have to be updated.
if
prefix_cache_hit
:
assert
computed_block_nums
is
not
None
context_len
=
len
(
computed_block_nums
)
*
self
.
block_size
inter_data
.
input_tokens
[
seq_idx
]
=
inter_data
.
input_tokens
[
seq_idx
][
context_len
:]
inter_data
.
input_positions
[
seq_idx
]
=
inter_data
.
input_positions
[
seq_idx
][
context_len
:]
inter_data
.
context_lens
[
seq_idx
]
=
context_len
inter_data
.
query_lens
[
seq_idx
]
=
inter_data
.
seq_lens
[
seq_idx
]
-
context_len
def
_compute_for_sliding_window
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""Update seq_len and curr_sliding_window_block for the given
sequence data (only required by decoding) if sliding window is enabled.
"""
curr_sliding_window_block
=
0
sliding_seq_len
=
inter_data
.
seq_lens
[
seq_idx
]
if
not
inter_data
.
is_prompt
and
self
.
sliding_window
is
not
None
:
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
curr_sliding_window_block
=
self
.
sliding_window_blocks
if
self
.
scheduler_config
.
use_v2_block_manager
:
# number of elements in last block
suff_len
=
inter_data
.
seq_lens
[
seq_idx
]
%
self
.
block_size
sliding_seq_len
=
min
(
inter_data
.
seq_lens
[
seq_idx
],
self
.
block_aligned_sliding_window
+
suff_len
)
if
suff_len
>
0
:
curr_sliding_window_block
+=
1
else
:
sliding_seq_len
=
min
(
inter_data
.
seq_lens
[
seq_idx
],
self
.
sliding_window
)
inter_data
.
curr_sliding_window_blocks
[
seq_idx
]
=
curr_sliding_window_block
inter_data
.
seq_lens
[
seq_idx
]
=
sliding_seq_len
def
_compute_lora_input
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""If LoRA is enabled, compute LoRA index and prompt mapping."""
if
not
self
.
enable_lora
:
return
lora_id
=
seq_group_metadata
.
lora_int_id
if
lora_id
>
0
:
inter_data
.
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
query_len
=
inter_data
.
query_lens
[
seq_idx
]
inter_data
.
lora_index_mapping
.
append
([
lora_id
]
*
query_len
)
inter_data
.
lora_prompt_mapping
.
append
(
[
lora_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
is
not
None
else
1
))
def
_compute_prompt_adapter_input
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""If prompt adapter is enabled, compute index and prompt mapping.
"""
# Note that when is_prompt=True, we expect only one sequence
# in the group.
if
not
self
.
enable_prompt_adapter
:
return
prompt_adapter_id
=
seq_group_metadata
.
prompt_adapter_id
if
prompt_adapter_id
<=
0
or
not
inter_data
.
is_prompt
:
return
# We expect only one sequence in the group when is_prompt=True.
assert
inter_data
.
n_seqs
==
1
query_len
=
inter_data
.
query_lens
[
0
]
inter_data
.
prompt_adapter_request
=
(
seq_group_metadata
.
prompt_adapter_request
)
num_tokens
=
seq_group_metadata
.
prompt_adapter_num_virtual_tokens
inter_data
.
prompt_adapter_index_mapping
=
[
prompt_adapter_id
]
*
num_tokens
+
[
0
]
*
(
query_len
-
num_tokens
)
inter_data
.
prompt_adapter_prompt_mapping
=
[
prompt_adapter_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
)
def
_compute_multi_modal_input
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""If multi-modal data is given, add it to the input."""
mm_data
=
seq_group_metadata
.
multi_modal_data
if
not
mm_data
:
return
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
inter_data
.
multi_modal_inputs
=
mm_kwargs
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""Add a sequence group to the builder."""
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
n_seqs
=
len
(
seq_ids
)
is_prompt
=
seq_group_metadata
.
is_prompt
if
is_prompt
:
assert
n_seqs
==
1
self
.
decode_only
=
False
inter_data
=
self
.
InterDataForSeqGroup
(
request_id
=
seq_group_metadata
.
request_id
,
seq_ids
=
seq_ids
,
is_prompt
=
is_prompt
,
block_tables
=
seq_group_metadata
.
block_tables
,
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
)
self
.
inter_data_list
.
append
(
inter_data
)
for
seq_idx
in
range
(
n_seqs
):
for
per_seq_fn
in
self
.
per_seq_compute_fns
:
per_seq_fn
(
inter_data
,
seq_idx
,
seq_group_metadata
)
for
per_seq_group_fn
in
self
.
per_seq_group_compute_fns
:
per_seq_group_fn
(
inter_data
,
seq_group_metadata
)
def
build
(
self
)
->
ModelInputForGPU
:
"""Finalize the builder intermediate data and
create on-device tensors.
"""
# Combine and flatten intermediate data.
input_tokens
=
flatten_2d_lists
([
flatten_2d_lists
(
inter_data
.
input_tokens
)
for
inter_data
in
self
.
inter_data_list
])
if
not
input_tokens
:
# This may happen when all prefill requests hit
# prefix caching and there is no decode request.
return
self
.
model_input_cls
()
input_positions
=
flatten_2d_lists
([
flatten_2d_lists
(
inter_data
.
input_positions
)
for
inter_data
in
self
.
inter_data_list
])
seq_lens
=
[]
max_decode_seq_len
=
0
for
inter_data
in
self
.
inter_data_list
:
seq_lens
.
extend
(
inter_data
.
seq_lens
)
if
not
inter_data
.
is_prompt
:
max_decode_seq_len
=
max
(
max_decode_seq_len
,
max
(
inter_data
.
seq_lens
))
query_lens
=
flatten_2d_lists
(
[
inter_data
.
query_lens
for
inter_data
in
self
.
inter_data_list
])
# Mapping from request IDs to sequence IDs. Used for Jamba models
# that manages the cache by itself.
request_ids_to_seq_ids
=
{
data
.
request_id
:
data
.
seq_ids
for
data
in
self
.
inter_data_list
}
batch_size
=
len
(
input_tokens
)
use_captured_graph
=
(
self
.
decode_only
and
not
self
.
runner
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
max_decode_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
cuda_graph_pad_size
=
-
1
if
use_captured_graph
:
graph_batch_size
=
_get_graph_batch_size
(
batch_size
)
assert
graph_batch_size
>=
batch_size
cuda_graph_pad_size
=
graph_batch_size
-
batch_size
batch_size
=
graph_batch_size
# Tokens and positions.
input_tokens
.
extend
([
0
]
*
cuda_graph_pad_size
)
input_positions
.
extend
([
0
]
*
cuda_graph_pad_size
)
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
runner
.
device
)
input_positions_tensor
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
runner
.
device
)
# Sequence and query lengths.
seq_lens
.
extend
([
1
]
*
cuda_graph_pad_size
)
# Attention metadata.
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
seq_lens
,
query_lens
,
cuda_graph_pad_size
,
batch_size
)
# LoRA data.
lora_requests
=
set
()
lora_mapping
=
None
if
self
.
enable_lora
:
lora_requests
=
set
(
r
for
data
in
self
.
inter_data_list
for
r
in
data
.
lora_requests
)
lora_index_mapping
=
flatten_2d_lists
([
flatten_2d_lists
(
inter_data
.
lora_index_mapping
)
for
inter_data
in
self
.
inter_data_list
])
lora_index_mapping
.
extend
([
0
]
*
cuda_graph_pad_size
)
lora_prompt_mapping
=
flatten_2d_lists
([
flatten_2d_lists
(
inter_data
.
lora_prompt_mapping
)
for
inter_data
in
self
.
inter_data_list
])
lora_mapping
=
LoRAMapping
(
lora_index_mapping
,
lora_prompt_mapping
,
)
# Prompt adapter data.
prompt_adapter_requests
:
Set
[
PromptAdapterRequest
]
=
set
()
prompt_adapter_mapping
=
None
if
self
.
enable_prompt_adapter
:
prompt_adapter_requests
=
set
(
data
.
prompt_adapter_request
for
data
in
self
.
inter_data_list
if
data
.
prompt_adapter_request
is
not
None
)
prompt_adapter_index_mapping
=
flatten_2d_lists
([
inter_data
.
prompt_adapter_index_mapping
for
inter_data
in
self
.
inter_data_list
])
prompt_adapter_index_mapping
.
extend
([
0
]
*
cuda_graph_pad_size
)
prompt_adapter_prompt_mapping
=
flatten_2d_lists
([
inter_data
.
prompt_adapter_prompt_mapping
for
inter_data
in
self
.
inter_data_list
])
prompt_adapter_mapping
=
PromptAdapterMapping
(
prompt_adapter_index_mapping
,
prompt_adapter_prompt_mapping
,
)
# Multi-modal data.
multi_modal_inputs_list
=
[
data
.
multi_modal_inputs
for
data
in
self
.
inter_data_list
if
data
.
multi_modal_inputs
is
not
None
]
multi_modal_kwargs
=
MultiModalInputs
.
batch
(
multi_modal_inputs_list
,
device
=
self
.
runner
.
device
)
return
self
.
model_input_cls
(
input_tokens
=
input_tokens_tensor
,
input_positions
=
input_positions_tensor
,
attn_metadata
=
attn_metadata
,
seq_lens
=
seq_lens
,
query_lens
=
query_lens
,
lora_mapping
=
lora_mapping
,
lora_requests
=
lora_requests
,
multi_modal_kwargs
=
multi_modal_kwargs
,
request_ids_to_seq_ids
=
request_ids_to_seq_ids
,
finished_requests_ids
=
self
.
finished_requests_ids
,
prompt_adapter_mapping
=
prompt_adapter_mapping
,
prompt_adapter_requests
=
prompt_adapter_requests
)
class
GPUModelRunnerBase
(
ModelRunnerBase
[
TModelInputForGPU
]):
"""
Helper class for shared methods between GPU model runners.
...
...
@@ -251,7 +673,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
flashinfer_prefill_workspace_buffer
=
None
self
.
flashinfer_prefill_wrapper
=
None
set_cpu_offload_max_bytes
(
int
(
self
.
cache_config
.
cpu_offload_gb
*
1024
**
3
))
def
load_model
(
self
)
->
None
:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
with
CudaMemoryProfiler
()
as
m
:
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
device_config
=
self
.
device_config
,
...
...
@@ -368,464 +794,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
If cuda graph is required, this API automatically pads inputs.
"""
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
lora_index_mapping
:
List
[
int
]
=
[]
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_requests
:
Set
[
LoRARequest
]
=
set
()
prompt_adapter_index_mapping
:
List
[
int
]
=
[]
prompt_adapter_prompt_mapping
:
List
[
int
]
=
[]
prompt_adapter_requests
:
Set
[
PromptAdapterRequest
]
=
set
()
seq_lens
:
List
[
int
]
=
[]
prefill_seq_lens
:
List
[
int
]
=
[]
decode_seq_lens
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
query_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
multi_modal_inputs_list
:
List
[
MultiModalInputs
]
=
[]
request_ids_to_seq_ids
:
Dict
[
str
,
List
[
int
]]
=
defaultdict
(
list
)
decode_only
=
True
num_prefills
=
0
num_prefill_tokens
=
0
num_decode_tokens
=
0
# The following fields are only for flashinfer
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
paged_kv_indices
:
List
[
int
]
=
[]
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first request’s page indices in the paged_kv_indices list.
paged_kv_indptr
:
List
[
int
]
=
[
0
]
# paged_kv_last_page_len is the length of the last page of each request
paged_kv_last_page_len
:
List
[
int
]
=
[]
if
len
(
seq_group_metadata_list
)
==
0
:
return
self
.
_model_input_cls
()
if
self
.
sliding_window
is
not
None
:
sliding_window_blocks
=
(
self
.
sliding_window
+
self
.
block_size
-
1
)
//
self
.
block_size
block_aligned_sliding_window
=
\
sliding_window_blocks
*
self
.
block_size
builder
=
ModelInputForGPUBuilder
(
weakref
.
proxy
(
self
),
finished_requests_ids
)
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
is_prompt
=
seq_group_metadata
.
is_prompt
for
seq_id
in
seq_ids
:
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
if
(
self
.
scheduler_config
is
not
None
and
self
.
scheduler_config
.
chunked_prefill_enabled
and
not
(
computed_block_nums
is
None
or
computed_block_nums
==
[])):
raise
RuntimeError
(
"chunked prefill cannot be used with prefix caching "
"now."
)
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
if
is_prompt
:
context_len
=
seq_data
.
get_num_computed_tokens
()
else
:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
context_len
=
seq_data
.
get_len
()
-
1
seq_len
=
min
(
seq_data
.
get_len
(),
context_len
+
seq_group_metadata
.
token_chunk_size
)
if
is_prompt
:
tokens
=
seq_data
.
get_token_ids
()[
context_len
:
seq_len
]
else
:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens
=
[
seq_data
.
get_last_token_id
()]
# Prefix cache was hit.
# Prefix is not supported with sliding_window
prefix_cache_hit
=
(
computed_block_nums
is
not
None
and
len
(
computed_block_nums
)
>
0
and
self
.
sliding_window
is
None
and
is_prompt
)
# These are seq_len/context_len capped to the sliding window.
# They are passed to decode kernel.
# We still need original seq_len/context_len to compute slot
# mapping (and input position) below.
curr_sliding_window_blocks
=
None
sliding_seq_len
=
seq_len
sliding_context_len
=
context_len
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
if
(
self
.
sliding_window
is
not
None
and
not
is_prompt
):
curr_sliding_window_blocks
=
sliding_window_blocks
if
self
.
scheduler_config
.
use_v2_block_manager
:
# number of elements in last block
suff_len
=
seq_len
%
self
.
block_size
sliding_seq_len
=
min
(
seq_len
,
block_aligned_sliding_window
+
suff_len
)
if
suff_len
>
0
:
curr_sliding_window_blocks
+=
1
else
:
sliding_seq_len
=
min
(
seq_len
,
self
.
sliding_window
)
sliding_context_len
=
sliding_seq_len
-
1
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
if
prefix_cache_hit
:
assert
computed_block_nums
is
not
None
context_len
=
len
(
computed_block_nums
)
*
self
.
block_size
tokens
=
tokens
[
context_len
:]
# need to think what to set it to when we have both sliding
# window and prefix caching...
assert
self
.
sliding_window
is
None
,
\
"Prefix caching is not supported with sliding window"
sliding_context_len
=
context_len
if
self
.
attn_backend
.
get_name
()
==
"flash-attn"
:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
# TODO(woosuk): This is a temporary fix. We should
# provide a unified interface for different backends.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
else
:
block_table
=
computed_block_nums
elif
(
self
.
scheduler_config
.
chunked_prefill_enabled
or
not
is_prompt
):
if
seq_group_metadata
.
block_tables
is
not
None
:
# chunked prefill or decode
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
if
curr_sliding_window_blocks
is
not
None
:
block_table
=
block_table
[
-
curr_sliding_window_blocks
:]
else
:
# Only happens when memory profiling runs.
block_table
=
[]
else
:
# Prefill without chunked prefill or memory profiling.
block_table
=
[]
block_tables
.
append
(
block_table
)
seq_lens
.
append
(
sliding_seq_len
)
context_lens
.
append
(
sliding_context_len
)
query_len
=
sliding_seq_len
-
sliding_context_len
query_lens
.
append
(
query_len
)
input_tokens
.
extend
(
tokens
)
input_positions
.
extend
(
list
(
range
(
context_len
,
seq_len
)))
lora_id
=
seq_group_metadata
.
lora_int_id
prompt_adapter_id
=
seq_group_metadata
.
prompt_adapter_id
if
is_prompt
:
assert
len
(
seq_ids
)
==
1
num_prefills
+=
1
num_prefill_tokens
+=
len
(
tokens
)
decode_only
=
False
prefill_seq_lens
.
append
(
seq_len
)
else
:
assert
query_len
==
1
,
(
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
seq_len
,
context_len
,
query_len
))
num_decode_tokens
+=
query_len
decode_seq_lens
.
append
(
sliding_seq_len
)
if
lora_id
>
0
:
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
lora_index_mapping
+=
[
lora_id
]
*
query_len
lora_prompt_mapping
.
extend
(
[
lora_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
is
not
None
else
1
))
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
:
# Process multi-modal data
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
multi_modal_inputs_list
.
append
(
mm_kwargs
)
if
prompt_adapter_id
>
0
and
is_prompt
:
prompt_adapter_requests
.
add
(
seq_group_metadata
.
prompt_adapter_request
)
num_tokens
=
seq_group_metadata
.
\
prompt_adapter_num_virtual_tokens
pm
=
[
prompt_adapter_id
]
*
num_tokens
+
[
0
]
*
(
query_len
-
num_tokens
)
prompt_adapter_index_mapping
+=
pm
prompt_adapter_prompt_mapping
.
extend
(
[
prompt_adapter_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
is_profile_run
=
_is_block_tables_empty
(
seq_group_metadata
.
block_tables
)
if
is_profile_run
:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping
.
extend
([
_PAD_SLOT_ID
]
*
seq_len
)
continue
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with
# _PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
if
is_prompt
:
assert
self
.
scheduler_config
.
use_v2_block_manager
\
or
context_len
==
0
,
(
"Prefix caching is currently not supported with "
"sliding window attention in V1 block manager"
)
# It is an optimization. When it is decoding, it is always
# 0. When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx
=
max
(
0
,
query_len
-
self
.
sliding_window
)
for
i
in
range
(
context_len
,
seq_len
):
if
i
<
start_idx
:
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
# Prepare input tensors for flashinfer
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
seq_len
=
seq_data
.
get_len
()
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound
=
seq_len
//
self
.
block_size
+
1
\
if
seq_len
%
self
.
block_size
!=
0
\
else
seq_len
//
self
.
block_size
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
paged_kv_indptr
.
append
(
paged_kv_indptr
[
-
1
]
+
block_table_bound
)
last_page_len
=
seq_len
%
self
.
block_size
if
last_page_len
==
0
:
last_page_len
=
self
.
block_size
paged_kv_last_page_len
.
append
(
last_page_len
)
batch_size
=
len
(
input_tokens
)
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
decode_seq_lens
,
default
=
0
)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
use_captured_graph
=
(
decode_only
and
not
self
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
max_decode_seq_len
<=
self
.
max_seq_len_to_capture
)
if
use_captured_graph
:
graph_batch_size
=
_get_graph_batch_size
(
batch_size
)
assert
graph_batch_size
>=
batch_size
for
_
in
range
(
graph_batch_size
-
batch_size
):
input_tokens
.
append
(
0
)
input_positions
.
append
(
0
)
slot_mapping
.
append
(
_PAD_SLOT_ID
)
seq_lens
.
append
(
1
)
block_tables
.
append
([])
lora_index_mapping
.
append
(
0
)
prompt_adapter_index_mapping
.
append
(
0
)
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
last_paged_kv_indptr
=
paged_kv_indptr
[
-
1
]
paged_kv_indptr
.
append
(
last_paged_kv_indptr
)
paged_kv_last_page_len
.
append
(
0
)
batch_size
=
graph_batch_size
num_decode_tokens
=
batch_size
if
use_captured_graph
:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
self
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
self
.
device
)
else
:
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
block_tables
)
block_tables
=
make_tensor_with_pad
(
block_tables
,
max_len
=
max_block_table_len
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
self
.
device
,
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions_tensor
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping_tensor
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
logits_soft_cap
=
getattr
(
self
.
model_config
.
hf_config
,
'attn_logit_softcapping'
,
None
)
if
logits_soft_cap
is
not
None
and
self
.
attn_backend
.
get_name
(
)
!=
"flashinfer"
:
raise
ValueError
(
"Please use Flashinfer backend for models with"
"logits_soft_cap (i.e., Gemma-2)."
" Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER."
)
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
if
len
(
paged_kv_indptr
)
>
0
:
paged_kv_indices_tensor
=
torch
.
tensor
(
paged_kv_indices
,
device
=
'cpu'
,
dtype
=
torch
.
int
)
paged_kv_indptr_tensor
=
torch
.
tensor
(
paged_kv_indptr
,
device
=
'cpu'
,
dtype
=
torch
.
int
)
paged_kv_last_page_len_tensor
=
torch
.
tensor
(
paged_kv_last_page_len
,
device
=
'cpu'
,
dtype
=
torch
.
int
)
else
:
paged_kv_indices_tensor
=
None
paged_kv_indptr_tensor
=
None
paged_kv_last_page_len_tensor
=
None
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
kv_cache_dtype
,
self
.
model_config
.
dtype
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
max_prefill_seq_len
=
max_prefill_seq_len
,
block_tables
=
block_tables
,
paged_kv_indptr
=
paged_kv_indptr_tensor
,
paged_kv_indices
=
paged_kv_indices_tensor
,
paged_kv_last_page_len
=
paged_kv_last_page_len_tensor
,
num_qo_heads
=
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
),
num_kv_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
),
head_dim
=
self
.
model_config
.
get_head_size
(),
page_size
=
self
.
block_size
,
seq_start_loc
=
seq_start_loc
,
query_start_loc
=
query_start_loc
,
device
=
self
.
device
,
data_type
=
kv_cache_dtype
,
use_cuda_graph
=
use_captured_graph
,
logits_soft_cap
=
logits_soft_cap
)
else
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
if
self
.
lora_config
:
lora_mapping
=
LoRAMapping
(
lora_index_mapping
,
lora_prompt_mapping
,
)
else
:
lora_mapping
=
None
if
self
.
prompt_adapter_config
:
prompt_adapter_mapping
=
PromptAdapterMapping
(
prompt_adapter_index_mapping
,
prompt_adapter_prompt_mapping
,
)
else
:
prompt_adapter_mapping
=
None
multi_modal_kwargs
=
MultiModalInputs
.
batch
(
multi_modal_inputs_list
,
device
=
self
.
device
)
request_ids_to_seq_ids
=
{
seq_group_metadata
.
request_id
:
list
(
seq_group_metadata
.
seq_data
.
keys
())
for
seq_group_metadata
in
seq_group_metadata_list
}
return
self
.
_model_input_cls
(
input_tokens
=
input_tokens_tensor
,
input_positions
=
input_positions_tensor
,
attn_metadata
=
attn_metadata
,
seq_lens
=
seq_lens
,
query_lens
=
query_lens
,
lora_mapping
=
lora_mapping
,
lora_requests
=
lora_requests
,
multi_modal_kwargs
=
multi_modal_kwargs
,
request_ids_to_seq_ids
=
request_ids_to_seq_ids
,
finished_requests_ids
=
finished_requests_ids
,
prompt_adapter_mapping
=
prompt_adapter_mapping
,
prompt_adapter_requests
=
prompt_adapter_requests
,
)
builder
.
add_seq_group
(
seq_group_metadata
)
return
builder
.
build
()
# type: ignore
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
...
...
@@ -847,7 +820,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dummy_lora_request
=
LoRARequest
(
lora_name
=
f
"warmup_
{
lora_id
}
"
,
lora_int_id
=
lora_id
,
lora_
local_
path
=
"/not/a/real/path"
,
lora_path
=
"/not/a/real/path"
,
)
self
.
lora_manager
.
add_dummy_lora
(
dummy_lora_request
,
rank
=
LORA_WARMUP_RANK
)
...
...
@@ -1549,15 +1522,3 @@ def _get_graph_batch_size(batch_size: int) -> int:
else
:
return
((
batch_size
+
_BATCH_SIZE_ALIGNMENT
-
1
)
//
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
def
_is_block_tables_empty
(
block_tables
:
Union
[
None
,
Dict
]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if
block_tables
is
None
:
return
True
if
isinstance
(
block_tables
,
dict
)
and
all
(
value
is
None
for
value
in
block_tables
.
values
()):
return
True
return
False
vllm/worker/model_runner_base.py
View file @
500b93c8
...
...
@@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import
torch
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
SequenceGroupMetadata
)
...
...
@@ -113,6 +114,21 @@ class ModelRunnerInputBase(ABC):
raise
NotImplementedError
class
ModelRunnerInputBuilderBase
(
ABC
,
Generic
[
T
]):
"""A builder to create ModelRunnerInputBase objects.
"""
@
abstractmethod
def
add_seq_group
(
self
,
seq_group_metadata
):
"""TBA"""
raise
NotImplementedError
@
abstractmethod
def
build
(
self
,
*
args
,
**
kwargs
)
->
T
:
"""Build metadata with on-device tensors."""
raise
NotImplementedError
class
ModelRunnerBase
(
ABC
,
Generic
[
T
]):
"""
Model runner interface that abstracts a particular hardware and/or type of
...
...
@@ -148,7 +164,7 @@ class ModelRunnerBase(ABC, Generic[T]):
"""
raise
NotImplementedError
@
torch
.
inference_mode
()
@
current_platform
.
inference_mode
()
def
execute_model
(
self
,
model_input
:
T
,
...
...
vllm/worker/neuron_model_runner.py
View file @
500b93c8
...
...
@@ -121,13 +121,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
max_seq_len
=
max
(
seq_lens
)
assert
max_seq_len
>
0
input_tokens
=
make_tensor_with_pad
(
input_tokens
,
max_seq_len
,
pad
=
0
,
max_len
=
max_seq_len
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
make_tensor_with_pad
(
input_positions
,
max_seq_len
,
pad
=
0
,
max_len
=
max_seq_len
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_block_ids
=
torch
.
tensor
(
input_block_ids
,
...
...
@@ -171,13 +171,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_block_ids
.
append
(
block_table
[
0
])
input_tokens
=
make_tensor_with_pad
(
input_tokens
,
max_len
=
1
,
pad
=
0
,
max_len
=
1
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
make_tensor_with_pad
(
input_positions
,
max_len
=
1
,
pad
=
0
,
max_len
=
1
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
context_lens
=
torch
.
tensor
(
context_lens
,
...
...
vllm/worker/tpu_model_runner.py
View file @
500b93c8
import
time
from
typing
import
List
,
Mapping
,
Optional
,
Tuple
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
numpy
as
np
import
torch
...
...
@@ -12,12 +13,16 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensors
,
MultiModalInputs
)
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
_add_attn_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
)
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
logger
=
init_logger
(
__name__
)
...
...
@@ -29,7 +34,44 @@ _ENABLE_TOP_P = False
_MAX_NUM_SAMPLES
=
128
class
TPUModelRunner
:
@
dataclass
(
frozen
=
True
)
class
ModelInputForTPU
(
ModelRunnerInputBase
):
token_ids
:
torch
.
Tensor
position_ids
:
torch
.
Tensor
attn_metadata
:
AttentionMetadata
input_lens
:
torch
.
Tensor
t
:
torch
.
Tensor
p
:
torch
.
Tensor
num_samples
:
int
best_of
:
List
[
int
]
seq_groups
:
List
[
List
[
int
]]
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
tensor_dict
=
{
"token_ids"
:
self
.
token_ids
,
"position_ids"
:
self
.
position_ids
,
"input_lens"
:
self
.
input_lens
,
"t"
:
self
.
t
,
"p"
:
self
.
p
,
"num_samples"
:
self
.
num_samples
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
:
Type
[
"ModelInputForTPU"
],
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
"ModelInputForTPU"
:
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
attn_backend
,
tensor_dict
)
return
cls
(
**
tensor_dict
)
class
TPUModelRunner
(
ModelRunnerBase
[
ModelInputForTPU
]):
def
__init__
(
self
,
...
...
@@ -68,10 +110,6 @@ class TPUModelRunner:
False
,
)
# Multi-modal data support
self
.
multi_modal_input_mapper
=
MULTIMODAL_REGISTRY
\
.
create_input_mapper
(
self
.
model_config
)
def
load_model
(
self
)
->
None
:
self
.
device
=
self
.
device_config
.
device
...
...
@@ -85,6 +123,7 @@ class TPUModelRunner:
multimodal_config
=
self
.
multimodal_config
,
lora_config
=
None
,
)
model
=
model
.
eval
()
xm
.
wait_device_ops
()
model
=
ModelWrapper
(
model
)
...
...
@@ -153,8 +192,8 @@ class TPUModelRunner:
# Dummy run.
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
self
.
model
(
token_ids
,
position_ids
,
kv_caches
,
attn_metadata
,
input_
le
n
s
,
None
,
t
,
p
,
num_sampl
es
)
self
.
model
(
token_ids
,
position_ids
,
attn_metadata
,
input_lens
,
t
,
p
,
num_samp
les
,
kv_cach
es
)
def
warmup_model
(
self
,
...
...
@@ -183,7 +222,7 @@ class TPUModelRunner:
# Decode
start
=
time
.
time
()
seq_len
=
1
batch_size
=
1
batch_size
=
8
# Must be in sync with _get_padded_batch_size()
while
True
:
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
is_prompt
=
False
)
xm
.
wait_device_ops
()
...
...
@@ -199,14 +238,12 @@ class TPUModelRunner:
def
_prepare_prompt
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
,
Mapping
[
str
,
BatchedTensors
]]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]
]
=
[]
input_positions
:
List
[
List
[
int
]
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
prompt_lens
:
List
[
int
]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
multi_modal_inputs_list
:
List
[
MultiModalInputs
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
...
...
@@ -220,78 +257,62 @@ class TPUModelRunner:
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
input_tokens
.
app
end
(
prompt_tokens
)
input_positions
.
app
end
(
list
(
range
(
prompt_len
)))
input_tokens
.
ext
end
(
prompt_tokens
)
input_positions
.
ext
end
(
list
(
range
(
prompt_len
)))
assert
seq_group_metadata
.
block_tables
is
not
None
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
slot_mapping
.
append
([])
for
i
in
range
(
prompt_len
):
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
[
-
1
].
append
(
slot
)
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
:
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
multi_modal_inputs_list
.
append
(
mm_kwargs
)
slot_mapping
.
append
(
slot
)
# Add paddings to EACH prompt to the smallest power of 2 that is
# greater than or equal to the prompt length.
# We pad the seq_len to reduce the compilation overhead.
# We execute each prompt individually (i.e., with batch_size 1)
# because the FlashAttention kernel does not support ragged inputs.
# TODO(woosuk): Use SplashAttention to support ragged inputs.
padded_prompt_len
=
_get_padded_prefill_len
(
prompt_len
)
num_paddings
=
padded_prompt_len
-
prompt_len
input_tokens
+=
[
0
]
*
num_paddings
input_positions
+=
[
0
]
*
num_paddings
slot_mapping
+=
[
_PAD_SLOT_ID
]
*
num_paddings
assert
len
(
prompt_lens
)
>
0
num_prefills
=
len
(
prompt_lens
)
num_prefill_tokens
=
sum
(
prompt_lens
)
# Add paddings to make the shape [batch_size, max_prompt_len] where
# max_prompt_len is smallest power of 2 that is greater than or equal
# to the maximum prompt length.
# We need the 2D input shape because the Pallas FlashAttention kernel
# does not support packed 1D inputs.
# We pad the seq_len to powers of 2 to reduce the compilation overhead.
max_prompt_len
=
_get_padded_prefill_len
(
max
(
prompt_lens
))
input_tokens
=
make_tensor_with_pad
(
input_tokens
,
max_prompt_len
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
input_positions
=
make_tensor_with_pad
(
input_positions
,
max_prompt_len
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slot_mapping
=
make_tensor_with_pad
(
slot_mapping
,
max_prompt_len
,
pad
=
_PAD_SLOT_ID
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int64
,
device
=
"cpu"
)
prompt_lens
=
torch
.
tensor
(
prompt_lens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
# NOTE: This is not used.
num_prefill_tokens
=
0
,
# NOTE: This is not used.
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
block_tables
=
None
,
context_lens
=
None
,
)
multi_modal_kwargs
=
MultiModalInputs
.
batch
(
multi_modal_inputs_list
,
device
=
self
.
device
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
,
multi_modal_kwargs
)
return
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
def
_prepare_decode
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
,
Mapping
[
str
,
BatchedTensors
]]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
context_lens
:
List
[
int
]
=
[]
multi_modal_inputs_list
:
List
[
MultiModalInputs
]
=
[]
batch_idx
=
0
for
seq_group_metadata
in
seq_group_metadata_list
:
...
...
@@ -317,11 +338,6 @@ class TPUModelRunner:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
([
slot
])
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
:
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
multi_modal_inputs_list
.
append
(
mm_kwargs
)
batch_size
=
_get_padded_batch_size
(
batch_idx
)
num_paddings
=
batch_size
-
batch_idx
input_tokens
=
input_tokens
+
[[
0
]]
*
num_paddings
...
...
@@ -331,22 +347,22 @@ class TPUModelRunner:
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
device
=
"cpu"
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
block_tables
=
torch
.
tensor
(
self
.
block_tables
[:
batch_size
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
input_lens
=
torch
.
tensor
([
1
]
*
batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
...
...
@@ -355,12 +371,7 @@ class TPUModelRunner:
block_tables
=
block_tables
,
context_lens
=
context_lens
,
)
multi_modal_kwargs
=
MultiModalInputs
.
batch
(
multi_modal_inputs_list
,
device
=
self
.
device
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
,
multi_modal_kwargs
)
return
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
def
_prepare_sample
(
self
,
...
...
@@ -412,16 +423,18 @@ class TPUModelRunner:
t
+=
[
1.0
]
*
num_paddings
p
+=
[
1.0
]
*
num_paddings
t
=
torch
.
tensor
(
t
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
p
=
torch
.
tensor
(
p
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
t
=
torch
.
tensor
(
t
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
p
=
torch
.
tensor
(
p
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
return
t
,
p
,
best_of
def
_execut
e_model
(
def
prepar
e_model
_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
List
[
CompletionSequenceGroupOutput
]:
# Prepare inputs.
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
)
->
ModelInputForTPU
:
del
finished_requests_ids
# Unused.
assert
virtual_engine
==
0
assert
len
(
seq_group_metadata_list
)
>
0
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
...
...
@@ -430,16 +443,104 @@ class TPUModelRunner:
inputs
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
padded_batch_size
=
inputs
[
0
].
shape
[
0
]
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
=
inputs
padded_batch_size
=
input_tokens
.
shape
[
0
]
t
,
p
,
best_of
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
padded_batch_size
)
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
# Execute the model.
next_token_ids
=
self
.
model
(
inputs
[
0
],
inputs
[
1
],
kv_caches
,
*
inputs
[
2
:],
t
,
p
,
num_samples
)
# Retrieve the outputs to CPU.
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
seq_groups
=
[
list
(
metadata
.
seq_data
.
keys
())
for
metadata
in
seq_group_metadata_list
]
return
ModelInputForTPU
(
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
,
t
,
p
,
num_samples
,
best_of
,
seq_groups
)
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
ModelInputForTPU
:
model_input
=
ModelInputForTPU
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
self
.
attn_backend
)
return
model_input
def
execute_model
(
self
,
model_input
:
ModelInputForTPU
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
)
->
List
[
SamplerOutput
]:
assert
intermediate_tensors
is
None
if
num_steps
>
1
:
raise
ValueError
(
"TPUModelRunner does not support multi-step execution."
)
def
_execute_model
(
*
args
,
clone
:
bool
=
False
)
->
torch
.
Tensor
:
"""Move input args from CPU to device and execute the model."""
def
_copy_to_device
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
clone
:
# When x is a slice of a CPU tensor, XLA may copy the whole
# original tensor to TPU instead of only copying x.
# To avoid this, we copy x after cloning.
x
=
x
.
clone
()
return
x
.
to
(
self
.
device
)
new_args
=
[]
for
arg
in
args
:
if
isinstance
(
arg
,
torch
.
Tensor
):
arg
=
_copy_to_device
(
arg
)
elif
isinstance
(
arg
,
AttentionMetadata
):
arg
.
slot_mapping
=
_copy_to_device
(
arg
.
slot_mapping
)
if
getattr
(
arg
,
"block_tables"
,
None
)
is
not
None
:
arg
.
block_tables
=
_copy_to_device
(
arg
.
block_tables
)
if
getattr
(
arg
,
"context_lens"
,
None
)
is
not
None
:
arg
.
context_lens
=
_copy_to_device
(
arg
.
context_lens
)
new_args
.
append
(
arg
)
return
self
.
model
(
*
new_args
)
num_prefills
=
model_input
.
attn_metadata
.
num_prefills
is_prompt
=
num_prefills
>
0
if
is_prompt
:
# NOTE(woosuk): Since the FlashAttention kernel does not support
# ragged inputs, we split the prompts into different batches and
# process them separately. This is a temporary hack that should be
# optimized by using SplashAttention.
next_token_ids
=
[]
orig_slot_mapping
=
model_input
.
attn_metadata
.
slot_mapping
batch_size
=
model_input
.
input_lens
.
shape
[
0
]
start_idx
=
0
for
i
in
range
(
batch_size
):
# Get the actual prefill_len.
prefill_len
=
model_input
.
input_lens
[
i
:
i
+
1
].
item
()
prefill_len
=
_get_padded_prefill_len
(
prefill_len
)
end_idx
=
start_idx
+
prefill_len
model_input
.
attn_metadata
.
slot_mapping
=
orig_slot_mapping
[
None
,
start_idx
:
end_idx
]
model_input
.
attn_metadata
.
num_prefills
=
1
output_token_ids
=
_execute_model
(
model_input
.
token_ids
[
None
,
start_idx
:
end_idx
],
model_input
.
position_ids
[
None
,
start_idx
:
end_idx
],
model_input
.
attn_metadata
,
model_input
.
input_lens
[
i
:
i
+
1
],
model_input
.
t
[
i
:
i
+
1
],
model_input
.
p
[
i
:
i
+
1
],
model_input
.
num_samples
,
kv_caches
,
clone
=
True
)
# Retrieve the outputs to CPU.
next_token_ids
+=
output_token_ids
.
cpu
().
tolist
()
start_idx
=
end_idx
else
:
# Execute the model.
output_token_ids
=
_execute_model
(
model_input
.
token_ids
,
model_input
.
position_ids
,
model_input
.
attn_metadata
,
model_input
.
input_lens
,
model_input
.
t
,
model_input
.
p
,
model_input
.
num_samples
,
kv_caches
)
# Retrieve the outputs to CPU.
next_token_ids
=
output_token_ids
.
cpu
().
tolist
()
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
...
...
@@ -447,13 +548,13 @@ class TPUModelRunner:
zero_logprob
=
Logprob
(
0.0
)
batch_idx
=
0
sampler_outputs
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group
in
model_input
.
seq_groups
:
seq_ids
=
seq_group
seq_outputs
=
[]
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
if
is_prompt
:
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
for
i
in
range
(
best_of
[
batch_idx
]):
for
i
in
range
(
model_input
.
best_of
[
batch_idx
]):
next_token_id
=
next_token_ids
[
batch_idx
][
i
]
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
...
...
@@ -468,35 +569,6 @@ class TPUModelRunner:
batch_idx
+=
1
sampler_outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
return
sampler_outputs
def
execute_model
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
num_steps
:
int
=
1
,
)
->
List
[
SamplerOutput
]:
if
num_steps
>
1
:
raise
ValueError
(
"TPUModelRunner does not support multi-step execution."
)
assert
seq_group_metadata_list
is
not
None
assert
len
(
seq_group_metadata_list
)
>
0
if
seq_group_metadata_list
[
0
].
is_prompt
:
# NOTE(woosuk): To reduce the compilation time, we only compile the
# prefill inputs with batch size 1. Because the scheduler is not
# aware of this limitation, we need to handle batch size > 1
# internally by calling the model multiple times and concatenating
# the outputs.
# FIXME(woosuk): This is a temporary hack to not change the existing
# scheduler. We need to fix this in the future.
sampler_outputs
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
sampler_outputs
+=
self
.
_execute_model
([
seq_group_metadata
],
kv_caches
)
else
:
sampler_outputs
=
self
.
_execute_model
(
seq_group_metadata_list
,
kv_caches
)
return
[
SamplerOutput
(
sampler_outputs
)]
...
...
@@ -504,39 +576,37 @@ class ModelWrapper(nn.Module):
def
__init__
(
self
,
model
:
nn
.
Module
):
super
().
__init__
()
self
.
model
=
model
.
eval
()
self
.
model
=
model
def
forward
(
self
,
token_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]],
attn_metadata
:
AttentionMetadata
,
input_lens
:
torch
.
Tensor
,
multi_modal_kwargs
:
Optional
[
Mapping
[
str
,
BatchedTensors
]],
t
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
num_samples
:
int
,
kv_caches
:
List
[
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]],
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model and samples the next token.
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size].
multi_modal_kwargs: Keyword arguments from multi-modal data to
pass to the model.
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
num_samples: Number of samples to draw from each logits vector.
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
"""
batch_size
,
seq_len
=
token_ids
.
shape
# Calculate the positions to sample from.
base
_indicies
=
torch
.
arange
(
start
_indicies
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
input_lens
.
device
)
*
seq_len
logits_indices
=
base
_indicies
+
input_lens
-
1
logits_indices
=
start
_indicies
+
input_lens
-
1
# FIXME(woosuk): This is a temporary hack to avoid using the existing
# sampler and sampling metadata.
...
...
@@ -573,7 +643,6 @@ class ModelWrapper(nn.Module):
position_ids
,
kv_caches
,
attn_metadata
,
**
(
multi_modal_kwargs
or
{}),
)
hidden_states
=
hidden_states
.
flatten
(
0
,
1
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
...
...
@@ -598,11 +667,10 @@ def _get_padded_prefill_len(x: int) -> int:
def
_get_padded_batch_size
(
batch_size
:
int
)
->
int
:
if
batch_size
<=
2
:
return
batch_size
elif
batch_size
<=
4
:
return
4
elif
batch_size
<=
8
:
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
# To meet this requirement in the simplest way, we set the minimal batch
# size to 8.
if
batch_size
<=
8
:
return
8
else
:
return
((
batch_size
+
15
)
//
16
)
*
16
...
...
vllm/worker/tpu_worker.py
View file @
500b93c8
...
...
@@ -13,15 +13,16 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
from
vllm.worker.tpu_model_runner
import
TPUModelRunner
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
LoraNotSupportedWorkerBase
,
WorkerInput
)
logger
=
init_logger
(
__name__
)
class
TPUWorker
(
LoraNotSupportedWorkerBase
):
class
TPUWorker
(
LoraNotSupportedWorkerBase
,
LocalOrDistributedWorkerBase
):
def
__init__
(
self
,
...
...
@@ -57,14 +58,15 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self
.
cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
self
.
cache_config
.
cache_dtype
]
self
.
model_runner
=
TPUModelRunner
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
cache_config
,
load_config
,
multimodal_config
,
is_driver_worker
=
is_driver_worker
)
self
.
model_runner
:
TPUModelRunner
=
TPUModelRunner
(
model_config
,
parallel_config
,
scheduler_config
,
device_config
,
cache_config
,
load_config
,
multimodal_config
,
is_driver_worker
=
is_driver_worker
)
def
init_device
(
self
)
->
None
:
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
...
...
@@ -98,8 +100,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): This does not completely eliminate the recompilation
# overhead because dynamo does not cache the compiled results.
xr
.
initialize_cache
(
os
.
path
.
expanduser
(
envs
.
VLLM_XLA_CACHE_PATH
),
readonly
=
False
)
xr
.
initialize_cache
(
envs
.
VLLM_XLA_CACHE_PATH
,
readonly
=
False
)
def
load_model
(
self
):
self
.
model_runner
.
load_model
()
...
...
@@ -197,69 +198,70 @@ class TPUWorker(LoraNotSupportedWorkerBase):
dtype_size
=
get_dtype_size
(
self
.
cache_dtype
)
return
dtype_size
*
total
def
execute_model
(
@
property
def
do_metadata_broadcast
(
self
)
->
bool
:
# TODO(woosuk): Support TP.
return
False
@
property
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
# parallelism.
return
[
self
.
tpu_cache
]
def
prepare_worker_input
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
)
->
List
[
SamplerOutput
]:
if
not
self
.
is_driver_worker
:
self
.
_execute_model_non_driver
()
return
[]
assert
execute_model_req
is
not
None
# Issue cache operations.
self
.
cache_swap
(
execute_model_req
.
blocks_to_swap_in
,
execute_model_req
.
blocks_to_swap_out
,
execute_model_req
.
blocks_to_copy
,
execute_model_req
:
ExecuteModelRequest
,
)
->
WorkerInput
:
virtual_engine
=
execute_model_req
.
virtual_engine
num_seq_groups
=
len
(
execute_model_req
.
seq_group_metadata_list
)
blocks_to_swap_in
=
_make_src_to_dst
(
execute_model_req
.
blocks_to_swap_in
,
"cpu"
,
self
.
device
)
blocks_to_swap_out
=
_make_src_to_dst
(
execute_model_req
.
blocks_to_swap_out
,
self
.
device
,
"cpu"
)
blocks_to_copy
=
_make_src_to_dst
(
execute_model_req
.
blocks_to_copy
,
self
.
device
,
self
.
device
)
return
WorkerInput
(
num_seq_groups
=
num_seq_groups
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
virtual_engine
=
virtual_engine
,
)
# Run the model.
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
assert
len
(
seq_group_metadata_list
)
>
0
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
self
.
tpu_cache
)
return
output
def
cache_swap
(
self
,
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]],
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]],
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]],
)
->
None
:
def
execute_worker
(
self
,
worker_input
:
WorkerInput
)
->
None
:
virtual_engine
=
worker_input
.
virtual_engine
assert
virtual_engine
==
0
attn_backend
=
self
.
model_runner
.
attn_backend
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
if
blocks_to_swap_in
:
# Swap from CPU to TPU.
src_indices
,
dst_indices
=
_make_src_to_dst
(
blocks_to_swap_in
,
"cpu"
,
self
.
device
)
for
i
in
range
(
num_layers
):
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
k
=
cpu_k_cache
[:,
src_indices
].
to
(
self
.
device
)
v
=
cpu_v_cache
[:,
src_indices
].
to
(
self
.
device
)
_insert_kv
(
k
,
v
,
dst_indices
,
tpu_k_cache
,
tpu_v_cache
)
if
blocks_to_swap_out
:
# Swap from TPU to CPU.
src_indices
,
dst_indices
=
_make_src_to_dst
(
blocks_to_swap_out
,
self
.
device
,
"cpu"
)
for
i
in
range
(
num_layers
):
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
cpu_k_cache
[:,
dst_indices
]
=
tpu_k_cache
[:,
src_indices
].
cpu
()
cpu_v_cache
[:,
dst_indices
]
=
tpu_v_cache
[:,
src_indices
].
cpu
()
if
blocks_to_copy
:
src_to_dst
=
_make_src_to_dst
(
blocks_to_copy
,
self
.
device
,
self
.
device
)
attn_backend
.
copy_blocks
(
self
.
tpu_cache
,
src_to_dst
)
def
start_worker_execution_loop
(
self
)
->
None
:
while
self
.
_execute_model_non_driver
():
pass
def
_execute_model_non_driver
(
self
)
->
bool
:
self
.
model_runner
.
execute_model
(
None
,
self
.
tpu_cache
)
return
True
# Issue cache operations.
if
worker_input
.
blocks_to_swap_in
is
not
None
:
src_indices
,
dst_indices
=
worker_input
.
blocks_to_swap_in
if
src_indices
.
numel
()
>
0
:
# Swap from CPU to TPU.
for
i
in
range
(
num_layers
):
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
k
=
cpu_k_cache
[:,
src_indices
].
to
(
self
.
device
)
v
=
cpu_v_cache
[:,
src_indices
].
to
(
self
.
device
)
_insert_kv
(
k
,
v
,
dst_indices
,
tpu_k_cache
,
tpu_v_cache
)
if
worker_input
.
blocks_to_swap_out
is
not
None
:
src_indices
,
dst_indices
=
worker_input
.
blocks_to_swap_out
if
src_indices
.
numel
()
>
0
:
# Swap from TPU to CPU.
for
i
in
range
(
num_layers
):
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
cpu_k_cache
[:,
dst_indices
]
=
tpu_k_cache
[:,
src_indices
]
cpu_v_cache
[:,
dst_indices
]
=
tpu_v_cache
[:,
src_indices
]
if
worker_input
.
blocks_to_copy
is
not
None
:
src_indices
,
dst_indices
=
worker_input
.
blocks_to_copy
if
src_indices
.
numel
()
>
0
:
attn_backend
.
copy_blocks
(
self
.
tpu_cache
,
(
src_indices
,
dst_indices
))
def
_make_src_to_dst
(
...
...
vllm/worker/worker.py
View file @
500b93c8
...
...
@@ -105,7 +105,7 @@ class Worker(LocalOrDistributedWorkerBase):
# initialize_cache.
self
.
cache_engine
:
List
[
CacheEngine
]
# Initialize gpu_cache as embedding models don't initialize kv_caches
self
.
gpu_cache
:
Optional
[
List
[
List
[
torch
.
t
ensor
]]]
=
None
self
.
gpu_cache
:
Optional
[
List
[
List
[
torch
.
T
ensor
]]]
=
None
def
init_device
(
self
)
->
None
:
if
self
.
device_config
.
device
.
type
==
"cuda"
:
...
...
Prev
1
…
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment