Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
500b93c8
"vllm/vscode:/vscode.git/clone" did not exist on "7439b2056adb22894f6e56c629f4ffc275d1e63d"
Commit
500b93c8
authored
Jul 25, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1
parents
99426767
38c4b7e8
Changes
282
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1001 additions
and
730 deletions
+1001
-730
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+4
-4
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+5
-4
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+4
-0
vllm/transformers_utils/configs/chameleon.py
vllm/transformers_utils/configs/chameleon.py
+138
-0
vllm/transformers_utils/detokenizer.py
vllm/transformers_utils/detokenizer.py
+8
-0
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+5
-3
vllm/transformers_utils/tokenizer_group/__init__.py
vllm/transformers_utils/tokenizer_group/__init__.py
+9
-5
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
...ransformers_utils/tokenizer_group/base_tokenizer_group.py
+7
-0
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
...transformers_utils/tokenizer_group/ray_tokenizer_group.py
+3
-1
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
+6
-0
vllm/usage/usage_lib.py
vllm/usage/usage_lib.py
+5
-4
vllm/utils.py
vllm/utils.py
+66
-12
vllm/version.py
vllm/version.py
+1
-1
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+0
-3
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+437
-476
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+17
-1
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+4
-4
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+208
-140
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+73
-71
vllm/worker/worker.py
vllm/worker/worker.py
+1
-1
No files found.
vllm/spec_decode/util.py
View file @
500b93c8
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -53,8 +53,8 @@ def create_sequence_group_output(
...
@@ -53,8 +53,8 @@ def create_sequence_group_output(
token_id_logprob_rank
:
int
,
token_id_logprob_rank
:
int
,
token_id_logprob
:
float
,
token_id_logprob
:
float
,
seq_id
:
SeqId
,
seq_id
:
SeqId
,
topk_token_ids
:
List
[
int
],
topk_token_ids
:
List
[
Optional
[
int
]
]
,
topk_logprobs
:
List
[
float
],
topk_logprobs
:
List
[
Optional
[
float
]
]
,
)
->
CompletionSequenceGroupOutput
:
)
->
CompletionSequenceGroupOutput
:
"""Create a SequenceGroupOutput given the sampling results.
"""Create a SequenceGroupOutput given the sampling results.
...
@@ -68,7 +68,7 @@ def create_sequence_group_output(
...
@@ -68,7 +68,7 @@ def create_sequence_group_output(
"""
"""
# vLLM logprobs always include the sampled token. In addition, the user may
# vLLM logprobs always include the sampled token. In addition, the user may
# request topk-logprobs (where top-k varies per user up to max_logprobs).
# request topk-logprobs (where top-k varies per user up to max_logprobs).
logprobs
:
Dict
[
int
,
Logprob
]
=
{
logprobs
:
Dict
[
Optional
[
int
]
,
Logprob
]
=
{
token_id
:
Logprob
(
token_id
:
Logprob
(
logprob
=
token_id_logprob
,
logprob
=
token_id_logprob
,
rank
=
token_id_logprob_rank
,
rank
=
token_id_logprob_rank
,
...
...
vllm/transformers_utils/config.py
View file @
500b93c8
...
@@ -5,10 +5,10 @@ from transformers import GenerationConfig, PretrainedConfig
...
@@ -5,10 +5,10 @@ from transformers import GenerationConfig, PretrainedConfig
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.configs
import
(
Cha
tGLM
Config
,
Dbrx
Config
,
from
vllm.transformers_utils.configs
import
(
Cha
meleon
Config
,
ChatGLM
Config
,
JAIS
Config
,
Medusa
Config
,
Dbrx
Config
,
JAIS
Config
,
MLPSpeculator
Config
,
MPT
Config
,
M
edusaConfig
,
M
LPSpeculatorConfig
,
RWConfig
)
MPTConfig
,
RWConfig
)
if
VLLM_USE_MODELSCOPE
:
if
VLLM_USE_MODELSCOPE
:
from
modelscope
import
AutoConfig
from
modelscope
import
AutoConfig
...
@@ -18,6 +18,7 @@ else:
...
@@ -18,6 +18,7 @@ else:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
"chameleon"
:
ChameleonConfig
,
"chatglm"
:
ChatGLMConfig
,
"chatglm"
:
ChatGLMConfig
,
"dbrx"
:
DbrxConfig
,
"dbrx"
:
DbrxConfig
,
"mpt"
:
MPTConfig
,
"mpt"
:
MPTConfig
,
...
...
vllm/transformers_utils/configs/__init__.py
View file @
500b93c8
from
vllm.transformers_utils.configs.chameleon
import
(
ChameleonConfig
,
ChameleonVQVAEConfig
)
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
...
@@ -10,6 +12,8 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
...
@@ -10,6 +12,8 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
__all__
=
[
__all__
=
[
"ChameleonConfig"
,
"ChameleonVQVAEConfig"
,
"ChatGLMConfig"
,
"ChatGLMConfig"
,
"DbrxConfig"
,
"DbrxConfig"
,
"MPTConfig"
,
"MPTConfig"
,
...
...
vllm/transformers_utils/configs/chameleon.py
0 → 100644
View file @
500b93c8
from
typing
import
List
,
Optional
from
transformers
import
PretrainedConfig
#TODO (ywang96): Remove this file and import it from
# transformers once the new release with Chameleon support
# is available.
class
ChameleonConfig
(
PretrainedConfig
):
model_type
=
"chameleon"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
vocab_size
=
65536
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
32
,
num_attention_heads
=
32
,
num_key_value_heads
=
32
,
hidden_act
=
"silu"
,
max_position_embeddings
=
4096
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-05
,
use_cache
=
True
,
pad_token_id
=
None
,
bos_token_id
=
1
,
eos_token_id
=
2
,
tie_word_embeddings
=
False
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
model_parallel_size
=
1
,
swin_norm
=
False
,
vq_config
=
None
,
vocabulary_map
=
None
,
mlp_bias
=
False
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
mlp_bias
=
mlp_bias
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
_rope_scaling_validation
()
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
model_parallel_size
=
model_parallel_size
self
.
swin_norm
=
swin_norm
if
vq_config
is
None
:
vq_config
=
{}
self
.
vq_config
=
ChameleonVQVAEConfig
(
**
vq_config
)
self
.
vocabulary_map
=
vocabulary_map
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
def
_rope_scaling_validation
(
self
):
"""
Validate the `rope_scaling` configuration.
"""
if
self
.
rope_scaling
is
None
:
return
if
not
isinstance
(
self
.
rope_scaling
,
dict
)
or
len
(
self
.
rope_scaling
)
!=
2
:
raise
ValueError
(
"`rope_scaling` must be a dictionary with with two fields, "
f
"`type` and `factor`, got
{
self
.
rope_scaling
}
"
)
rope_scaling_type
=
self
.
rope_scaling
.
get
(
"type"
,
None
)
rope_scaling_factor
=
self
.
rope_scaling
.
get
(
"factor"
,
None
)
if
rope_scaling_type
is
None
or
rope_scaling_type
not
in
[
"linear"
,
"dynamic"
]:
raise
ValueError
(
"`rope_scaling`'s type field must be one of ['linear', "
f
"'dynamic'], got
{
rope_scaling_type
}
"
)
if
rope_scaling_factor
is
None
or
not
isinstance
(
rope_scaling_factor
,
float
)
or
rope_scaling_factor
<=
1.0
:
raise
ValueError
(
"`rope_scaling`'s factor field must be a float > 1, "
f
"got
{
rope_scaling_factor
}
"
)
class
ChameleonVQVAEConfig
(
PretrainedConfig
):
model_type
=
"chameleon_vqgan"
def
__init__
(
self
,
embed_dim
:
int
=
256
,
num_embeddings
:
int
=
8192
,
double_latent
:
bool
=
False
,
latent_channels
:
int
=
256
,
resolution
:
int
=
512
,
in_channels
:
int
=
3
,
base_channels
:
int
=
128
,
channel_multiplier
:
List
[
int
]
=
[
1
,
1
,
2
,
2
,
4
],
#noqa
num_res_blocks
:
int
=
2
,
attn_resolutions
:
Optional
[
List
[
int
]]
=
None
,
dropout
:
float
=
0.0
,
attn_type
:
str
=
"vanilla"
,
initializer_range
=
0.02
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
embed_dim
=
embed_dim
self
.
num_embeddings
=
num_embeddings
self
.
double_latent
=
double_latent
self
.
latent_channels
=
latent_channels
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
base_channels
=
base_channels
self
.
channel_multiplier
=
channel_multiplier
self
.
num_res_blocks
=
num_res_blocks
self
.
attn_resolutions
=
attn_resolutions
self
.
dropout
=
dropout
self
.
attn_type
=
attn_type
self
.
initializer_range
=
initializer_range
vllm/transformers_utils/detokenizer.py
View file @
500b93c8
...
@@ -165,6 +165,12 @@ class Detokenizer:
...
@@ -165,6 +165,12 @@ class Detokenizer:
return
len
(
new_decoded_token_text
)
return
len
(
new_decoded_token_text
)
def
_replace_none_with_empty
(
tokens
:
List
[
Optional
[
str
]]):
for
i
,
token
in
enumerate
(
tokens
):
if
token
is
None
:
tokens
[
i
]
=
""
def
_convert_tokens_to_string_with_added_encoders
(
def
_convert_tokens_to_string_with_added_encoders
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
output_tokens
:
List
[
str
],
output_tokens
:
List
[
str
],
...
@@ -223,6 +229,8 @@ def convert_prompt_ids_to_tokens(
...
@@ -223,6 +229,8 @@ def convert_prompt_ids_to_tokens(
read_offset
=
len
(
new_tokens
)
read_offset
=
len
(
new_tokens
)
prefix_offset
=
max
(
prefix_offset
=
max
(
read_offset
-
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET
,
0
)
read_offset
-
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET
,
0
)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty
(
new_tokens
)
return
new_tokens
,
prefix_offset
,
read_offset
return
new_tokens
,
prefix_offset
,
read_offset
...
...
vllm/transformers_utils/tokenizer.py
View file @
500b93c8
...
@@ -88,6 +88,9 @@ def get_tokenizer(
...
@@ -88,6 +88,9 @@ def get_tokenizer(
"Cannot use the fast tokenizer in slow tokenizer mode."
)
"Cannot use the fast tokenizer in slow tokenizer mode."
)
kwargs
[
"use_fast"
]
=
False
kwargs
[
"use_fast"
]
=
False
if
"truncation_side"
not
in
kwargs
:
kwargs
[
"truncation_side"
]
=
"left"
try
:
try
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_name
,
tokenizer_name
,
...
@@ -134,14 +137,13 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
...
@@ -134,14 +137,13 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
if
lora_request
is
None
:
if
lora_request
is
None
:
return
None
return
None
try
:
try
:
tokenizer
=
get_tokenizer
(
lora_request
.
lora_local_path
,
*
args
,
tokenizer
=
get_tokenizer
(
lora_request
.
lora_path
,
*
args
,
**
kwargs
)
**
kwargs
)
except
OSError
as
e
:
except
OSError
as
e
:
# No tokenizer was found in the LoRA folder,
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
# use base model tokenizer
logger
.
warning
(
logger
.
warning
(
"No tokenizer found in %s, using base model tokenizer instead. "
"No tokenizer found in %s, using base model tokenizer instead. "
"(Exception: %s)"
,
lora_request
.
lora_
local_
path
,
e
)
"(Exception: %s)"
,
lora_request
.
lora_path
,
e
)
tokenizer
=
None
tokenizer
=
None
return
tokenizer
return
tokenizer
...
...
vllm/transformers_utils/tokenizer_group/__init__.py
View file @
500b93c8
from
typing
import
Optional
from
typing
import
Optional
,
Type
from
vllm.config
import
TokenizerPoolConfig
from
vllm.config
import
TokenizerPoolConfig
from
vllm.executor.ray_utils
import
ray
from
vllm.executor.ray_utils
import
ray
...
@@ -16,18 +16,22 @@ else:
...
@@ -16,18 +16,22 @@ else:
def
get_tokenizer_group
(
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
],
def
get_tokenizer_group
(
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
],
**
init_kwargs
)
->
BaseTokenizerGroup
:
**
init_kwargs
)
->
BaseTokenizerGroup
:
tokenizer_cls
:
Type
[
BaseTokenizerGroup
]
if
tokenizer_pool_config
is
None
:
if
tokenizer_pool_config
is
None
:
return
TokenizerGroup
(
**
init_kwargs
)
tokenizer_cls
=
TokenizerGroup
if
tokenizer_pool_config
.
pool_type
==
"ray"
:
elif
isinstance
(
tokenizer_pool_config
.
pool_type
,
type
)
and
issubclass
(
tokenizer_pool_config
.
pool_type
,
BaseTokenizerGroup
):
tokenizer_cls
=
tokenizer_pool_config
.
pool_type
elif
tokenizer_pool_config
.
pool_type
==
"ray"
:
if
RayTokenizerGroupPool
is
None
:
if
RayTokenizerGroupPool
is
None
:
raise
ImportError
(
raise
ImportError
(
"RayTokenizerGroupPool is not available. Please install "
"RayTokenizerGroupPool is not available. Please install "
"the ray package to use the Ray tokenizer group pool."
)
"the ray package to use the Ray tokenizer group pool."
)
return
RayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_cls
=
RayTokenizerGroupPool
**
init_kwargs
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Unknown pool type:
{
tokenizer_pool_config
.
pool_type
}
"
)
f
"Unknown pool type:
{
tokenizer_pool_config
.
pool_type
}
"
)
return
tokenizer_cls
.
from_config
(
tokenizer_pool_config
,
**
init_kwargs
)
__all__
=
[
"get_tokenizer_group"
,
"BaseTokenizerGroup"
]
__all__
=
[
"get_tokenizer_group"
,
"BaseTokenizerGroup"
]
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
View file @
500b93c8
...
@@ -3,12 +3,19 @@ from typing import List, Optional
...
@@ -3,12 +3,19 @@ from typing import List, Optional
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
TokenizerPoolConfig
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
class
BaseTokenizerGroup
(
ABC
):
class
BaseTokenizerGroup
(
ABC
):
"""A group of tokenizers that can be used for LoRA adapters."""
"""A group of tokenizers that can be used for LoRA adapters."""
@
classmethod
@
abstractmethod
def
from_config
(
cls
,
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
],
**
init_kwargs
)
->
"BaseTokenizerGroup"
:
pass
@
abstractmethod
@
abstractmethod
def
ping
(
self
)
->
bool
:
def
ping
(
self
)
->
bool
:
"""Check if the tokenizer group is alive."""
"""Check if the tokenizer group is alive."""
...
...
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
View file @
500b93c8
...
@@ -29,8 +29,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
...
@@ -29,8 +29,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
_worker_cls
=
TokenizerGroup
_worker_cls
=
TokenizerGroup
@
classmethod
@
classmethod
def
from_config
(
cls
,
tokenizer_pool_config
:
TokenizerPoolConfig
,
def
from_config
(
cls
,
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
]
,
**
init_kwargs
)
->
"RayTokenizerGroupPool"
:
**
init_kwargs
)
->
"RayTokenizerGroupPool"
:
if
not
tokenizer_pool_config
:
raise
ValueError
(
"tokenizer_pool_config must not be None."
)
ray_actor_options
=
(
tokenizer_pool_config
.
extra_config
or
{
ray_actor_options
=
(
tokenizer_pool_config
.
extra_config
or
{
"num_cpus"
:
0
"num_cpus"
:
0
})
})
...
...
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
View file @
500b93c8
...
@@ -2,6 +2,7 @@ from typing import List, Optional
...
@@ -2,6 +2,7 @@ from typing import List, Optional
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
TokenizerPoolConfig
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizer
import
(
get_lora_tokenizer
,
from
vllm.transformers_utils.tokenizer
import
(
get_lora_tokenizer
,
get_lora_tokenizer_async
,
get_lora_tokenizer_async
,
...
@@ -24,6 +25,11 @@ class TokenizerGroup(BaseTokenizerGroup):
...
@@ -24,6 +25,11 @@ class TokenizerGroup(BaseTokenizerGroup):
self
.
lora_tokenizers
=
LRUCache
[
PreTrainedTokenizer
](
self
.
lora_tokenizers
=
LRUCache
[
PreTrainedTokenizer
](
capacity
=
max_num_seqs
)
if
enable_lora
else
None
capacity
=
max_num_seqs
)
if
enable_lora
else
None
@
classmethod
def
from_config
(
cls
,
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
],
**
init_kwargs
)
->
"TokenizerGroup"
:
return
cls
(
**
init_kwargs
)
def
ping
(
self
)
->
bool
:
def
ping
(
self
)
->
bool
:
"""Check if the tokenizer group is alive."""
"""Check if the tokenizer group is alive."""
return
True
return
True
...
...
vllm/usage/usage_lib.py
View file @
500b93c8
...
@@ -16,12 +16,12 @@ import requests
...
@@ -16,12 +16,12 @@ import requests
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.connections
import
global_http_connection
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
_config_home
=
envs
.
VLLM_CONFIG_ROOT
_config_home
=
envs
.
VLLM_CONFIG_ROOT
_USAGE_STATS_JSON_PATH
=
os
.
path
.
join
(
_config_home
,
"vllm/usage_stats.json"
)
_USAGE_STATS_JSON_PATH
=
os
.
path
.
join
(
_config_home
,
"usage_stats.json"
)
_USAGE_STATS_DO_NOT_TRACK_PATH
=
os
.
path
.
join
(
_config_home
,
_USAGE_STATS_DO_NOT_TRACK_PATH
=
os
.
path
.
join
(
_config_home
,
"do_not_track"
)
"vllm/do_not_track"
)
_USAGE_STATS_ENABLED
=
None
_USAGE_STATS_ENABLED
=
None
_USAGE_STATS_SERVER
=
envs
.
VLLM_USAGE_STATS_SERVER
_USAGE_STATS_SERVER
=
envs
.
VLLM_USAGE_STATS_SERVER
...
@@ -205,7 +205,8 @@ class UsageMessage:
...
@@ -205,7 +205,8 @@ class UsageMessage:
def
_send_to_server
(
self
,
data
):
def
_send_to_server
(
self
,
data
):
try
:
try
:
requests
.
post
(
_USAGE_STATS_SERVER
,
json
=
data
)
global_http_client
=
global_http_connection
.
get_sync_client
()
global_http_client
.
post
(
_USAGE_STATS_SERVER
,
json
=
data
)
except
requests
.
exceptions
.
RequestException
:
except
requests
.
exceptions
.
RequestException
:
# silently ignore unless we are using debug log
# silently ignore unless we are using debug log
logging
.
debug
(
"Failed to send usage data to server"
)
logging
.
debug
(
"Failed to send usage data to server"
)
...
...
vllm/utils.py
View file @
500b93c8
...
@@ -20,6 +20,7 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
...
@@ -20,6 +20,7 @@ from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Union
)
Union
)
import
numpy
as
np
import
numpy
as
np
import
numpy.typing
as
npt
import
psutil
import
psutil
import
torch
import
torch
import
torch.types
import
torch.types
...
@@ -40,6 +41,15 @@ STR_DTYPE_TO_TORCH_DTYPE = {
...
@@ -40,6 +41,15 @@ STR_DTYPE_TO_TORCH_DTYPE = {
# "fp8_e5m2": torch.uint8,
# "fp8_e5m2": torch.uint8,
}
}
TORCH_DTYPE_TO_NUMPY_DTYPE
=
{
torch
.
float16
:
np
.
float16
,
torch
.
float32
:
np
.
float32
,
torch
.
float64
:
np
.
float64
,
torch
.
uint8
:
np
.
uint8
,
torch
.
int32
:
np
.
int32
,
torch
.
int64
:
np
.
int64
,
}
P
=
ParamSpec
(
'P'
)
P
=
ParamSpec
(
'P'
)
K
=
TypeVar
(
"K"
)
K
=
TypeVar
(
"K"
)
T
=
TypeVar
(
"T"
)
T
=
TypeVar
(
"T"
)
...
@@ -415,9 +425,10 @@ def init_kmp_env():
...
@@ -415,9 +425,10 @@ def init_kmp_env():
os
.
environ
[
'KMP_REDUCTION_BARRIER_PATTERN'
]
=
"dist,dist"
os
.
environ
[
'KMP_REDUCTION_BARRIER_PATTERN'
]
=
"dist,dist"
def
chunk_list
(
lst
:
List
[
T
],
chunk_size
:
int
)
->
List
[
List
[
T
]]
:
def
chunk_list
(
lst
:
List
[
T
],
chunk_size
:
int
):
"""Yield successive chunk_size chunks from lst."""
"""Yield successive chunk_size chunks from lst."""
return
[
lst
[
i
:
i
+
chunk_size
]
for
i
in
range
(
0
,
len
(
lst
),
chunk_size
)]
for
i
in
range
(
0
,
len
(
lst
),
chunk_size
):
yield
lst
[
i
:
i
+
chunk_size
]
def
cdiv
(
a
:
int
,
b
:
int
)
->
int
:
def
cdiv
(
a
:
int
,
b
:
int
)
->
int
:
...
@@ -616,23 +627,54 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]:
...
@@ -616,23 +627,54 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]:
f
"(e.g., 1, 2, 3). Given input:
{
s
}
"
)
from
e
f
"(e.g., 1, 2, 3). Given input:
{
s
}
"
)
from
e
def
make_tensor_with_pad
(
def
make_ndarray_with_pad
(
x
:
List
[
List
[
int
]],
x
:
List
[
List
[
T
]],
max_len
:
int
,
pad
:
T
,
pad
:
int
,
dtype
:
npt
.
DTypeLike
,
dtype
:
torch
.
dtype
,
*
,
device
:
Optional
[
Union
[
str
,
torch
.
device
]],
max_len
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
)
->
npt
.
NDArray
:
"""Make a padded tensor of a 2D inputs.
"""
Make a padded array from 2D inputs.
The padding is applied to the end of each inner list until it reaches
The padding is applied to the end of each inner list until it reaches
`max_len`.
`max_len`.
"""
"""
padded_x
=
np
.
zeros
([
len
(
x
),
max_len
],
dtype
=
np
.
int32
)
+
pad
if
max_len
is
None
:
# Unlike for most functions, map is faster than a genexpr over `len`
max_len
=
max
(
map
(
len
,
x
),
default
=
0
)
padded_x
=
np
.
full
((
len
(
x
),
max_len
),
pad
,
dtype
=
dtype
)
for
ind
,
blocktb
in
enumerate
(
x
):
for
ind
,
blocktb
in
enumerate
(
x
):
assert
len
(
blocktb
)
<=
max_len
assert
len
(
blocktb
)
<=
max_len
padded_x
[
ind
,
:
len
(
blocktb
)]
=
blocktb
padded_x
[
ind
,
:
len
(
blocktb
)]
=
blocktb
return
torch
.
tensor
(
padded_x
,
dtype
=
dtype
,
device
=
device
)
return
padded_x
def
make_tensor_with_pad
(
x
:
List
[
List
[
T
]],
pad
:
T
,
dtype
:
torch
.
dtype
,
*
,
max_len
:
Optional
[
int
]
=
None
,
device
:
Optional
[
Union
[
str
,
torch
.
device
]]
=
None
,
pin_memory
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Make a padded tensor from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
np_dtype
=
TORCH_DTYPE_TO_NUMPY_DTYPE
[
dtype
]
padded_x
=
make_ndarray_with_pad
(
x
,
pad
,
np_dtype
,
max_len
=
max_len
)
tensor
=
torch
.
from_numpy
(
padded_x
).
to
(
device
)
if
pin_memory
:
tensor
=
tensor
.
pin_memory
()
return
tensor
def
async_tensor_h2d
(
def
async_tensor_h2d
(
...
@@ -677,6 +719,11 @@ def merge_dicts(dict1: Dict[K, List[T]],
...
@@ -677,6 +719,11 @@ def merge_dicts(dict1: Dict[K, List[T]],
return
dict
(
merged_dict
)
return
dict
(
merged_dict
)
def
flatten_2d_lists
(
lists
:
List
[
List
[
T
]])
->
List
[
T
]:
"""Flatten a list of lists to a single list."""
return
[
item
for
sublist
in
lists
for
item
in
sublist
]
def
init_cached_hf_modules
()
->
None
:
def
init_cached_hf_modules
()
->
None
:
"""
"""
Lazy initialization of the Hugging Face modules.
Lazy initialization of the Hugging Face modules.
...
@@ -939,3 +986,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
...
@@ -939,3 +986,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
processed_args
.
append
(
arg
)
processed_args
.
append
(
arg
)
return
super
().
parse_args
(
processed_args
,
namespace
)
return
super
().
parse_args
(
processed_args
,
namespace
)
async
def
_run_task_with_lock
(
task
:
Callable
,
lock
:
asyncio
.
Lock
,
*
args
,
**
kwargs
):
"""Utility function to run async task in a lock"""
async
with
lock
:
return
await
task
(
*
args
,
**
kwargs
)
vllm/version.py
View file @
500b93c8
...
@@ -9,4 +9,4 @@ except Exception as e:
...
@@ -9,4 +9,4 @@ except Exception as e:
stacklevel
=
2
)
stacklevel
=
2
)
__commit__
=
"COMMIT_HASH_PLACEHOLDER"
__commit__
=
"COMMIT_HASH_PLACEHOLDER"
__version__
=
"0.5.
2
"
__version__
=
"0.5.
3.post1
"
vllm/worker/cpu_model_runner.py
View file @
500b93c8
...
@@ -276,11 +276,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
...
@@ -276,11 +276,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
device
=
self
.
device
)
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
block_tables
)
block_tables
=
make_tensor_with_pad
(
block_tables
=
make_tensor_with_pad
(
block_tables
,
block_tables
,
max_len
=
max_block_table_len
,
pad
=
0
,
pad
=
0
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
self
.
device
,
device
=
self
.
device
,
...
...
vllm/worker/model_runner.py
View file @
500b93c8
...
@@ -2,7 +2,8 @@ import dataclasses
...
@@ -2,7 +2,8 @@ import dataclasses
import
gc
import
gc
import
time
import
time
import
warnings
import
warnings
from
collections
import
defaultdict
import
weakref
from
dataclasses
import
dataclass
,
field
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
,
Union
)
Tuple
,
Type
,
TypeVar
,
Union
)
...
@@ -38,6 +39,7 @@ from vllm.model_executor.model_loader import get_model
...
@@ -38,6 +39,7 @@ from vllm.model_executor.model_loader import get_model
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.models.interfaces
import
(
supports_lora
,
from
vllm.model_executor.models.interfaces
import
(
supports_lora
,
supports_vision
)
supports_vision
)
from
vllm.model_executor.models.utils
import
set_cpu_offload_max_bytes
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensors
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensors
,
MultiModalInputs
)
MultiModalInputs
)
from
vllm.prompt_adapter.layers
import
PromptAdapterMapping
from
vllm.prompt_adapter.layers
import
PromptAdapterMapping
...
@@ -47,10 +49,11 @@ from vllm.prompt_adapter.worker_manager import (
...
@@ -47,10 +49,11 @@ from vllm.prompt_adapter.worker_manager import (
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
get_kv_cache_torch_dtype
,
is_hip
,
from
vllm.utils
import
(
CudaMemoryProfiler
,
flatten_2d_lists
,
is_pin_memory_available
,
make_tensor_with_pad
)
get_kv_cache_torch_dtype
,
is_hip
,
is_pin_memory_available
)
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
_add_attn_metadata_broadcastable_dict
,
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_attn_metadata_from_tensor_dict
,
...
@@ -74,7 +77,7 @@ _NUM_WARMUP_ITERS = 2
...
@@ -74,7 +77,7 @@ _NUM_WARMUP_ITERS = 2
TModelInputForGPU
=
TypeVar
(
'TModelInputForGPU'
,
bound
=
"ModelInputForGPU"
)
TModelInputForGPU
=
TypeVar
(
'TModelInputForGPU'
,
bound
=
"ModelInputForGPU"
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForGPU
(
ModelRunnerInputBase
):
class
ModelInputForGPU
(
ModelRunnerInputBase
):
"""
"""
This base class contains metadata needed for the base model forward pass
This base class contains metadata needed for the base model forward pass
...
@@ -124,7 +127,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
...
@@ -124,7 +127,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
return
cls
(
**
tensor_dict
)
return
cls
(
**
tensor_dict
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForGPUWithSamplingMetadata
(
ModelInputForGPU
):
class
ModelInputForGPUWithSamplingMetadata
(
ModelInputForGPU
):
"""
"""
Used by the ModelRunner.
Used by the ModelRunner.
...
@@ -165,6 +168,425 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
...
@@ -165,6 +168,425 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
return
cls
(
**
tensor_dict
)
return
cls
(
**
tensor_dict
)
class
ModelInputForGPUBuilder
(
ModelRunnerInputBuilderBase
[
ModelInputForGPU
]):
"""Build ModelInputForGPU from SequenceGroupMetadata."""
@
dataclass
class
InterDataForSeqGroup
:
"""Intermediate data for the current sequence group."""
# From sequence group metadata.
request_id
:
str
seq_ids
:
List
[
int
]
is_prompt
:
bool
block_tables
:
Optional
[
Dict
[
int
,
List
[
int
]]]
computed_block_nums
:
List
[
int
]
n_seqs
:
int
=
0
# Input tokens and positions.
input_tokens
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
input_positions
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
# The sequence length (may be capped to the sliding window).
seq_lens
:
List
[
int
]
=
field
(
default_factory
=
list
)
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens
:
List
[
int
]
=
field
(
default_factory
=
list
)
# The query length.
query_lens
:
List
[
int
]
=
field
(
default_factory
=
list
)
# The number of tokens that are already computed.
context_lens
:
List
[
int
]
=
field
(
default_factory
=
list
)
# The current sliding window block.
curr_sliding_window_blocks
:
List
[
int
]
=
field
(
default_factory
=
list
)
# LoRA inputs.
lora_index_mapping
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
lora_prompt_mapping
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
lora_requests
:
Set
[
LoRARequest
]
=
field
(
default_factory
=
set
)
# Prompt adapter inputs.
prompt_adapter_index_mapping
:
List
[
int
]
=
field
(
default_factory
=
list
)
prompt_adapter_prompt_mapping
:
List
[
int
]
=
field
(
default_factory
=
list
)
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
# Multi-modal inputs.
multi_modal_inputs
:
Optional
[
MultiModalInputs
]
=
None
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit
:
bool
=
False
def
__post_init__
(
self
):
self
.
n_seqs
=
len
(
self
.
seq_ids
)
self
.
input_tokens
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
input_positions
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
seq_lens
=
[
0
]
*
self
.
n_seqs
self
.
orig_seq_lens
=
[
0
]
*
self
.
n_seqs
self
.
query_lens
=
[
0
]
*
self
.
n_seqs
self
.
context_lens
=
[
0
]
*
self
.
n_seqs
self
.
curr_sliding_window_blocks
=
[
0
]
*
self
.
n_seqs
self
.
lora_index_mapping
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
lora_prompt_mapping
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
def
__init__
(
self
,
runner
:
"GPUModelRunnerBase"
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
):
super
().
__init__
()
# Compute functions for each sequence in a sequence group.
# WARNING: The order of the functions matters!
self
.
per_seq_compute_fns
=
[
self
.
_compute_lens
,
self
.
_compute_for_prefix_cache_hit
,
self
.
_compute_for_sliding_window
,
self
.
_compute_lora_input
,
]
# Compute functions for each sequence group.
# WARNING: The order of the functions matters!
self
.
per_seq_group_compute_fns
=
[
self
.
_compute_prompt_adapter_input
,
self
.
_compute_multi_modal_input
,
]
self
.
runner
=
runner
self
.
model_input_cls
=
self
.
runner
.
_model_input_cls
self
.
attn_backend
=
self
.
runner
.
attn_backend
self
.
scheduler_config
=
self
.
runner
.
scheduler_config
self
.
sliding_window
=
self
.
runner
.
sliding_window
self
.
block_size
=
self
.
runner
.
block_size
self
.
enable_lora
=
self
.
runner
.
lora_config
is
not
None
self
.
enable_prompt_adapter
=
(
self
.
runner
.
prompt_adapter_config
is
not
None
)
self
.
multi_modal_input_mapper
=
self
.
runner
.
multi_modal_input_mapper
self
.
finished_requests_ids
=
finished_requests_ids
self
.
decode_only
=
True
# Intermediate data (data in CPU before going to GPU) for
# the current sequence group.
self
.
inter_data_list
:
List
[
ModelInputForGPUBuilder
.
InterDataForSeqGroup
]
=
[]
# Attention metadata inputs.
self
.
attn_metadata_builder
=
self
.
attn_backend
.
make_metadata_builder
(
weakref
.
proxy
(
self
))
# Engine/Model configurations.
self
.
chunked_prefill_enabled
=
(
self
.
scheduler_config
is
not
None
and
self
.
scheduler_config
.
chunked_prefill_enabled
)
if
self
.
sliding_window
is
not
None
:
self
.
sliding_window_blocks
=
(
self
.
sliding_window
+
self
.
block_size
-
1
)
//
self
.
block_size
self
.
block_aligned_sliding_window
=
\
self
.
sliding_window_blocks
*
self
.
block_size
def
_compute_lens
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""Compute context length, sequence length and tokens
for the given sequence data.
"""
seq_data
=
seq_group_metadata
.
seq_data
[
inter_data
.
seq_ids
[
seq_idx
]]
token_chunk_size
=
seq_group_metadata
.
token_chunk_size
# Compute context length (the number of tokens that are
# already computed) and sequence length (total number of tokens).
seq_len
=
seq_data
.
get_len
()
if
inter_data
.
is_prompt
:
context_len
=
seq_data
.
get_num_computed_tokens
()
else
:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
context_len
=
seq_len
-
1
seq_len
=
min
(
seq_len
,
context_len
+
token_chunk_size
)
# Compute tokens.
if
inter_data
.
is_prompt
:
tokens
=
seq_data
.
get_token_ids
()[
context_len
:
seq_len
]
else
:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens
=
[
seq_data
.
get_last_token_id
()]
inter_data
.
seq_lens
[
seq_idx
]
=
seq_len
inter_data
.
orig_seq_lens
[
seq_idx
]
=
seq_len
inter_data
.
context_lens
[
seq_idx
]
=
context_len
inter_data
.
input_tokens
[
seq_idx
]
=
tokens
inter_data
.
input_positions
[
seq_idx
]
=
list
(
range
(
context_len
,
seq_len
))
inter_data
.
query_lens
[
seq_idx
]
=
seq_len
-
context_len
if
inter_data
.
is_prompt
else
1
def
_compute_for_prefix_cache_hit
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""Check if hit prefix cache (i.e., some blocks are already computed).
If hit, update input tokens and positions to only compute the
remaining blocks.
"""
computed_block_nums
=
inter_data
.
computed_block_nums
# Note that prefix caching does not support sliding window.
prefix_cache_hit
=
(
computed_block_nums
is
not
None
and
len
(
computed_block_nums
)
>
0
and
self
.
sliding_window
is
None
and
inter_data
.
is_prompt
)
inter_data
.
prefix_cache_hit
=
prefix_cache_hit
if
self
.
chunked_prefill_enabled
and
prefix_cache_hit
:
raise
RuntimeError
(
"chunked prefill cannot be used with prefix caching now."
)
# If prefix cache is hit, advance context length to bypass
# hit blocks. Accordingly, input tokens, position and query length
# have to be updated.
if
prefix_cache_hit
:
assert
computed_block_nums
is
not
None
context_len
=
len
(
computed_block_nums
)
*
self
.
block_size
inter_data
.
input_tokens
[
seq_idx
]
=
inter_data
.
input_tokens
[
seq_idx
][
context_len
:]
inter_data
.
input_positions
[
seq_idx
]
=
inter_data
.
input_positions
[
seq_idx
][
context_len
:]
inter_data
.
context_lens
[
seq_idx
]
=
context_len
inter_data
.
query_lens
[
seq_idx
]
=
inter_data
.
seq_lens
[
seq_idx
]
-
context_len
def
_compute_for_sliding_window
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""Update seq_len and curr_sliding_window_block for the given
sequence data (only required by decoding) if sliding window is enabled.
"""
curr_sliding_window_block
=
0
sliding_seq_len
=
inter_data
.
seq_lens
[
seq_idx
]
if
not
inter_data
.
is_prompt
and
self
.
sliding_window
is
not
None
:
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
curr_sliding_window_block
=
self
.
sliding_window_blocks
if
self
.
scheduler_config
.
use_v2_block_manager
:
# number of elements in last block
suff_len
=
inter_data
.
seq_lens
[
seq_idx
]
%
self
.
block_size
sliding_seq_len
=
min
(
inter_data
.
seq_lens
[
seq_idx
],
self
.
block_aligned_sliding_window
+
suff_len
)
if
suff_len
>
0
:
curr_sliding_window_block
+=
1
else
:
sliding_seq_len
=
min
(
inter_data
.
seq_lens
[
seq_idx
],
self
.
sliding_window
)
inter_data
.
curr_sliding_window_blocks
[
seq_idx
]
=
curr_sliding_window_block
inter_data
.
seq_lens
[
seq_idx
]
=
sliding_seq_len
def
_compute_lora_input
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""If LoRA is enabled, compute LoRA index and prompt mapping."""
if
not
self
.
enable_lora
:
return
lora_id
=
seq_group_metadata
.
lora_int_id
if
lora_id
>
0
:
inter_data
.
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
query_len
=
inter_data
.
query_lens
[
seq_idx
]
inter_data
.
lora_index_mapping
.
append
([
lora_id
]
*
query_len
)
inter_data
.
lora_prompt_mapping
.
append
(
[
lora_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
is
not
None
else
1
))
def
_compute_prompt_adapter_input
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""If prompt adapter is enabled, compute index and prompt mapping.
"""
# Note that when is_prompt=True, we expect only one sequence
# in the group.
if
not
self
.
enable_prompt_adapter
:
return
prompt_adapter_id
=
seq_group_metadata
.
prompt_adapter_id
if
prompt_adapter_id
<=
0
or
not
inter_data
.
is_prompt
:
return
# We expect only one sequence in the group when is_prompt=True.
assert
inter_data
.
n_seqs
==
1
query_len
=
inter_data
.
query_lens
[
0
]
inter_data
.
prompt_adapter_request
=
(
seq_group_metadata
.
prompt_adapter_request
)
num_tokens
=
seq_group_metadata
.
prompt_adapter_num_virtual_tokens
inter_data
.
prompt_adapter_index_mapping
=
[
prompt_adapter_id
]
*
num_tokens
+
[
0
]
*
(
query_len
-
num_tokens
)
inter_data
.
prompt_adapter_prompt_mapping
=
[
prompt_adapter_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
)
def
_compute_multi_modal_input
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""If multi-modal data is given, add it to the input."""
mm_data
=
seq_group_metadata
.
multi_modal_data
if
not
mm_data
:
return
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
inter_data
.
multi_modal_inputs
=
mm_kwargs
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""Add a sequence group to the builder."""
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
n_seqs
=
len
(
seq_ids
)
is_prompt
=
seq_group_metadata
.
is_prompt
if
is_prompt
:
assert
n_seqs
==
1
self
.
decode_only
=
False
inter_data
=
self
.
InterDataForSeqGroup
(
request_id
=
seq_group_metadata
.
request_id
,
seq_ids
=
seq_ids
,
is_prompt
=
is_prompt
,
block_tables
=
seq_group_metadata
.
block_tables
,
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
)
self
.
inter_data_list
.
append
(
inter_data
)
for
seq_idx
in
range
(
n_seqs
):
for
per_seq_fn
in
self
.
per_seq_compute_fns
:
per_seq_fn
(
inter_data
,
seq_idx
,
seq_group_metadata
)
for
per_seq_group_fn
in
self
.
per_seq_group_compute_fns
:
per_seq_group_fn
(
inter_data
,
seq_group_metadata
)
def
build
(
self
)
->
ModelInputForGPU
:
"""Finalize the builder intermediate data and
create on-device tensors.
"""
# Combine and flatten intermediate data.
input_tokens
=
flatten_2d_lists
([
flatten_2d_lists
(
inter_data
.
input_tokens
)
for
inter_data
in
self
.
inter_data_list
])
if
not
input_tokens
:
# This may happen when all prefill requests hit
# prefix caching and there is no decode request.
return
self
.
model_input_cls
()
input_positions
=
flatten_2d_lists
([
flatten_2d_lists
(
inter_data
.
input_positions
)
for
inter_data
in
self
.
inter_data_list
])
seq_lens
=
[]
max_decode_seq_len
=
0
for
inter_data
in
self
.
inter_data_list
:
seq_lens
.
extend
(
inter_data
.
seq_lens
)
if
not
inter_data
.
is_prompt
:
max_decode_seq_len
=
max
(
max_decode_seq_len
,
max
(
inter_data
.
seq_lens
))
query_lens
=
flatten_2d_lists
(
[
inter_data
.
query_lens
for
inter_data
in
self
.
inter_data_list
])
# Mapping from request IDs to sequence IDs. Used for Jamba models
# that manages the cache by itself.
request_ids_to_seq_ids
=
{
data
.
request_id
:
data
.
seq_ids
for
data
in
self
.
inter_data_list
}
batch_size
=
len
(
input_tokens
)
use_captured_graph
=
(
self
.
decode_only
and
not
self
.
runner
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
max_decode_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
cuda_graph_pad_size
=
-
1
if
use_captured_graph
:
graph_batch_size
=
_get_graph_batch_size
(
batch_size
)
assert
graph_batch_size
>=
batch_size
cuda_graph_pad_size
=
graph_batch_size
-
batch_size
batch_size
=
graph_batch_size
# Tokens and positions.
input_tokens
.
extend
([
0
]
*
cuda_graph_pad_size
)
input_positions
.
extend
([
0
]
*
cuda_graph_pad_size
)
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
runner
.
device
)
input_positions_tensor
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
runner
.
device
)
# Sequence and query lengths.
seq_lens
.
extend
([
1
]
*
cuda_graph_pad_size
)
# Attention metadata.
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
seq_lens
,
query_lens
,
cuda_graph_pad_size
,
batch_size
)
# LoRA data.
lora_requests
=
set
()
lora_mapping
=
None
if
self
.
enable_lora
:
lora_requests
=
set
(
r
for
data
in
self
.
inter_data_list
for
r
in
data
.
lora_requests
)
lora_index_mapping
=
flatten_2d_lists
([
flatten_2d_lists
(
inter_data
.
lora_index_mapping
)
for
inter_data
in
self
.
inter_data_list
])
lora_index_mapping
.
extend
([
0
]
*
cuda_graph_pad_size
)
lora_prompt_mapping
=
flatten_2d_lists
([
flatten_2d_lists
(
inter_data
.
lora_prompt_mapping
)
for
inter_data
in
self
.
inter_data_list
])
lora_mapping
=
LoRAMapping
(
lora_index_mapping
,
lora_prompt_mapping
,
)
# Prompt adapter data.
prompt_adapter_requests
:
Set
[
PromptAdapterRequest
]
=
set
()
prompt_adapter_mapping
=
None
if
self
.
enable_prompt_adapter
:
prompt_adapter_requests
=
set
(
data
.
prompt_adapter_request
for
data
in
self
.
inter_data_list
if
data
.
prompt_adapter_request
is
not
None
)
prompt_adapter_index_mapping
=
flatten_2d_lists
([
inter_data
.
prompt_adapter_index_mapping
for
inter_data
in
self
.
inter_data_list
])
prompt_adapter_index_mapping
.
extend
([
0
]
*
cuda_graph_pad_size
)
prompt_adapter_prompt_mapping
=
flatten_2d_lists
([
inter_data
.
prompt_adapter_prompt_mapping
for
inter_data
in
self
.
inter_data_list
])
prompt_adapter_mapping
=
PromptAdapterMapping
(
prompt_adapter_index_mapping
,
prompt_adapter_prompt_mapping
,
)
# Multi-modal data.
multi_modal_inputs_list
=
[
data
.
multi_modal_inputs
for
data
in
self
.
inter_data_list
if
data
.
multi_modal_inputs
is
not
None
]
multi_modal_kwargs
=
MultiModalInputs
.
batch
(
multi_modal_inputs_list
,
device
=
self
.
runner
.
device
)
return
self
.
model_input_cls
(
input_tokens
=
input_tokens_tensor
,
input_positions
=
input_positions_tensor
,
attn_metadata
=
attn_metadata
,
seq_lens
=
seq_lens
,
query_lens
=
query_lens
,
lora_mapping
=
lora_mapping
,
lora_requests
=
lora_requests
,
multi_modal_kwargs
=
multi_modal_kwargs
,
request_ids_to_seq_ids
=
request_ids_to_seq_ids
,
finished_requests_ids
=
self
.
finished_requests_ids
,
prompt_adapter_mapping
=
prompt_adapter_mapping
,
prompt_adapter_requests
=
prompt_adapter_requests
)
class
GPUModelRunnerBase
(
ModelRunnerBase
[
TModelInputForGPU
]):
class
GPUModelRunnerBase
(
ModelRunnerBase
[
TModelInputForGPU
]):
"""
"""
Helper class for shared methods between GPU model runners.
Helper class for shared methods between GPU model runners.
...
@@ -251,7 +673,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -251,7 +673,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
flashinfer_prefill_workspace_buffer
=
None
self
.
flashinfer_prefill_workspace_buffer
=
None
self
.
flashinfer_prefill_wrapper
=
None
self
.
flashinfer_prefill_wrapper
=
None
set_cpu_offload_max_bytes
(
int
(
self
.
cache_config
.
cpu_offload_gb
*
1024
**
3
))
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
with
CudaMemoryProfiler
()
as
m
:
with
CudaMemoryProfiler
()
as
m
:
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
device_config
=
self
.
device_config
,
device_config
=
self
.
device_config
,
...
@@ -368,464 +794,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -368,464 +794,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
If cuda graph is required, this API automatically pads inputs.
If cuda graph is required, this API automatically pads inputs.
"""
"""
input_tokens
:
List
[
int
]
=
[]
builder
=
ModelInputForGPUBuilder
(
weakref
.
proxy
(
self
),
input_positions
:
List
[
int
]
=
[]
finished_requests_ids
)
slot_mapping
:
List
[
int
]
=
[]
lora_index_mapping
:
List
[
int
]
=
[]
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_requests
:
Set
[
LoRARequest
]
=
set
()
prompt_adapter_index_mapping
:
List
[
int
]
=
[]
prompt_adapter_prompt_mapping
:
List
[
int
]
=
[]
prompt_adapter_requests
:
Set
[
PromptAdapterRequest
]
=
set
()
seq_lens
:
List
[
int
]
=
[]
prefill_seq_lens
:
List
[
int
]
=
[]
decode_seq_lens
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
query_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
multi_modal_inputs_list
:
List
[
MultiModalInputs
]
=
[]
request_ids_to_seq_ids
:
Dict
[
str
,
List
[
int
]]
=
defaultdict
(
list
)
decode_only
=
True
num_prefills
=
0
num_prefill_tokens
=
0
num_decode_tokens
=
0
# The following fields are only for flashinfer
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
paged_kv_indices
:
List
[
int
]
=
[]
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first request’s page indices in the paged_kv_indices list.
paged_kv_indptr
:
List
[
int
]
=
[
0
]
# paged_kv_last_page_len is the length of the last page of each request
paged_kv_last_page_len
:
List
[
int
]
=
[]
if
len
(
seq_group_metadata_list
)
==
0
:
return
self
.
_model_input_cls
()
if
self
.
sliding_window
is
not
None
:
sliding_window_blocks
=
(
self
.
sliding_window
+
self
.
block_size
-
1
)
//
self
.
block_size
block_aligned_sliding_window
=
\
sliding_window_blocks
*
self
.
block_size
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
builder
.
add_seq_group
(
seq_group_metadata
)
is_prompt
=
seq_group_metadata
.
is_prompt
return
builder
.
build
()
# type: ignore
for
seq_id
in
seq_ids
:
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
if
(
self
.
scheduler_config
is
not
None
and
self
.
scheduler_config
.
chunked_prefill_enabled
and
not
(
computed_block_nums
is
None
or
computed_block_nums
==
[])):
raise
RuntimeError
(
"chunked prefill cannot be used with prefix caching "
"now."
)
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
if
is_prompt
:
context_len
=
seq_data
.
get_num_computed_tokens
()
else
:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
context_len
=
seq_data
.
get_len
()
-
1
seq_len
=
min
(
seq_data
.
get_len
(),
context_len
+
seq_group_metadata
.
token_chunk_size
)
if
is_prompt
:
tokens
=
seq_data
.
get_token_ids
()[
context_len
:
seq_len
]
else
:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens
=
[
seq_data
.
get_last_token_id
()]
# Prefix cache was hit.
# Prefix is not supported with sliding_window
prefix_cache_hit
=
(
computed_block_nums
is
not
None
and
len
(
computed_block_nums
)
>
0
and
self
.
sliding_window
is
None
and
is_prompt
)
# These are seq_len/context_len capped to the sliding window.
# They are passed to decode kernel.
# We still need original seq_len/context_len to compute slot
# mapping (and input position) below.
curr_sliding_window_blocks
=
None
sliding_seq_len
=
seq_len
sliding_context_len
=
context_len
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
if
(
self
.
sliding_window
is
not
None
and
not
is_prompt
):
curr_sliding_window_blocks
=
sliding_window_blocks
if
self
.
scheduler_config
.
use_v2_block_manager
:
# number of elements in last block
suff_len
=
seq_len
%
self
.
block_size
sliding_seq_len
=
min
(
seq_len
,
block_aligned_sliding_window
+
suff_len
)
if
suff_len
>
0
:
curr_sliding_window_blocks
+=
1
else
:
sliding_seq_len
=
min
(
seq_len
,
self
.
sliding_window
)
sliding_context_len
=
sliding_seq_len
-
1
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
if
prefix_cache_hit
:
assert
computed_block_nums
is
not
None
context_len
=
len
(
computed_block_nums
)
*
self
.
block_size
tokens
=
tokens
[
context_len
:]
# need to think what to set it to when we have both sliding
# window and prefix caching...
assert
self
.
sliding_window
is
None
,
\
"Prefix caching is not supported with sliding window"
sliding_context_len
=
context_len
if
self
.
attn_backend
.
get_name
()
==
"flash-attn"
:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
# TODO(woosuk): This is a temporary fix. We should
# provide a unified interface for different backends.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
else
:
block_table
=
computed_block_nums
elif
(
self
.
scheduler_config
.
chunked_prefill_enabled
or
not
is_prompt
):
if
seq_group_metadata
.
block_tables
is
not
None
:
# chunked prefill or decode
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
if
curr_sliding_window_blocks
is
not
None
:
block_table
=
block_table
[
-
curr_sliding_window_blocks
:]
else
:
# Only happens when memory profiling runs.
block_table
=
[]
else
:
# Prefill without chunked prefill or memory profiling.
block_table
=
[]
block_tables
.
append
(
block_table
)
seq_lens
.
append
(
sliding_seq_len
)
context_lens
.
append
(
sliding_context_len
)
query_len
=
sliding_seq_len
-
sliding_context_len
query_lens
.
append
(
query_len
)
input_tokens
.
extend
(
tokens
)
input_positions
.
extend
(
list
(
range
(
context_len
,
seq_len
)))
lora_id
=
seq_group_metadata
.
lora_int_id
prompt_adapter_id
=
seq_group_metadata
.
prompt_adapter_id
if
is_prompt
:
assert
len
(
seq_ids
)
==
1
num_prefills
+=
1
num_prefill_tokens
+=
len
(
tokens
)
decode_only
=
False
prefill_seq_lens
.
append
(
seq_len
)
else
:
assert
query_len
==
1
,
(
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
seq_len
,
context_len
,
query_len
))
num_decode_tokens
+=
query_len
decode_seq_lens
.
append
(
sliding_seq_len
)
if
lora_id
>
0
:
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
lora_index_mapping
+=
[
lora_id
]
*
query_len
lora_prompt_mapping
.
extend
(
[
lora_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
is
not
None
else
1
))
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
:
# Process multi-modal data
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
multi_modal_inputs_list
.
append
(
mm_kwargs
)
if
prompt_adapter_id
>
0
and
is_prompt
:
prompt_adapter_requests
.
add
(
seq_group_metadata
.
prompt_adapter_request
)
num_tokens
=
seq_group_metadata
.
\
prompt_adapter_num_virtual_tokens
pm
=
[
prompt_adapter_id
]
*
num_tokens
+
[
0
]
*
(
query_len
-
num_tokens
)
prompt_adapter_index_mapping
+=
pm
prompt_adapter_prompt_mapping
.
extend
(
[
prompt_adapter_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
is_profile_run
=
_is_block_tables_empty
(
seq_group_metadata
.
block_tables
)
if
is_profile_run
:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping
.
extend
([
_PAD_SLOT_ID
]
*
seq_len
)
continue
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with
# _PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
if
is_prompt
:
assert
self
.
scheduler_config
.
use_v2_block_manager
\
or
context_len
==
0
,
(
"Prefix caching is currently not supported with "
"sliding window attention in V1 block manager"
)
# It is an optimization. When it is decoding, it is always
# 0. When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx
=
max
(
0
,
query_len
-
self
.
sliding_window
)
for
i
in
range
(
context_len
,
seq_len
):
if
i
<
start_idx
:
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
# Prepare input tensors for flashinfer
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
seq_len
=
seq_data
.
get_len
()
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound
=
seq_len
//
self
.
block_size
+
1
\
if
seq_len
%
self
.
block_size
!=
0
\
else
seq_len
//
self
.
block_size
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
paged_kv_indptr
.
append
(
paged_kv_indptr
[
-
1
]
+
block_table_bound
)
last_page_len
=
seq_len
%
self
.
block_size
if
last_page_len
==
0
:
last_page_len
=
self
.
block_size
paged_kv_last_page_len
.
append
(
last_page_len
)
batch_size
=
len
(
input_tokens
)
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
decode_seq_lens
,
default
=
0
)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
use_captured_graph
=
(
decode_only
and
not
self
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
max_decode_seq_len
<=
self
.
max_seq_len_to_capture
)
if
use_captured_graph
:
graph_batch_size
=
_get_graph_batch_size
(
batch_size
)
assert
graph_batch_size
>=
batch_size
for
_
in
range
(
graph_batch_size
-
batch_size
):
input_tokens
.
append
(
0
)
input_positions
.
append
(
0
)
slot_mapping
.
append
(
_PAD_SLOT_ID
)
seq_lens
.
append
(
1
)
block_tables
.
append
([])
lora_index_mapping
.
append
(
0
)
prompt_adapter_index_mapping
.
append
(
0
)
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
last_paged_kv_indptr
=
paged_kv_indptr
[
-
1
]
paged_kv_indptr
.
append
(
last_paged_kv_indptr
)
paged_kv_last_page_len
.
append
(
0
)
batch_size
=
graph_batch_size
num_decode_tokens
=
batch_size
if
use_captured_graph
:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
self
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
self
.
device
)
else
:
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
block_tables
)
block_tables
=
make_tensor_with_pad
(
block_tables
,
max_len
=
max_block_table_len
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
self
.
device
,
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions_tensor
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping_tensor
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
logits_soft_cap
=
getattr
(
self
.
model_config
.
hf_config
,
'attn_logit_softcapping'
,
None
)
if
logits_soft_cap
is
not
None
and
self
.
attn_backend
.
get_name
(
)
!=
"flashinfer"
:
raise
ValueError
(
"Please use Flashinfer backend for models with"
"logits_soft_cap (i.e., Gemma-2)."
" Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER."
)
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
if
len
(
paged_kv_indptr
)
>
0
:
paged_kv_indices_tensor
=
torch
.
tensor
(
paged_kv_indices
,
device
=
'cpu'
,
dtype
=
torch
.
int
)
paged_kv_indptr_tensor
=
torch
.
tensor
(
paged_kv_indptr
,
device
=
'cpu'
,
dtype
=
torch
.
int
)
paged_kv_last_page_len_tensor
=
torch
.
tensor
(
paged_kv_last_page_len
,
device
=
'cpu'
,
dtype
=
torch
.
int
)
else
:
paged_kv_indices_tensor
=
None
paged_kv_indptr_tensor
=
None
paged_kv_last_page_len_tensor
=
None
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
kv_cache_dtype
,
self
.
model_config
.
dtype
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
max_prefill_seq_len
=
max_prefill_seq_len
,
block_tables
=
block_tables
,
paged_kv_indptr
=
paged_kv_indptr_tensor
,
paged_kv_indices
=
paged_kv_indices_tensor
,
paged_kv_last_page_len
=
paged_kv_last_page_len_tensor
,
num_qo_heads
=
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
),
num_kv_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
),
head_dim
=
self
.
model_config
.
get_head_size
(),
page_size
=
self
.
block_size
,
seq_start_loc
=
seq_start_loc
,
query_start_loc
=
query_start_loc
,
device
=
self
.
device
,
data_type
=
kv_cache_dtype
,
use_cuda_graph
=
use_captured_graph
,
logits_soft_cap
=
logits_soft_cap
)
else
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
if
self
.
lora_config
:
lora_mapping
=
LoRAMapping
(
lora_index_mapping
,
lora_prompt_mapping
,
)
else
:
lora_mapping
=
None
if
self
.
prompt_adapter_config
:
prompt_adapter_mapping
=
PromptAdapterMapping
(
prompt_adapter_index_mapping
,
prompt_adapter_prompt_mapping
,
)
else
:
prompt_adapter_mapping
=
None
multi_modal_kwargs
=
MultiModalInputs
.
batch
(
multi_modal_inputs_list
,
device
=
self
.
device
)
request_ids_to_seq_ids
=
{
seq_group_metadata
.
request_id
:
list
(
seq_group_metadata
.
seq_data
.
keys
())
for
seq_group_metadata
in
seq_group_metadata_list
}
return
self
.
_model_input_cls
(
input_tokens
=
input_tokens_tensor
,
input_positions
=
input_positions_tensor
,
attn_metadata
=
attn_metadata
,
seq_lens
=
seq_lens
,
query_lens
=
query_lens
,
lora_mapping
=
lora_mapping
,
lora_requests
=
lora_requests
,
multi_modal_kwargs
=
multi_modal_kwargs
,
request_ids_to_seq_ids
=
request_ids_to_seq_ids
,
finished_requests_ids
=
finished_requests_ids
,
prompt_adapter_mapping
=
prompt_adapter_mapping
,
prompt_adapter_requests
=
prompt_adapter_requests
,
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
def
profile_run
(
self
)
->
None
:
...
@@ -847,7 +820,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -847,7 +820,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dummy_lora_request
=
LoRARequest
(
dummy_lora_request
=
LoRARequest
(
lora_name
=
f
"warmup_
{
lora_id
}
"
,
lora_name
=
f
"warmup_
{
lora_id
}
"
,
lora_int_id
=
lora_id
,
lora_int_id
=
lora_id
,
lora_
local_
path
=
"/not/a/real/path"
,
lora_path
=
"/not/a/real/path"
,
)
)
self
.
lora_manager
.
add_dummy_lora
(
dummy_lora_request
,
self
.
lora_manager
.
add_dummy_lora
(
dummy_lora_request
,
rank
=
LORA_WARMUP_RANK
)
rank
=
LORA_WARMUP_RANK
)
...
@@ -1549,15 +1522,3 @@ def _get_graph_batch_size(batch_size: int) -> int:
...
@@ -1549,15 +1522,3 @@ def _get_graph_batch_size(batch_size: int) -> int:
else
:
else
:
return
((
batch_size
+
_BATCH_SIZE_ALIGNMENT
-
1
)
//
return
((
batch_size
+
_BATCH_SIZE_ALIGNMENT
-
1
)
//
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
def
_is_block_tables_empty
(
block_tables
:
Union
[
None
,
Dict
]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if
block_tables
is
None
:
return
True
if
isinstance
(
block_tables
,
dict
)
and
all
(
value
is
None
for
value
in
block_tables
.
values
()):
return
True
return
False
vllm/worker/model_runner_base.py
View file @
500b93c8
...
@@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
...
@@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import
torch
import
torch
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
...
@@ -113,6 +114,21 @@ class ModelRunnerInputBase(ABC):
...
@@ -113,6 +114,21 @@ class ModelRunnerInputBase(ABC):
raise
NotImplementedError
raise
NotImplementedError
class
ModelRunnerInputBuilderBase
(
ABC
,
Generic
[
T
]):
"""A builder to create ModelRunnerInputBase objects.
"""
@
abstractmethod
def
add_seq_group
(
self
,
seq_group_metadata
):
"""TBA"""
raise
NotImplementedError
@
abstractmethod
def
build
(
self
,
*
args
,
**
kwargs
)
->
T
:
"""Build metadata with on-device tensors."""
raise
NotImplementedError
class
ModelRunnerBase
(
ABC
,
Generic
[
T
]):
class
ModelRunnerBase
(
ABC
,
Generic
[
T
]):
"""
"""
Model runner interface that abstracts a particular hardware and/or type of
Model runner interface that abstracts a particular hardware and/or type of
...
@@ -148,7 +164,7 @@ class ModelRunnerBase(ABC, Generic[T]):
...
@@ -148,7 +164,7 @@ class ModelRunnerBase(ABC, Generic[T]):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
@
torch
.
inference_mode
()
@
current_platform
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
model_input
:
T
,
model_input
:
T
,
...
...
vllm/worker/neuron_model_runner.py
View file @
500b93c8
...
@@ -121,13 +121,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
...
@@ -121,13 +121,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
max_seq_len
=
max
(
seq_lens
)
max_seq_len
=
max
(
seq_lens
)
assert
max_seq_len
>
0
assert
max_seq_len
>
0
input_tokens
=
make_tensor_with_pad
(
input_tokens
,
input_tokens
=
make_tensor_with_pad
(
input_tokens
,
max_seq_len
,
pad
=
0
,
pad
=
0
,
max_len
=
max_seq_len
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
device
=
self
.
device
)
input_positions
=
make_tensor_with_pad
(
input_positions
,
input_positions
=
make_tensor_with_pad
(
input_positions
,
max_seq_len
,
pad
=
0
,
pad
=
0
,
max_len
=
max_seq_len
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
device
=
self
.
device
)
input_block_ids
=
torch
.
tensor
(
input_block_ids
,
input_block_ids
=
torch
.
tensor
(
input_block_ids
,
...
@@ -171,13 +171,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
...
@@ -171,13 +171,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_block_ids
.
append
(
block_table
[
0
])
input_block_ids
.
append
(
block_table
[
0
])
input_tokens
=
make_tensor_with_pad
(
input_tokens
,
input_tokens
=
make_tensor_with_pad
(
input_tokens
,
max_len
=
1
,
pad
=
0
,
pad
=
0
,
max_len
=
1
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
device
=
self
.
device
)
input_positions
=
make_tensor_with_pad
(
input_positions
,
input_positions
=
make_tensor_with_pad
(
input_positions
,
max_len
=
1
,
pad
=
0
,
pad
=
0
,
max_len
=
1
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
device
=
self
.
device
)
context_lens
=
torch
.
tensor
(
context_lens
,
context_lens
=
torch
.
tensor
(
context_lens
,
...
...
vllm/worker/tpu_model_runner.py
View file @
500b93c8
import
time
import
time
from
typing
import
List
,
Mapping
,
Optional
,
Tuple
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -12,12 +13,16 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
...
@@ -12,12 +13,16 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensors
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
MultiModalInputs
)
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceOutput
)
SequenceOutput
)
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
_add_attn_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
)
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -29,7 +34,44 @@ _ENABLE_TOP_P = False
...
@@ -29,7 +34,44 @@ _ENABLE_TOP_P = False
_MAX_NUM_SAMPLES
=
128
_MAX_NUM_SAMPLES
=
128
class
TPUModelRunner
:
@
dataclass
(
frozen
=
True
)
class
ModelInputForTPU
(
ModelRunnerInputBase
):
token_ids
:
torch
.
Tensor
position_ids
:
torch
.
Tensor
attn_metadata
:
AttentionMetadata
input_lens
:
torch
.
Tensor
t
:
torch
.
Tensor
p
:
torch
.
Tensor
num_samples
:
int
best_of
:
List
[
int
]
seq_groups
:
List
[
List
[
int
]]
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
tensor_dict
=
{
"token_ids"
:
self
.
token_ids
,
"position_ids"
:
self
.
position_ids
,
"input_lens"
:
self
.
input_lens
,
"t"
:
self
.
t
,
"p"
:
self
.
p
,
"num_samples"
:
self
.
num_samples
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
:
Type
[
"ModelInputForTPU"
],
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
"ModelInputForTPU"
:
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
attn_backend
,
tensor_dict
)
return
cls
(
**
tensor_dict
)
class
TPUModelRunner
(
ModelRunnerBase
[
ModelInputForTPU
]):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -68,10 +110,6 @@ class TPUModelRunner:
...
@@ -68,10 +110,6 @@ class TPUModelRunner:
False
,
False
,
)
)
# Multi-modal data support
self
.
multi_modal_input_mapper
=
MULTIMODAL_REGISTRY
\
.
create_input_mapper
(
self
.
model_config
)
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
device
=
self
.
device_config
.
device
self
.
device
=
self
.
device_config
.
device
...
@@ -85,6 +123,7 @@ class TPUModelRunner:
...
@@ -85,6 +123,7 @@ class TPUModelRunner:
multimodal_config
=
self
.
multimodal_config
,
multimodal_config
=
self
.
multimodal_config
,
lora_config
=
None
,
lora_config
=
None
,
)
)
model
=
model
.
eval
()
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
model
=
ModelWrapper
(
model
)
model
=
ModelWrapper
(
model
)
...
@@ -153,8 +192,8 @@ class TPUModelRunner:
...
@@ -153,8 +192,8 @@ class TPUModelRunner:
# Dummy run.
# Dummy run.
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
self
.
model
(
token_ids
,
position_ids
,
kv_caches
,
attn_metadata
,
self
.
model
(
token_ids
,
position_ids
,
attn_metadata
,
input_lens
,
t
,
p
,
input_
le
n
s
,
None
,
t
,
p
,
num_sampl
es
)
num_samp
les
,
kv_cach
es
)
def
warmup_model
(
def
warmup_model
(
self
,
self
,
...
@@ -183,7 +222,7 @@ class TPUModelRunner:
...
@@ -183,7 +222,7 @@ class TPUModelRunner:
# Decode
# Decode
start
=
time
.
time
()
start
=
time
.
time
()
seq_len
=
1
seq_len
=
1
batch_size
=
1
batch_size
=
8
# Must be in sync with _get_padded_batch_size()
while
True
:
while
True
:
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
is_prompt
=
False
)
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
is_prompt
=
False
)
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
...
@@ -199,14 +238,12 @@ class TPUModelRunner:
...
@@ -199,14 +238,12 @@ class TPUModelRunner:
def
_prepare_prompt
(
def
_prepare_prompt
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
]:
Mapping
[
str
,
BatchedTensors
]]:
assert
len
(
seq_group_metadata_list
)
>
0
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
List
[
int
]
]
=
[]
input_positions
:
List
[
int
]
=
[]
prompt_lens
:
List
[
int
]
=
[]
prompt_lens
:
List
[
int
]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
int
]
=
[]
multi_modal_inputs_list
:
List
[
MultiModalInputs
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
is_prompt
...
@@ -220,78 +257,62 @@ class TPUModelRunner:
...
@@ -220,78 +257,62 @@ class TPUModelRunner:
prompt_len
=
len
(
prompt_tokens
)
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
prompt_lens
.
append
(
prompt_len
)
input_tokens
.
app
end
(
prompt_tokens
)
input_tokens
.
ext
end
(
prompt_tokens
)
input_positions
.
app
end
(
list
(
range
(
prompt_len
)))
input_positions
.
ext
end
(
list
(
range
(
prompt_len
)))
assert
seq_group_metadata
.
block_tables
is
not
None
assert
seq_group_metadata
.
block_tables
is
not
None
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
slot_mapping
.
append
([])
for
i
in
range
(
prompt_len
):
for
i
in
range
(
prompt_len
):
block_number
=
block_table
[
i
//
self
.
block_size
]
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
[
-
1
].
append
(
slot
)
slot_mapping
.
append
(
slot
)
mm_data
=
seq_group_metadata
.
multi_modal_data
# Add paddings to EACH prompt to the smallest power of 2 that is
if
mm_data
:
# greater than or equal to the prompt length.
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
# We pad the seq_len to reduce the compilation overhead.
multi_modal_inputs_list
.
append
(
mm_kwargs
)
# We execute each prompt individually (i.e., with batch_size 1)
# because the FlashAttention kernel does not support ragged inputs.
# TODO(woosuk): Use SplashAttention to support ragged inputs.
padded_prompt_len
=
_get_padded_prefill_len
(
prompt_len
)
num_paddings
=
padded_prompt_len
-
prompt_len
input_tokens
+=
[
0
]
*
num_paddings
input_positions
+=
[
0
]
*
num_paddings
slot_mapping
+=
[
_PAD_SLOT_ID
]
*
num_paddings
assert
len
(
prompt_lens
)
>
0
assert
len
(
prompt_lens
)
>
0
num_prefills
=
len
(
prompt_lens
)
num_prefills
=
len
(
prompt_lens
)
num_prefill_tokens
=
sum
(
prompt_lens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
int32
,
# Add paddings to make the shape [batch_size, max_prompt_len] where
device
=
"cpu"
)
# max_prompt_len is smallest power of 2 that is greater than or equal
input_positions
=
torch
.
tensor
(
input_positions
,
# to the maximum prompt length.
dtype
=
torch
.
int32
,
# We need the 2D input shape because the Pallas FlashAttention kernel
device
=
"cpu"
)
# does not support packed 1D inputs.
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
# We pad the seq_len to powers of 2 to reduce the compilation overhead.
dtype
=
torch
.
int64
,
max_prompt_len
=
_get_padded_prefill_len
(
max
(
prompt_lens
))
device
=
"cpu"
)
input_tokens
=
make_tensor_with_pad
(
input_tokens
,
max_prompt_len
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
input_positions
=
make_tensor_with_pad
(
input_positions
,
max_prompt_len
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slot_mapping
=
make_tensor_with_pad
(
slot_mapping
,
max_prompt_len
,
pad
=
_PAD_SLOT_ID
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
prompt_lens
=
torch
.
tensor
(
prompt_lens
,
prompt_lens
=
torch
.
tensor
(
prompt_lens
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
# NOTE: This is not used.
num_prefill_tokens
=
0
,
# NOTE: This is not used.
num_decode_tokens
=
0
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
block_tables
=
None
,
block_tables
=
None
,
context_lens
=
None
,
context_lens
=
None
,
)
)
return
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
multi_modal_kwargs
=
MultiModalInputs
.
batch
(
multi_modal_inputs_list
,
device
=
self
.
device
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
,
multi_modal_kwargs
)
def
_prepare_decode
(
def
_prepare_decode
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
]:
Mapping
[
str
,
BatchedTensors
]]:
assert
len
(
seq_group_metadata_list
)
>
0
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
context_lens
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
multi_modal_inputs_list
:
List
[
MultiModalInputs
]
=
[]
batch_idx
=
0
batch_idx
=
0
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
...
@@ -317,11 +338,6 @@ class TPUModelRunner:
...
@@ -317,11 +338,6 @@ class TPUModelRunner:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
([
slot
])
slot_mapping
.
append
([
slot
])
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
:
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
multi_modal_inputs_list
.
append
(
mm_kwargs
)
batch_size
=
_get_padded_batch_size
(
batch_idx
)
batch_size
=
_get_padded_batch_size
(
batch_idx
)
num_paddings
=
batch_size
-
batch_idx
num_paddings
=
batch_size
-
batch_idx
input_tokens
=
input_tokens
+
[[
0
]]
*
num_paddings
input_tokens
=
input_tokens
+
[[
0
]]
*
num_paddings
...
@@ -331,22 +347,22 @@ class TPUModelRunner:
...
@@ -331,22 +347,22 @@ class TPUModelRunner:
input_tokens
=
torch
.
tensor
(
input_tokens
,
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
input_positions
=
torch
.
tensor
(
input_positions
,
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
device
=
"cpu"
)
context_lens
=
torch
.
tensor
(
context_lens
,
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
block_tables
=
torch
.
tensor
(
self
.
block_tables
[:
batch_size
],
block_tables
=
torch
.
tensor
(
self
.
block_tables
[:
batch_size
],
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
input_lens
=
torch
.
tensor
([
1
]
*
batch_size
,
input_lens
=
torch
.
tensor
([
1
]
*
batch_size
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
device
=
"cpu"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_prefill_tokens
=
0
,
...
@@ -355,12 +371,7 @@ class TPUModelRunner:
...
@@ -355,12 +371,7 @@ class TPUModelRunner:
block_tables
=
block_tables
,
block_tables
=
block_tables
,
context_lens
=
context_lens
,
context_lens
=
context_lens
,
)
)
return
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
multi_modal_kwargs
=
MultiModalInputs
.
batch
(
multi_modal_inputs_list
,
device
=
self
.
device
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
,
multi_modal_kwargs
)
def
_prepare_sample
(
def
_prepare_sample
(
self
,
self
,
...
@@ -412,16 +423,18 @@ class TPUModelRunner:
...
@@ -412,16 +423,18 @@ class TPUModelRunner:
t
+=
[
1.0
]
*
num_paddings
t
+=
[
1.0
]
*
num_paddings
p
+=
[
1.0
]
*
num_paddings
p
+=
[
1.0
]
*
num_paddings
t
=
torch
.
tensor
(
t
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
t
=
torch
.
tensor
(
t
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
p
=
torch
.
tensor
(
p
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
p
=
torch
.
tensor
(
p
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
return
t
,
p
,
best_of
return
t
,
p
,
best_of
def
_execut
e_model
(
def
prepar
e_model
_input
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
virtual_engine
:
int
=
0
,
)
->
List
[
CompletionSequenceGroupOutput
]:
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
# Prepare inputs.
)
->
ModelInputForTPU
:
del
finished_requests_ids
# Unused.
assert
virtual_engine
==
0
assert
len
(
seq_group_metadata_list
)
>
0
assert
len
(
seq_group_metadata_list
)
>
0
# NOTE: We assume that all sequences in the group are all prompts or
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
# all decodes.
...
@@ -430,16 +443,104 @@ class TPUModelRunner:
...
@@ -430,16 +443,104 @@ class TPUModelRunner:
inputs
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
inputs
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
else
:
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
padded_batch_size
=
inputs
[
0
].
shape
[
0
]
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
=
inputs
padded_batch_size
=
input_tokens
.
shape
[
0
]
t
,
p
,
best_of
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
t
,
p
,
best_of
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
padded_batch_size
)
padded_batch_size
)
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
# Execute the model.
seq_groups
=
[
next_token_ids
=
self
.
model
(
inputs
[
0
],
inputs
[
1
],
kv_caches
,
list
(
metadata
.
seq_data
.
keys
())
*
inputs
[
2
:],
t
,
p
,
num_samples
)
for
metadata
in
seq_group_metadata_list
# Retrieve the outputs to CPU.
]
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
return
ModelInputForTPU
(
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
,
t
,
p
,
num_samples
,
best_of
,
seq_groups
)
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
ModelInputForTPU
:
model_input
=
ModelInputForTPU
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
self
.
attn_backend
)
return
model_input
def
execute_model
(
self
,
model_input
:
ModelInputForTPU
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
)
->
List
[
SamplerOutput
]:
assert
intermediate_tensors
is
None
if
num_steps
>
1
:
raise
ValueError
(
"TPUModelRunner does not support multi-step execution."
)
def
_execute_model
(
*
args
,
clone
:
bool
=
False
)
->
torch
.
Tensor
:
"""Move input args from CPU to device and execute the model."""
def
_copy_to_device
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
clone
:
# When x is a slice of a CPU tensor, XLA may copy the whole
# original tensor to TPU instead of only copying x.
# To avoid this, we copy x after cloning.
x
=
x
.
clone
()
return
x
.
to
(
self
.
device
)
new_args
=
[]
for
arg
in
args
:
if
isinstance
(
arg
,
torch
.
Tensor
):
arg
=
_copy_to_device
(
arg
)
elif
isinstance
(
arg
,
AttentionMetadata
):
arg
.
slot_mapping
=
_copy_to_device
(
arg
.
slot_mapping
)
if
getattr
(
arg
,
"block_tables"
,
None
)
is
not
None
:
arg
.
block_tables
=
_copy_to_device
(
arg
.
block_tables
)
if
getattr
(
arg
,
"context_lens"
,
None
)
is
not
None
:
arg
.
context_lens
=
_copy_to_device
(
arg
.
context_lens
)
new_args
.
append
(
arg
)
return
self
.
model
(
*
new_args
)
num_prefills
=
model_input
.
attn_metadata
.
num_prefills
is_prompt
=
num_prefills
>
0
if
is_prompt
:
# NOTE(woosuk): Since the FlashAttention kernel does not support
# ragged inputs, we split the prompts into different batches and
# process them separately. This is a temporary hack that should be
# optimized by using SplashAttention.
next_token_ids
=
[]
orig_slot_mapping
=
model_input
.
attn_metadata
.
slot_mapping
batch_size
=
model_input
.
input_lens
.
shape
[
0
]
start_idx
=
0
for
i
in
range
(
batch_size
):
# Get the actual prefill_len.
prefill_len
=
model_input
.
input_lens
[
i
:
i
+
1
].
item
()
prefill_len
=
_get_padded_prefill_len
(
prefill_len
)
end_idx
=
start_idx
+
prefill_len
model_input
.
attn_metadata
.
slot_mapping
=
orig_slot_mapping
[
None
,
start_idx
:
end_idx
]
model_input
.
attn_metadata
.
num_prefills
=
1
output_token_ids
=
_execute_model
(
model_input
.
token_ids
[
None
,
start_idx
:
end_idx
],
model_input
.
position_ids
[
None
,
start_idx
:
end_idx
],
model_input
.
attn_metadata
,
model_input
.
input_lens
[
i
:
i
+
1
],
model_input
.
t
[
i
:
i
+
1
],
model_input
.
p
[
i
:
i
+
1
],
model_input
.
num_samples
,
kv_caches
,
clone
=
True
)
# Retrieve the outputs to CPU.
next_token_ids
+=
output_token_ids
.
cpu
().
tolist
()
start_idx
=
end_idx
else
:
# Execute the model.
output_token_ids
=
_execute_model
(
model_input
.
token_ids
,
model_input
.
position_ids
,
model_input
.
attn_metadata
,
model_input
.
input_lens
,
model_input
.
t
,
model_input
.
p
,
model_input
.
num_samples
,
kv_caches
)
# Retrieve the outputs to CPU.
next_token_ids
=
output_token_ids
.
cpu
().
tolist
()
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# The TPU backend does not reuse the sampler, since the TPU backend
...
@@ -447,13 +548,13 @@ class TPUModelRunner:
...
@@ -447,13 +548,13 @@ class TPUModelRunner:
zero_logprob
=
Logprob
(
0.0
)
zero_logprob
=
Logprob
(
0.0
)
batch_idx
=
0
batch_idx
=
0
sampler_outputs
=
[]
sampler_outputs
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group
in
model_input
.
seq_groups
:
seq_ids
=
seq_group
seq_outputs
=
[]
seq_outputs
=
[]
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
if
is_prompt
:
if
is_prompt
:
assert
len
(
seq_ids
)
==
1
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_id
=
seq_ids
[
0
]
for
i
in
range
(
best_of
[
batch_idx
]):
for
i
in
range
(
model_input
.
best_of
[
batch_idx
]):
next_token_id
=
next_token_ids
[
batch_idx
][
i
]
next_token_id
=
next_token_ids
[
batch_idx
][
i
]
seq_outputs
.
append
(
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
SequenceOutput
(
seq_id
,
next_token_id
,
...
@@ -468,35 +569,6 @@ class TPUModelRunner:
...
@@ -468,35 +569,6 @@ class TPUModelRunner:
batch_idx
+=
1
batch_idx
+=
1
sampler_outputs
.
append
(
sampler_outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
return
sampler_outputs
def
execute_model
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
num_steps
:
int
=
1
,
)
->
List
[
SamplerOutput
]:
if
num_steps
>
1
:
raise
ValueError
(
"TPUModelRunner does not support multi-step execution."
)
assert
seq_group_metadata_list
is
not
None
assert
len
(
seq_group_metadata_list
)
>
0
if
seq_group_metadata_list
[
0
].
is_prompt
:
# NOTE(woosuk): To reduce the compilation time, we only compile the
# prefill inputs with batch size 1. Because the scheduler is not
# aware of this limitation, we need to handle batch size > 1
# internally by calling the model multiple times and concatenating
# the outputs.
# FIXME(woosuk): This is a temporary hack to not change the existing
# scheduler. We need to fix this in the future.
sampler_outputs
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
sampler_outputs
+=
self
.
_execute_model
([
seq_group_metadata
],
kv_caches
)
else
:
sampler_outputs
=
self
.
_execute_model
(
seq_group_metadata_list
,
kv_caches
)
return
[
SamplerOutput
(
sampler_outputs
)]
return
[
SamplerOutput
(
sampler_outputs
)]
...
@@ -504,39 +576,37 @@ class ModelWrapper(nn.Module):
...
@@ -504,39 +576,37 @@ class ModelWrapper(nn.Module):
def
__init__
(
self
,
model
:
nn
.
Module
):
def
__init__
(
self
,
model
:
nn
.
Module
):
super
().
__init__
()
super
().
__init__
()
self
.
model
=
model
.
eval
()
self
.
model
=
model
def
forward
(
def
forward
(
self
,
self
,
token_ids
:
torch
.
Tensor
,
token_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
input_lens
:
torch
.
Tensor
,
input_lens
:
torch
.
Tensor
,
multi_modal_kwargs
:
Optional
[
Mapping
[
str
,
BatchedTensors
]],
t
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
num_samples
:
int
,
num_samples
:
int
,
kv_caches
:
List
[
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model and samples the next token.
"""Executes the forward pass of the model and samples the next token.
Args:
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
attn_metadata: The Pallas attention metadata.
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size].
input_lens: The actual input lengths of shape [batch_size].
multi_modal_kwargs: Keyword arguments from multi-modal data to
pass to the model.
t: The sampling temperature of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
p: The top-p probability of shape [batch_size].
num_samples: Number of samples to draw from each logits vector.
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
"""
"""
batch_size
,
seq_len
=
token_ids
.
shape
batch_size
,
seq_len
=
token_ids
.
shape
# Calculate the positions to sample from.
# Calculate the positions to sample from.
base
_indicies
=
torch
.
arange
(
start
_indicies
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
input_lens
.
device
)
*
seq_len
batch_size
,
dtype
=
torch
.
int32
,
device
=
input_lens
.
device
)
*
seq_len
logits_indices
=
base
_indicies
+
input_lens
-
1
logits_indices
=
start
_indicies
+
input_lens
-
1
# FIXME(woosuk): This is a temporary hack to avoid using the existing
# FIXME(woosuk): This is a temporary hack to avoid using the existing
# sampler and sampling metadata.
# sampler and sampling metadata.
...
@@ -573,7 +643,6 @@ class ModelWrapper(nn.Module):
...
@@ -573,7 +643,6 @@ class ModelWrapper(nn.Module):
position_ids
,
position_ids
,
kv_caches
,
kv_caches
,
attn_metadata
,
attn_metadata
,
**
(
multi_modal_kwargs
or
{}),
)
)
hidden_states
=
hidden_states
.
flatten
(
0
,
1
)
hidden_states
=
hidden_states
.
flatten
(
0
,
1
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
...
@@ -598,11 +667,10 @@ def _get_padded_prefill_len(x: int) -> int:
...
@@ -598,11 +667,10 @@ def _get_padded_prefill_len(x: int) -> int:
def
_get_padded_batch_size
(
batch_size
:
int
)
->
int
:
def
_get_padded_batch_size
(
batch_size
:
int
)
->
int
:
if
batch_size
<=
2
:
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
return
batch_size
# To meet this requirement in the simplest way, we set the minimal batch
elif
batch_size
<=
4
:
# size to 8.
return
4
if
batch_size
<=
8
:
elif
batch_size
<=
8
:
return
8
return
8
else
:
else
:
return
((
batch_size
+
15
)
//
16
)
*
16
return
((
batch_size
+
15
)
//
16
)
*
16
...
...
vllm/worker/tpu_worker.py
View file @
500b93c8
...
@@ -13,15 +13,16 @@ from vllm.distributed import (ensure_model_parallel_initialized,
...
@@ -13,15 +13,16 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
get_dtype_size
from
vllm.worker.tpu_model_runner
import
TPUModelRunner
from
vllm.worker.tpu_model_runner
import
TPUModelRunner
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
LoraNotSupportedWorkerBase
,
WorkerInput
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
TPUWorker
(
LoraNotSupportedWorkerBase
):
class
TPUWorker
(
LoraNotSupportedWorkerBase
,
LocalOrDistributedWorkerBase
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -57,14 +58,15 @@ class TPUWorker(LoraNotSupportedWorkerBase):
...
@@ -57,14 +58,15 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self
.
cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
self
.
cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
self
.
cache_config
.
cache_dtype
]
self
.
cache_config
.
cache_dtype
]
self
.
model_runner
=
TPUModelRunner
(
model_config
,
self
.
model_runner
:
TPUModelRunner
=
TPUModelRunner
(
parallel_config
,
model_config
,
scheduler_config
,
parallel_config
,
device_config
,
scheduler_config
,
cache_config
,
device_config
,
load_config
,
cache_config
,
multimodal_config
,
load_config
,
is_driver_worker
=
is_driver_worker
)
multimodal_config
,
is_driver_worker
=
is_driver_worker
)
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
...
@@ -98,8 +100,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
...
@@ -98,8 +100,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
# Use persistent cache to avoid XLA recompilation.
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): This does not completely eliminate the recompilation
# NOTE(woosuk): This does not completely eliminate the recompilation
# overhead because dynamo does not cache the compiled results.
# overhead because dynamo does not cache the compiled results.
xr
.
initialize_cache
(
os
.
path
.
expanduser
(
envs
.
VLLM_XLA_CACHE_PATH
),
xr
.
initialize_cache
(
envs
.
VLLM_XLA_CACHE_PATH
,
readonly
=
False
)
readonly
=
False
)
def
load_model
(
self
):
def
load_model
(
self
):
self
.
model_runner
.
load_model
()
self
.
model_runner
.
load_model
()
...
@@ -197,69 +198,70 @@ class TPUWorker(LoraNotSupportedWorkerBase):
...
@@ -197,69 +198,70 @@ class TPUWorker(LoraNotSupportedWorkerBase):
dtype_size
=
get_dtype_size
(
self
.
cache_dtype
)
dtype_size
=
get_dtype_size
(
self
.
cache_dtype
)
return
dtype_size
*
total
return
dtype_size
*
total
def
execute_model
(
@
property
def
do_metadata_broadcast
(
self
)
->
bool
:
# TODO(woosuk): Support TP.
return
False
@
property
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
# parallelism.
return
[
self
.
tpu_cache
]
def
prepare_worker_input
(
self
,
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
execute_model_req
:
ExecuteModelRequest
,
)
->
List
[
SamplerOutput
]:
)
->
WorkerInput
:
if
not
self
.
is_driver_worker
:
virtual_engine
=
execute_model_req
.
virtual_engine
self
.
_execute_model_non_driver
()
num_seq_groups
=
len
(
execute_model_req
.
seq_group_metadata_list
)
return
[]
blocks_to_swap_in
=
_make_src_to_dst
(
assert
execute_model_req
is
not
None
execute_model_req
.
blocks_to_swap_in
,
"cpu"
,
self
.
device
)
# Issue cache operations.
blocks_to_swap_out
=
_make_src_to_dst
(
self
.
cache_swap
(
execute_model_req
.
blocks_to_swap_out
,
self
.
device
,
"cpu"
)
execute_model_req
.
blocks_to_swap_in
,
blocks_to_copy
=
_make_src_to_dst
(
execute_model_req
.
blocks_to_copy
,
execute_model_req
.
blocks_to_swap_out
,
self
.
device
,
self
.
device
)
execute_model_req
.
blocks_to_copy
,
return
WorkerInput
(
num_seq_groups
=
num_seq_groups
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
virtual_engine
=
virtual_engine
,
)
)
# Run the model.
seq_group_metadata_list
=
execute_model_req
.
seq_group_metadata_list
def
execute_worker
(
self
,
worker_input
:
WorkerInput
)
->
None
:
assert
len
(
seq_group_metadata_list
)
>
0
virtual_engine
=
worker_input
.
virtual_engine
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
assert
virtual_engine
==
0
self
.
tpu_cache
)
return
output
def
cache_swap
(
self
,
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]],
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]],
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]],
)
->
None
:
attn_backend
=
self
.
model_runner
.
attn_backend
attn_backend
=
self
.
model_runner
.
attn_backend
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
if
blocks_to_swap_in
:
# Issue cache operations.
# Swap from CPU to TPU.
if
worker_input
.
blocks_to_swap_in
is
not
None
:
src_indices
,
dst_indices
=
_make_src_to_dst
(
src_indices
,
dst_indices
=
worker_input
.
blocks_to_swap_in
blocks_to_swap_in
,
"cpu"
,
self
.
device
)
if
src_indices
.
numel
()
>
0
:
for
i
in
range
(
num_layers
):
# Swap from CPU to TPU.
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
for
i
in
range
(
num_layers
):
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
k
=
cpu_k_cache
[:,
src_indices
].
to
(
self
.
device
)
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
v
=
cpu_v_cache
[:,
src_indices
].
to
(
self
.
device
)
k
=
cpu_k_cache
[:,
src_indices
].
to
(
self
.
device
)
_insert_kv
(
k
,
v
,
dst_indices
,
tpu_k_cache
,
tpu_v_cache
)
v
=
cpu_v_cache
[:,
src_indices
].
to
(
self
.
device
)
_insert_kv
(
k
,
v
,
dst_indices
,
tpu_k_cache
,
tpu_v_cache
)
if
blocks_to_swap_out
:
# Swap from TPU to CPU.
if
worker_input
.
blocks_to_swap_out
is
not
None
:
src_indices
,
dst_indices
=
_make_src_to_dst
(
src_indices
,
dst_indices
=
worker_input
.
blocks_to_swap_out
blocks_to_swap_out
,
self
.
device
,
"cpu"
)
if
src_indices
.
numel
()
>
0
:
for
i
in
range
(
num_layers
):
# Swap from TPU to CPU.
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
for
i
in
range
(
num_layers
):
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
cpu_k_cache
[:,
dst_indices
]
=
tpu_k_cache
[:,
src_indices
].
cpu
()
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
cpu_v_cache
[:,
dst_indices
]
=
tpu_v_cache
[:,
src_indices
].
cpu
()
cpu_k_cache
[:,
dst_indices
]
=
tpu_k_cache
[:,
src_indices
]
cpu_v_cache
[:,
dst_indices
]
=
tpu_v_cache
[:,
src_indices
]
if
blocks_to_copy
:
src_to_dst
=
_make_src_to_dst
(
blocks_to_copy
,
self
.
device
,
if
worker_input
.
blocks_to_copy
is
not
None
:
self
.
device
)
src_indices
,
dst_indices
=
worker_input
.
blocks_to_copy
attn_backend
.
copy_blocks
(
self
.
tpu_cache
,
src_to_dst
)
if
src_indices
.
numel
()
>
0
:
attn_backend
.
copy_blocks
(
self
.
tpu_cache
,
def
start_worker_execution_loop
(
self
)
->
None
:
(
src_indices
,
dst_indices
))
while
self
.
_execute_model_non_driver
():
pass
def
_execute_model_non_driver
(
self
)
->
bool
:
self
.
model_runner
.
execute_model
(
None
,
self
.
tpu_cache
)
return
True
def
_make_src_to_dst
(
def
_make_src_to_dst
(
...
...
vllm/worker/worker.py
View file @
500b93c8
...
@@ -105,7 +105,7 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -105,7 +105,7 @@ class Worker(LocalOrDistributedWorkerBase):
# initialize_cache.
# initialize_cache.
self
.
cache_engine
:
List
[
CacheEngine
]
self
.
cache_engine
:
List
[
CacheEngine
]
# Initialize gpu_cache as embedding models don't initialize kv_caches
# Initialize gpu_cache as embedding models don't initialize kv_caches
self
.
gpu_cache
:
Optional
[
List
[
List
[
torch
.
t
ensor
]]]
=
None
self
.
gpu_cache
:
Optional
[
List
[
List
[
torch
.
T
ensor
]]]
=
None
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
if
self
.
device_config
.
device
.
type
==
"cuda"
:
if
self
.
device_config
.
device
.
type
==
"cuda"
:
...
...
Prev
1
…
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment