Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
SIYIXNI
vllm
Commits
51679bbd
Commit
51679bbd
authored
Feb 01, 2024
by
zhuwenwen
Browse files
resolve merge confilcts
parents
4095d0db
1af090b5
Changes
170
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
982 additions
and
170 deletions
+982
-170
vllm/prefix.py
vllm/prefix.py
+87
-0
vllm/sampling_params.py
vllm/sampling_params.py
+2
-2
vllm/sequence.py
vllm/sequence.py
+29
-1
vllm/test_utils.py
vllm/test_utils.py
+38
-0
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+80
-0
vllm/utils.py
vllm/utils.py
+213
-7
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+12
-3
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+282
-131
vllm/worker/spec_decode/multi_step_worker.py
vllm/worker/spec_decode/multi_step_worker.py
+178
-0
vllm/worker/worker.py
vllm/worker/worker.py
+61
-26
No files found.
vllm/prefix.py
0 → 100644
View file @
51679bbd
from
typing
import
Dict
,
List
,
Sequence
,
Tuple
,
Optional
from
vllm.block
import
BlockTable
class
Prefix
:
"""Data and states associated with a prefix of prompt tokens for multiple
sequence groups.
NOTE: This feature is experimental and may be replaced with automatic
prefix caching in the future.
Args:
token_ids: The token ids of the prefix.
block_size: The block size of the executed model.
"""
def
__init__
(
self
,
token_ids
:
Sequence
[
int
],
block_size
:
int
,
)
->
None
:
self
.
token_ids
=
tuple
(
token_ids
)
self
.
block_size
=
block_size
self
.
length
=
len
(
token_ids
)
self
.
hash
=
hash
(
token_ids
)
assert
self
.
length
%
block_size
==
0
self
.
block_table
:
Optional
[
BlockTable
]
=
None
self
.
computed
=
False
@
property
def
allocated
(
self
)
->
bool
:
return
self
.
block_table
is
not
None
def
get_num_blocks
(
self
)
->
int
:
return
self
.
length
//
self
.
block_size
def
get_block_numbers
(
self
)
->
List
[
int
]:
return
[
block
.
block_number
for
block
in
self
.
block_table
]
def
get_length
(
self
)
->
int
:
return
self
.
length
def
__hash__
(
self
)
->
int
:
return
self
.
hash
def
set_block_table
(
self
,
block_table
:
BlockTable
)
->
None
:
self
.
block_table
=
block_table
.
copy
()
class
PrefixPool
:
"""Manages all the prompt prefixes.
NOTE: This feature is experimental and may be replaced with automatic
prefix caching in the future.
Args:
block_size: The block size of the executed model.
Attributes:
prefixes: A list of all the prefixes.
block_size: The block size of the executed model.
"""
def
__init__
(
self
,
block_size
:
int
,
)
->
None
:
# TODO(zhuohan): Add a capacity limit to the prefix pool.
self
.
prefixes
:
Dict
[
int
,
Prefix
]
=
{}
self
.
block_size
=
block_size
def
_truncate_token_ids
(
self
,
token_ids
:
Sequence
[
int
])
->
Tuple
[
int
]:
new_length
=
len
(
token_ids
)
//
self
.
block_size
*
self
.
block_size
return
tuple
(
token_ids
[:
new_length
])
def
add_or_get_prefix
(
self
,
token_ids
:
Sequence
[
int
],
lora_int_id
:
int
)
->
Optional
[
Prefix
]:
token_ids
=
self
.
_truncate_token_ids
(
token_ids
)
if
len
(
token_ids
)
==
0
:
# Prefix is empty.
return
None
prefix
=
Prefix
(
token_ids
,
self
.
block_size
)
prefix_hash
=
hash
((
prefix
,
lora_int_id
))
if
prefix_hash
not
in
self
.
prefixes
:
self
.
prefixes
[
prefix_hash
]
=
prefix
return
self
.
prefixes
[
prefix_hash
]
vllm/sampling_params.py
View file @
51679bbd
...
...
@@ -108,7 +108,7 @@ class SamplingParams:
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
include_stop_str_in_output
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
max_tokens
:
int
=
16
,
max_tokens
:
Optional
[
int
]
=
16
,
logprobs
:
Optional
[
int
]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
skip_special_tokens
:
bool
=
True
,
...
...
@@ -183,7 +183,7 @@ class SamplingParams:
if
not
0.0
<=
self
.
min_p
<=
1.0
:
raise
ValueError
(
"min_p must be in [0, 1], got "
f
"
{
self
.
min_p
}
."
)
if
self
.
max_tokens
<
1
:
if
self
.
max_tokens
is
not
None
and
self
.
max_tokens
<
1
:
raise
ValueError
(
f
"max_tokens must be at least 1, got
{
self
.
max_tokens
}
."
)
if
self
.
logprobs
is
not
None
and
self
.
logprobs
<
0
:
...
...
vllm/sequence.py
View file @
51679bbd
...
...
@@ -4,7 +4,9 @@ import enum
from
typing
import
Dict
,
List
,
Optional
,
Union
from
vllm.block
import
LogicalTokenBlock
from
vllm.prefix
import
Prefix
from
vllm.sampling_params
import
SamplingParams
from
vllm.lora.request
import
LoRARequest
PromptLogprobs
=
List
[
Optional
[
Dict
[
int
,
float
]]]
SampleLogprobs
=
List
[
Dict
[
int
,
float
]]
...
...
@@ -105,6 +107,7 @@ class Sequence:
prompt_token_ids: The token IDs of the prompt.
block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine.
lora_request: LoRA request.
"""
def
__init__
(
...
...
@@ -113,10 +116,12 @@ class Sequence:
prompt
:
str
,
prompt_token_ids
:
List
[
int
],
block_size
:
int
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
None
:
self
.
seq_id
=
seq_id
self
.
prompt
=
prompt
self
.
block_size
=
block_size
self
.
lora_request
=
lora_request
self
.
data
=
SequenceData
(
prompt_token_ids
)
self
.
output_logprobs
:
SampleLogprobs
=
[]
...
...
@@ -133,6 +138,10 @@ class Sequence:
# Input + output tokens
self
.
tokens
:
Optional
[
List
[
str
]]
=
None
@
property
def
lora_int_id
(
self
)
->
int
:
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
def
_append_logical_block
(
self
)
->
None
:
block
=
LogicalTokenBlock
(
block_number
=
len
(
self
.
logical_token_blocks
),
...
...
@@ -228,6 +237,8 @@ class SequenceGroup:
seqs: The list of sequences.
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group.
"""
def
__init__
(
...
...
@@ -236,11 +247,15 @@ class SequenceGroup:
seqs
:
List
[
Sequence
],
sampling_params
:
SamplingParams
,
arrival_time
:
float
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prefix
:
Optional
[
Prefix
]
=
None
,
)
->
None
:
self
.
request_id
=
request_id
self
.
seqs_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
self
.
sampling_params
=
sampling_params
self
.
arrival_time
=
arrival_time
self
.
lora_request
=
lora_request
self
.
prefix
:
Optional
[
Prefix
]
=
prefix
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
@
property
...
...
@@ -255,6 +270,10 @@ class SequenceGroup:
# We use the prompt of an arbitrary sequence.
return
next
(
iter
(
self
.
seqs_dict
.
values
())).
data
.
prompt_token_ids
@
property
def
lora_int_id
(
self
)
->
int
:
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
def
get_max_num_running_seqs
(
self
)
->
int
:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
...
...
@@ -327,7 +346,6 @@ class SequenceGroup:
class
SequenceGroupMetadata
:
"""Metadata for a sequence group. Used to create `InputMetadata`.
Args:
request_id: The ID of the request.
is_prompt: Whether the request is at prompt stage.
...
...
@@ -335,6 +353,8 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group.
"""
def
__init__
(
...
...
@@ -344,12 +364,20 @@ class SequenceGroupMetadata:
seq_data
:
Dict
[
int
,
SequenceData
],
sampling_params
:
SamplingParams
,
block_tables
:
Dict
[
int
,
List
[
int
]],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prefix
:
Optional
[
Prefix
]
=
None
,
)
->
None
:
self
.
request_id
=
request_id
self
.
is_prompt
=
is_prompt
self
.
seq_data
=
seq_data
self
.
sampling_params
=
sampling_params
self
.
block_tables
=
block_tables
self
.
lora_request
=
lora_request
self
.
prefix
=
prefix
@
property
def
lora_int_id
(
self
)
->
int
:
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
class
SequenceOutput
:
...
...
vllm/test_utils.py
0 → 100644
View file @
51679bbd
import
ray
from
vllm.config
import
ParallelConfig
from
vllm.utils
import
get_open_port
from
vllm.worker.worker
import
init_distributed_environment
def
init_test_distributed_environment
(
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
,
rank
:
int
,
distributed_init_port
:
str
,
)
->
None
:
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
,
tensor_parallel_size
,
worker_use_ray
=
True
)
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
init_distributed_environment
(
parallel_config
,
rank
,
distributed_init_method
)
def
multi_process_tensor_parallel
(
tensor_parallel_size
:
int
,
test_target
,
)
->
None
:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
ray
.
init
()
distributed_init_port
=
get_open_port
()
refs
=
[]
for
rank
in
range
(
tensor_parallel_size
):
refs
.
append
(
test_target
.
remote
(
tensor_parallel_size
,
rank
,
distributed_init_port
))
ray
.
get
(
refs
)
ray
.
shutdown
()
vllm/transformers_utils/tokenizer.py
View file @
51679bbd
...
...
@@ -4,6 +4,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast
)
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.utils
import
make_async
,
LRUCache
from
vllm.transformers_utils.tokenizers
import
*
logger
=
init_logger
(
__name__
)
...
...
@@ -65,6 +67,84 @@ def get_tokenizer(
return
tokenizer
def
get_lora_tokenizer
(
lora_request
:
LoRARequest
,
*
args
,
**
kwargs
)
->
Optional
[
PreTrainedTokenizer
]:
if
lora_request
is
None
:
return
None
try
:
tokenizer
=
get_tokenizer
(
lora_request
.
lora_local_path
,
*
args
,
**
kwargs
)
except
OSError
as
e
:
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
logger
.
warning
(
f
"No tokenizer found in
{
lora_request
.
lora_local_path
}
, "
"using base model tokenizer instead. "
f
"(Exception:
{
str
(
e
)
}
)"
)
tokenizer
=
None
return
tokenizer
get_lora_tokenizer_async
=
make_async
(
get_lora_tokenizer
)
class
TokenizerGroup
:
"""A group of tokenizers that can be used for LoRA adapters."""
def
__init__
(
self
,
tokenizer_id
:
str
,
enable_lora
:
bool
,
max_num_seqs
:
int
,
max_input_length
:
Optional
[
int
],
**
tokenizer_config
):
self
.
tokenizer_id
=
tokenizer_id
self
.
tokenizer_config
=
tokenizer_config
self
.
enable_lora
=
enable_lora
self
.
max_input_length
=
max_input_length
self
.
tokenizer
=
get_tokenizer
(
self
.
tokenizer_id
,
**
tokenizer_config
)
if
enable_lora
:
self
.
lora_tokenizers
=
LRUCache
(
capacity
=
max_num_seqs
)
else
:
self
.
lora_tokenizers
=
None
def
encode
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
List
[
int
]:
tokenizer
=
self
.
get_lora_tokenizer
(
lora_request
)
return
tokenizer
.
encode
(
prompt
)
async
def
encode_async
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
List
[
int
]:
tokenizer
=
await
self
.
get_lora_tokenizer_async
(
lora_request
)
return
tokenizer
.
encode
(
prompt
)
def
get_lora_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
])
->
"PreTrainedTokenizer"
:
if
not
lora_request
or
not
self
.
enable_lora
:
return
self
.
tokenizer
if
lora_request
.
lora_int_id
not
in
self
.
lora_tokenizers
:
tokenizer
=
(
get_lora_tokenizer
(
lora_request
,
**
self
.
tokenizer_config
)
or
self
.
tokenizer
)
self
.
lora_tokenizers
.
put
(
lora_request
.
lora_int_id
,
tokenizer
)
return
tokenizer
else
:
return
self
.
lora_tokenizers
.
get
(
lora_request
.
lora_int_id
)
async
def
get_lora_tokenizer_async
(
self
,
lora_request
:
Optional
[
LoRARequest
])
->
"PreTrainedTokenizer"
:
if
not
lora_request
or
not
self
.
enable_lora
:
return
self
.
tokenizer
if
lora_request
.
lora_int_id
not
in
self
.
lora_tokenizers
:
tokenizer
=
(
await
get_lora_tokenizer_async
(
lora_request
,
**
self
.
tokenizer_config
)
or
self
.
tokenizer
)
self
.
lora_tokenizers
.
put
(
lora_request
.
lora_int_id
,
tokenizer
)
return
tokenizer
else
:
return
self
.
lora_tokenizers
.
get
(
lora_request
.
lora_int_id
)
def
_convert_tokens_to_string_with_added_encoders
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
output_tokens
:
List
[
str
],
...
...
vllm/utils.py
View file @
51679bbd
import
enum
import
os
import
socket
import
subprocess
import
uuid
from
platform
import
uname
from
typing
import
List
from
typing
import
List
,
Tuple
,
Union
from
packaging.version
import
parse
,
Version
import
psutil
import
torch
import
asyncio
from
functools
import
partial
from
typing
import
(
Awaitable
,
Callable
,
TypeVar
,
)
from
collections
import
OrderedDict
from
typing
import
Any
,
Hashable
,
Optional
from
vllm._C
import
cuda_utils
from
vllm.logger
import
init_logger
T
=
TypeVar
(
"T"
)
logger
=
init_logger
(
__name__
)
STR_DTYPE_TO_TORCH_DTYPE
=
{
"half"
:
torch
.
half
,
"bfloat16"
:
torch
.
bfloat16
,
"float"
:
torch
.
float
,
"fp8_e5m2"
:
torch
.
uint8
,
}
class
Device
(
enum
.
Enum
):
...
...
@@ -30,16 +51,83 @@ class Counter:
self
.
counter
=
0
class
LRUCache
:
def
__init__
(
self
,
capacity
:
int
):
self
.
cache
=
OrderedDict
()
self
.
capacity
=
capacity
def
__contains__
(
self
,
key
:
Hashable
)
->
bool
:
return
key
in
self
.
cache
def
__len__
(
self
)
->
int
:
return
len
(
self
.
cache
)
def
__getitem__
(
self
,
key
:
Hashable
)
->
Any
:
return
self
.
get
(
key
)
def
__setitem__
(
self
,
key
:
Hashable
,
value
:
Any
)
->
None
:
self
.
put
(
key
,
value
)
def
__delitem__
(
self
,
key
:
Hashable
)
->
None
:
self
.
pop
(
key
)
def
touch
(
self
,
key
:
Hashable
)
->
None
:
self
.
cache
.
move_to_end
(
key
)
def
get
(
self
,
key
:
Hashable
,
default_value
:
Optional
[
Any
]
=
None
)
->
int
:
if
key
in
self
.
cache
:
value
=
self
.
cache
[
key
]
self
.
cache
.
move_to_end
(
key
)
else
:
value
=
default_value
return
value
def
put
(
self
,
key
:
Hashable
,
value
:
Any
)
->
None
:
self
.
cache
[
key
]
=
value
self
.
cache
.
move_to_end
(
key
)
self
.
_remove_old_if_needed
()
def
_on_remove
(
self
,
key
:
Hashable
,
value
:
Any
):
pass
def
remove_oldest
(
self
):
if
not
self
.
cache
:
return
key
,
value
=
self
.
cache
.
popitem
(
last
=
False
)
self
.
_on_remove
(
key
,
value
)
def
_remove_old_if_needed
(
self
)
->
None
:
while
len
(
self
.
cache
)
>
self
.
capacity
:
self
.
remove_oldest
()
def
pop
(
self
,
key
:
int
,
default_value
:
Optional
[
Any
]
=
None
)
->
Any
:
run_on_remove
=
key
in
self
.
cache
value
=
self
.
cache
.
pop
(
key
,
default_value
)
if
run_on_remove
:
self
.
_on_remove
(
key
,
value
)
return
value
def
clear
(
self
):
while
len
(
self
.
cache
)
>
0
:
self
.
remove_oldest
()
self
.
cache
.
clear
()
def
is_hip
()
->
bool
:
return
torch
.
version
.
hip
is
not
None
def
get_max_shared_memory_bytes
(
gpu
:
int
=
0
)
->
int
:
"""Returns the maximum shared memory per thread block in bytes."""
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
cudaDevAttrMaxSharedMemoryPerBlockOptin
=
97
if
not
is_hip
()
else
74
max_shared_mem
=
cuda_utils
.
get_device_attribute
(
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
gpu
)
# NOTE: This import statement should be executed lazily since
# the Neuron-X backend does not have the `cuda_utils` module.
from
vllm._C
import
cuda_utils
max_shared_mem
=
cuda_utils
.
get_max_shared_memory_per_block_device_attribute
(
gpu
)
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py will fail
assert
max_shared_mem
>
0
,
"max_shared_mem can not be zero"
return
int
(
max_shared_mem
)
...
...
@@ -57,8 +145,30 @@ def in_wsl() -> bool:
return
"microsoft"
in
" "
.
join
(
uname
()).
lower
()
def
make_async
(
func
:
Callable
[...,
T
])
->
Callable
[...,
Awaitable
[
T
]]:
"""Take a blocking function, and run it on in an executor thread.
This function prevents the blocking function from blocking the
asyncio event loop.
The code in this function needs to be thread safe.
"""
def
_async_wrapper
(
*
args
,
**
kwargs
)
->
asyncio
.
Future
:
loop
=
asyncio
.
get_event_loop
()
p_func
=
partial
(
func
,
*
args
,
**
kwargs
)
return
loop
.
run_in_executor
(
executor
=
None
,
func
=
p_func
)
return
_async_wrapper
def
get_ip
()
->
str
:
return
socket
.
gethostbyname
(
socket
.
gethostname
())
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_DGRAM
)
s
.
connect
((
"8.8.8.8"
,
80
))
# Doesn't need to be reachable
return
s
.
getsockname
()[
0
]
def
get_distributed_init_method
(
ip
:
str
,
port
:
int
)
->
str
:
return
f
"tcp://
{
ip
}
:
{
port
}
"
def
get_open_port
()
->
int
:
...
...
@@ -69,3 +179,99 @@ def get_open_port() -> int:
def
set_cuda_visible_devices
(
device_ids
:
List
[
int
])
->
None
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
","
.
join
(
map
(
str
,
device_ids
))
def
get_nvcc_cuda_version
()
->
Version
:
cuda_home
=
os
.
environ
.
get
(
'CUDA_HOME'
)
if
not
cuda_home
:
cuda_home
=
'/usr/local/cuda'
logger
.
info
(
f
'CUDA_HOME is not found in the environment. Using
{
cuda_home
}
as CUDA_HOME.'
)
nvcc_output
=
subprocess
.
check_output
([
cuda_home
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
nvcc_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
nvcc_cuda_version
=
parse
(
output
[
release_idx
].
split
(
","
)[
0
])
return
nvcc_cuda_version
def
_generate_random_fp8_e5m2
(
tensor
:
torch
.
tensor
,
low
:
float
,
high
:
float
,
)
->
None
:
# NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
# it may occur Inf or NaN if we directly use torch.randint
# to generate random data for fp8 data.
# For example, s.11111.00 in fp8e5m2 format repesents Inf.
# | E4M3 | E5M2
#-----|-------------|-------------------
# Inf | N/A | s.11111.00
# NaN | s.1111.111 | s.11111.{01,10,11}
from
vllm._C
import
cache_ops
tensor_tmp
=
torch
.
empty_like
(
tensor
,
dtype
=
torch
.
float16
)
tensor_tmp
.
uniform_
(
low
,
high
)
cache_ops
.
convert_fp8_e5m2
(
tensor_tmp
,
tensor
)
del
tensor_tmp
def
create_kv_caches_with_random
(
num_blocks
:
int
,
block_size
:
int
,
num_layers
:
int
,
num_heads
:
int
,
head_size
:
int
,
cache_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]],
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
seed
:
Optional
[
int
]
=
0
,
device
:
Optional
[
str
]
=
"cuda"
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
if
isinstance
(
cache_dtype
,
str
):
if
cache_dtype
==
"auto"
:
if
isinstance
(
model_dtype
,
str
):
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
model_dtype
]
elif
isinstance
(
model_dtype
,
torch
.
dtype
):
torch_dtype
=
model_dtype
else
:
raise
ValueError
(
f
"Invalid model dtype:
{
model_dtype
}
"
)
elif
cache_dtype
in
[
"half"
,
"bfloat16"
,
"float"
]:
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_dtype
]
elif
cache_dtype
==
"fp8_e5m2"
:
torch_dtype
=
torch
.
uint8
else
:
raise
ValueError
(
f
"Invalid kv cache dtype:
{
cache_dtype
}
"
)
elif
isinstance
(
cache_dtype
,
torch
.
dtype
):
torch_dtype
=
cache_dtype
else
:
raise
ValueError
(
f
"Invalid kv cache dtype:
{
cache_dtype
}
"
)
scale
=
head_size
**-
0.5
x
=
16
//
torch
.
tensor
([],
dtype
=
torch_dtype
).
element_size
()
key_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
//
x
,
block_size
,
x
)
key_caches
=
[]
for
_
in
range
(
num_layers
):
key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
torch_dtype
,
device
=
device
)
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
key_cache
.
uniform_
(
-
scale
,
scale
)
elif
cache_dtype
==
'fp8_e5m2'
:
_generate_random_fp8_e5m2
(
key_cache
,
-
scale
,
scale
)
key_caches
.
append
(
key_cache
)
value_cache_shape
=
(
num_blocks
,
num_heads
,
head_size
,
block_size
)
value_caches
=
[]
for
_
in
range
(
num_layers
):
value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
torch_dtype
,
device
=
device
)
if
cache_dtype
in
[
"auto"
,
"half"
,
"bfloat16"
,
"float"
]:
value_cache
.
uniform_
(
-
scale
,
scale
)
elif
cache_dtype
==
'fp8_e5m2'
:
_generate_random_fp8_e5m2
(
value_cache
,
-
scale
,
scale
)
value_caches
.
append
(
value_cache
)
return
key_caches
,
value_caches
vllm/worker/cache_engine.py
View file @
51679bbd
...
...
@@ -6,7 +6,7 @@ import torch
from
vllm._C
import
cache_ops
from
vllm.config
import
CacheConfig
,
ModelConfig
,
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
in_wsl
from
vllm.utils
import
in_wsl
,
STR_DTYPE_TO_TORCH_DTYPE
logger
=
init_logger
(
__name__
)
...
...
@@ -34,12 +34,16 @@ class CacheEngine:
self
.
head_size
=
model_config
.
get_head_size
()
self
.
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
self
.
num_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
dtype
=
model_config
.
dtype
self
.
block_size
=
cache_config
.
block_size
self
.
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
self
.
num_cpu_blocks
=
cache_config
.
num_cpu_blocks
if
cache_config
.
cache_dtype
==
"auto"
:
self
.
dtype
=
model_config
.
dtype
else
:
self
.
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
# Initialize the cache.
self
.
gpu_cache
=
self
.
allocate_gpu_cache
()
self
.
cpu_cache
=
self
.
allocate_cpu_cache
()
...
...
@@ -142,6 +146,7 @@ class CacheEngine:
@
staticmethod
def
get_cache_block_size
(
block_size
:
int
,
cache_dtype
:
str
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
)
->
int
:
...
...
@@ -152,7 +157,11 @@ class CacheEngine:
key_cache_block
=
block_size
*
num_heads
*
head_size
value_cache_block
=
key_cache_block
total
=
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype_size
=
_get_dtype_size
(
model_config
.
dtype
)
if
cache_dtype
==
"auto"
:
dtype
=
model_config
.
dtype
else
:
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_dtype
]
dtype_size
=
_get_dtype_size
(
dtype
)
return
dtype_size
*
total
...
...
vllm/worker/model_runner.py
View file @
51679bbd
import
time
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Set
,
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
vllm.config
import
ModelConfig
,
ParallelConfig
,
SchedulerConfig
from
vllm.config
import
ModelConfig
,
LoRAConfig
,
ParallelConfig
,
SchedulerConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
get_model
,
InputMetadata
,
SamplingMetadata
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast
,
broadcast_object_list
)
broadcast_tensor_dict
)
from
vllm.model_executor.parallel_utils
import
custom_all_reduce
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.utils
import
in_wsl
logger
=
init_logger
(
__name__
)
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
_PAD_SLOT_ID
=
-
1
LORA_WARMUP_RANK
=
8
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE
=
[
1
,
2
,
4
]
+
[
8
*
i
for
i
in
range
(
1
,
33
)]
...
...
@@ -30,19 +35,24 @@ class ModelRunner:
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
lora_config
:
Optional
[
LoRAConfig
],
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
):
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
lora_config
=
lora_config
self
.
is_driver_worker
=
is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self
.
sliding_window
=
(
model_config
.
get_sliding_window
()
if
model_config
is
not
None
else
None
)
self
.
device
=
torch
.
device
(
torch
.
cuda
.
current_device
())
self
.
model
=
None
self
.
block_size
=
None
# Set after initial profiling.
self
.
lora_manager
=
None
self
.
graph_runners
:
Dict
[
int
,
CUDAGraphRunner
]
=
{}
self
.
graph_memory_pool
=
None
# Set during graph capture.
...
...
@@ -59,9 +69,20 @@ class ModelRunner:
self
.
graph_block_tables
=
None
# Set after initial profiling.
# cache in_wsl result
self
.
in_wsl
=
in_wsl
()
self
.
kv_cache_dtype
=
kv_cache_dtype
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
self
.
model_config
)
self
.
model
=
get_model
(
self
.
model_config
,
self
.
lora_config
)
vocab_size
=
self
.
model
.
config
.
vocab_size
if
self
.
lora_config
:
self
.
lora_manager
=
LRUCacheWorkerLoRAManager
(
self
.
scheduler_config
.
max_num_seqs
,
self
.
scheduler_config
.
max_num_batched_tokens
+
self
.
scheduler_config
.
max_paddings
,
vocab_size
,
self
.
lora_config
,
self
.
device
)
self
.
model
=
self
.
lora_manager
.
create_lora_manager
(
self
.
model
)
def
set_block_size
(
self
,
block_size
:
int
)
->
None
:
self
.
block_size
=
block_size
...
...
@@ -74,13 +95,20 @@ class ModelRunner:
def
_prepare_prompt
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
,
List
[
int
]]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
,
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
],
Set
[
LoRARequest
]]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
lora_index_mapping
:
List
[
int
]
=
[]
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_requests
:
Set
[
LoRARequest
]
=
set
()
prompt_lens
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
subquery_lens
:
List
[
int
]
=
[]
prefix_block_tables
:
List
[
List
[
int
]]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
...
...
@@ -91,11 +119,34 @@ class ModelRunner:
prompt_tokens
=
seq_data
.
get_token_ids
()
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
prefix_len
=
0
prefix
=
seq_group_metadata
.
prefix
if
prefix
is
not
None
and
prefix
.
computed
:
prefix_len
=
prefix
.
get_length
()
prompt_tokens
=
prompt_tokens
[
prefix_len
:]
prefix_block_tables
.
append
(
prefix
.
get_block_numbers
())
else
:
prefix_block_tables
.
append
([])
# actual prompt lens
context_lens
.
append
(
prefix_len
)
subquery_lens
.
append
(
prompt_len
-
prefix_len
)
input_tokens
.
append
(
prompt_tokens
)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions
.
append
(
list
(
range
(
prompt_len
)))
input_positions
.
append
(
list
(
range
(
prefix_len
,
prefix_len
+
len
(
prompt_tokens
))))
lora_id
=
seq_group_metadata
.
lora_int_id
if
lora_id
>
0
:
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
lora_index_mapping
.
append
([
lora_id
]
*
prompt_len
)
lora_prompt_mapping
.
extend
(
[
lora_id
]
*
(
prompt_len
if
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
if
seq_group_metadata
.
block_tables
is
None
:
# During memory profiling, the block tables are not initialized
...
...
@@ -113,8 +164,11 @@ class ModelRunner:
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
assert
prefix_len
==
0
,
(
"Prefix caching is currently not supported with "
"sliding window attention"
)
start_idx
=
max
(
0
,
prompt_len
-
self
.
sliding_window
)
for
i
in
range
(
prompt_len
):
for
i
in
range
(
prefix_len
,
prompt_len
):
if
i
<
start_idx
:
slot_mapping
[
-
1
].
append
(
_PAD_SLOT_ID
)
continue
...
...
@@ -124,7 +178,7 @@ class ModelRunner:
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
[
-
1
].
append
(
slot
)
max_prompt_len
=
max
(
prompt
_lens
)
max_prompt_len
=
max
(
subquery
_lens
)
input_tokens
=
_make_tensor_with_pad
(
input_tokens
,
max_prompt_len
,
pad
=
0
,
...
...
@@ -137,32 +191,70 @@ class ModelRunner:
max_prompt_len
,
pad
=
_PAD_SLOT_ID
,
dtype
=
torch
.
long
)
lora_index_mapping
=
[
_pad_to_max
(
mapping
,
max_prompt_len
,
pad
=
0
)
for
mapping
in
lora_index_mapping
]
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
# Prepare prefix block tables
max_prompt_block_table_len
=
max
(
len
(
t
)
for
t
in
prefix_block_tables
)
block_tables
=
_make_tensor_with_pad
(
prefix_block_tables
,
max_len
=
max_prompt_block_table_len
,
pad
=
0
,
dtype
=
torch
.
int
,
)
start_loc_tensor
=
torch
.
arange
(
0
,
len
(
prompt_lens
)
*
max_prompt_len
,
max_prompt_len
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
prompt_lens_tensor
=
torch
.
tensor
(
prompt_lens
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
input_metadata
=
InputMetadata
(
is_prompt
=
True
,
slot_mapping
=
slot_mapping
,
prompt_lens
=
prompt_lens_tensor
,
max_seq_len
=
max_prompt_len
,
start_loc
=
start_loc_tensor
,
max_context_len
=
None
,
context_lens
=
N
on
e
,
block_tables
=
None
,
context_lens
=
c
on
text_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
False
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
return
input_tokens
,
input_positions
,
input_metadata
,
prompt_lens
return
(
input_tokens
,
input_positions
,
input_metadata
,
prompt_lens
,
subquery_lens
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
)
def
_prepare_decode
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
,
List
[
int
],
List
[
int
],
Set
[
LoRARequest
]]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
context_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
lora_index_mapping
:
List
[
int
]
=
[]
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_requests
:
Set
[
LoRARequest
]
=
set
()
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
not
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
lora_id
=
seq_group_metadata
.
lora_int_id
if
lora_id
>
0
:
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
...
...
@@ -181,6 +273,8 @@ class ModelRunner:
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
([
slot
])
lora_index_mapping
.
append
([
lora_id
])
lora_prompt_mapping
.
append
(
lora_id
)
if
self
.
sliding_window
is
not
None
:
sliding_window_blocks
=
(
self
.
sliding_window
//
...
...
@@ -235,28 +329,39 @@ class ModelRunner:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
"cuda"
)
else
:
max_block_table_len
=
max
(
len
(
block_table
)
for
block_table
in
block_tables
)
block_tables
=
_make_tensor_with_pad
(
block_tables
,
max_len
=
max_
context
_len
,
max_len
=
max_
block_table
_len
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
)
lora_index_mapping
=
[
_pad_to_max
(
mapping
,
1
,
pad
=
0
)
for
mapping
in
lora_index_mapping
]
input_metadata
=
InputMetadata
(
is_prompt
=
False
,
slot_mapping
=
slot_mapping
,
prompt_lens
=
None
,
max_seq_len
=
None
,
start_loc
=
None
,
max_context_len
=
max_context_len
,
context_lens
=
context_lens
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
return
input_tokens
,
input_positions
,
input_metadata
return
input_tokens
,
input_positions
,
input_metadata
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
def
_prepare_sample
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt_lens
:
List
[
int
],
subquery_lens
:
Optional
[
List
[
int
]],
)
->
SamplingMetadata
:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
...
...
@@ -264,7 +369,7 @@ class ModelRunner:
categorized_sample_indices
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
max_
prompt
_len
=
max
(
prompt
_lens
)
if
prompt
_lens
else
1
max_
subquery
_len
=
max
(
subquery
_lens
)
if
subquery
_lens
else
1
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
...
...
@@ -272,10 +377,11 @@ class ModelRunner:
if
seq_group_metadata
.
is_prompt
:
assert
len
(
seq_ids
)
==
1
prompt_len
=
prompt_lens
[
i
]
assert
subquery_lens
is
not
None
subquery_len
=
subquery_lens
[
i
]
if
sampling_params
.
prompt_logprobs
is
not
None
:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx
+=
prompt
_len
-
1
categorized_sample_indices_start_idx
+=
subquery
_len
-
1
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
(
...
...
@@ -285,10 +391,10 @@ class ModelRunner:
if
sampling_params
.
prompt_logprobs
is
not
None
:
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
prompt
_len
-
1
))
selected_token_start_idx
+
subquery
_len
-
1
))
selected_token_indices
.
append
(
selected_token_start_idx
+
prompt
_len
-
1
)
selected_token_start_idx
+=
max_
prompt
_len
subquery
_len
-
1
)
selected_token_start_idx
+=
max_
subquery
_len
else
:
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
...
...
@@ -326,115 +432,86 @@ class ModelRunner:
def
prepare_input_tensors
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
,
SamplingMetadata
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
,
SamplingMetadata
,
Set
[
int
],
LoRAMapping
]:
if
self
.
is_driver_worker
:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
input_metadata
,
prompt_lens
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
(
input_tokens
,
input_positions
,
input_metadata
,
prompt_lens
,
subquery_lens
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
input_metadata
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
(
input_tokens
,
input_positions
,
input_metadata
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
prompt_lens
=
[]
subquery_lens
=
None
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
def
get_size_or_none
(
x
:
Optional
[
torch
.
Tensor
]):
return
x
.
size
()
if
x
is
not
None
else
None
# Broadcast the input data. For input tensors, we first broadcast
# its shape and then broadcast the tensor to avoid high
# serialization cost.
py_data
=
{
"input_tokens_size"
:
input_tokens
.
size
(),
"input_positions_size"
:
input_positions
.
size
(),
"is_prompt"
:
input_metadata
.
is_prompt
,
"slot_mapping_size"
:
get_size_or_none
(
input_metadata
.
slot_mapping
),
"max_context_len"
:
input_metadata
.
max_context_len
,
"context_lens_size"
:
get_size_or_none
(
input_metadata
.
context_lens
),
"block_tables_size"
:
get_size_or_none
(
input_metadata
.
block_tables
),
"use_cuda_graph"
:
input_metadata
.
use_cuda_graph
,
"selected_token_indices_size"
:
sampling_metadata
.
selected_token_indices
.
size
(),
prompt_lens
,
subquery_lens
)
if
self
.
lora_config
:
flat_lora_index_mapping
=
[
item
for
sublist
in
lora_index_mapping
for
item
in
sublist
]
lora_mapping
=
LoRAMapping
(
flat_lora_index_mapping
,
lora_prompt_mapping
,
)
else
:
lora_mapping
=
None
# Broadcast the metadata.
metadata_dict
=
{
"input_tokens"
:
input_tokens
,
"input_positions"
:
input_positions
,
"is_prompt"
:
input_metadata
.
is_prompt
,
"slot_mapping"
:
input_metadata
.
slot_mapping
,
"prompt_lens"
:
input_metadata
.
prompt_lens
,
"max_seq_len"
:
input_metadata
.
max_seq_len
,
"start_loc"
:
input_metadata
.
start_loc
,
"max_context_len"
:
input_metadata
.
max_context_len
,
"context_lens"
:
input_metadata
.
context_lens
,
"block_tables"
:
input_metadata
.
block_tables
,
"use_cuda_graph"
:
input_metadata
.
use_cuda_graph
,
"kv_cache_dtype"
:
input_metadata
.
kv_cache_dtype
,
"selected_token_indices"
:
sampling_metadata
.
selected_token_indices
,
"lora_requests"
:
lora_requests
,
"lora_mapping"
:
lora_mapping
,
}
broadcast_object_list
([
py_data
],
src
=
0
)
# TODO(zhuohan): Combine the broadcasts or set async_op=True.
broadcast
(
input_tokens
,
src
=
0
)
broadcast
(
input_positions
,
src
=
0
)
if
input_metadata
.
slot_mapping
is
not
None
:
broadcast
(
input_metadata
.
slot_mapping
,
src
=
0
)
if
input_metadata
.
context_lens
is
not
None
:
broadcast
(
input_metadata
.
context_lens
,
src
=
0
)
if
input_metadata
.
block_tables
is
not
None
:
broadcast
(
input_metadata
.
block_tables
,
src
=
0
)
broadcast
(
sampling_metadata
.
selected_token_indices
,
src
=
0
)
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
else
:
receving_list
=
[
None
]
broadcast_object_list
(
receving_list
,
src
=
0
)
py_data
=
receving_list
[
0
]
input_tokens
=
torch
.
empty
(
*
py_data
[
"input_tokens_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
input_tokens
,
src
=
0
)
input_positions
=
torch
.
empty
(
*
py_data
[
"input_positions_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
input_positions
,
src
=
0
)
if
py_data
[
"slot_mapping_size"
]
is
not
None
:
slot_mapping
=
torch
.
empty
(
*
py_data
[
"slot_mapping_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
slot_mapping
,
src
=
0
)
else
:
slot_mapping
=
None
if
py_data
[
"context_lens_size"
]
is
not
None
:
context_lens
=
torch
.
empty
(
*
py_data
[
"context_lens_size"
],
dtype
=
torch
.
int
,
device
=
"cuda"
)
broadcast
(
context_lens
,
src
=
0
)
else
:
context_lens
=
None
if
py_data
[
"block_tables_size"
]
is
not
None
:
block_tables
=
torch
.
empty
(
*
py_data
[
"block_tables_size"
],
dtype
=
torch
.
int
,
device
=
"cuda"
)
broadcast
(
block_tables
,
src
=
0
)
else
:
block_tables
=
None
selected_token_indices
=
torch
.
empty
(
*
py_data
[
"selected_token_indices_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
selected_token_indices
,
src
=
0
)
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
input_tokens
=
metadata_dict
[
"input_tokens"
]
input_positions
=
metadata_dict
[
"input_positions"
]
lora_mapping
=
metadata_dict
[
"lora_mapping"
]
lora_requests
=
metadata_dict
[
"lora_requests"
]
input_metadata
=
InputMetadata
(
is_prompt
=
py_data
[
"is_prompt"
],
slot_mapping
=
slot_mapping
,
max_context_len
=
py_data
[
"max_context_len"
],
context_lens
=
context_lens
,
block_tables
=
block_tables
,
use_cuda_graph
=
py_data
[
"use_cuda_graph"
],
is_prompt
=
metadata_dict
[
"is_prompt"
],
slot_mapping
=
metadata_dict
[
"slot_mapping"
],
prompt_lens
=
metadata_dict
[
"prompt_lens"
],
max_seq_len
=
metadata_dict
[
"max_seq_len"
],
start_loc
=
metadata_dict
[
"start_loc"
],
max_context_len
=
metadata_dict
[
"max_context_len"
],
context_lens
=
metadata_dict
[
"context_lens"
],
block_tables
=
metadata_dict
[
"block_tables"
],
use_cuda_graph
=
metadata_dict
[
"use_cuda_graph"
],
kv_cache_dtype
=
metadata_dict
[
"kv_cache_dtype"
],
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_data
=
None
,
prompt_lens
=
None
,
selected_token_indices
=
selected_token_indices
,
selected_token_indices
=
metadata_dict
[
"
selected_token_indices
"
]
,
categorized_sample_indices
=
None
,
perform_sampling
=
False
,
)
return
input_tokens
,
input_positions
,
input_metadata
,
sampling_metadata
return
input_tokens
,
input_positions
,
input_metadata
,
sampling_metadata
,
lora_requests
,
lora_mapping
@
torch
.
inference_mode
()
def
execute_model
(
...
...
@@ -442,8 +519,12 @@ class ModelRunner:
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
Optional
[
SamplerOutput
]:
input_tokens
,
input_positions
,
input_metadata
,
sampling_metadata
=
(
input_tokens
,
input_positions
,
input_metadata
,
sampling_metadata
,
lora_requests
,
lora_mapping
=
(
self
.
prepare_input_tensors
(
seq_group_metadata_list
))
if
self
.
lora_config
:
self
.
set_active_loras
(
lora_requests
,
lora_mapping
)
# Execute the model.
if
input_metadata
.
use_cuda_graph
:
graph_batch_size
=
input_tokens
.
shape
[
0
]
...
...
@@ -472,6 +553,28 @@ class ModelRunner:
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
# passed in, which contains a lora from the lora warmup path.
dummy_lora_requests
=
[]
dummy_lora_requests_per_seq
=
[]
if
self
.
lora_config
:
for
idx
in
range
(
self
.
lora_config
.
max_loras
):
lora_id
=
idx
+
1
dummy_lora_request
=
LoRARequest
(
lora_name
=
f
"warmup_
{
lora_id
}
"
,
lora_int_id
=
lora_id
,
lora_local_path
=
"/not/a/real/path"
,
)
self
.
lora_manager
.
add_dummy_lora
(
dummy_lora_request
,
rank
=
LORA_WARMUP_RANK
)
dummy_lora_requests
.
append
(
dummy_lora_request
)
dummy_lora_requests_per_seq
=
[
dummy_lora_requests
[
idx
%
len
(
dummy_lora_requests
)]
for
idx
in
range
(
max_num_seqs
)
]
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
...
...
@@ -485,6 +588,8 @@ class ModelRunner:
seq_data
=
{
group_id
:
seq_data
},
sampling_params
=
sampling_params
,
block_tables
=
None
,
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
if
dummy_lora_requests_per_seq
else
None
,
)
seqs
.
append
(
seq
)
...
...
@@ -495,6 +600,32 @@ class ModelRunner:
torch
.
cuda
.
synchronize
()
return
def
remove_all_loras
(
self
)
->
bool
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
remove_all_loras
()
def
set_active_loras
(
self
,
lora_requests
:
List
[
LoRARequest
],
lora_mapping
:
LoRAMapping
)
->
None
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
self
.
lora_manager
.
set_active_loras
(
lora_requests
,
lora_mapping
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
add_lora
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
remove_lora
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
list_loras
()
@
torch
.
inference_mode
()
def
capture_model
(
self
,
kv_caches
:
List
[
KVCache
])
->
None
:
assert
not
self
.
model_config
.
enforce_eager
...
...
@@ -504,7 +635,9 @@ class ModelRunner:
"use '--enforce-eager' in the CLI."
)
logger
.
info
(
"CUDA graphs can take additional 1~3 GiB memory per GPU. "
"If you are running out of memory, consider decreasing "
"`gpu_memory_utilization` or enforcing eager mode."
)
"`gpu_memory_utilization` or enforcing eager mode. "
"You can also reduce the `max_num_seqs` as needed "
"to decrease memory usage."
)
start_time
=
time
.
perf_counter
()
# Prepare dummy inputs. These will be reused for all batch sizes.
...
...
@@ -517,29 +650,47 @@ class ModelRunner:
context_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
).
cuda
()
block_tables
=
torch
.
from_numpy
(
self
.
graph_block_tables
).
cuda
()
graph_batch_size
=
_get_graph_batch_size
(
self
.
scheduler_config
.
max_num_seqs
)
batch_size_capture_list
=
[
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
]
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for
batch_size
in
reversed
(
_BATCH_SIZES_TO_CAPTURE
):
# Create dummy input_metadata.
input_metadata
=
InputMetadata
(
is_prompt
=
False
,
slot_mapping
=
slot_mapping
[:
batch_size
],
max_context_len
=
self
.
max_context_len_to_capture
,
context_lens
=
context_lens
[:
batch_size
],
block_tables
=
block_tables
[:
batch_size
],
use_cuda_graph
=
True
,
)
graph_runner
=
CUDAGraphRunner
(
self
.
model
)
graph_runner
.
capture
(
input_tokens
[:
batch_size
],
input_positions
[:
batch_size
],
kv_caches
,
input_metadata
,
memory_pool
=
self
.
graph_memory_pool
,
)
self
.
graph_memory_pool
=
graph_runner
.
graph
.
pool
()
self
.
graph_runners
[
batch_size
]
=
graph_runner
with
custom_all_reduce
.
capture
():
for
batch_size
in
reversed
(
batch_size_capture_list
):
# Create dummy input_metadata.
input_metadata
=
InputMetadata
(
is_prompt
=
False
,
slot_mapping
=
slot_mapping
[:
batch_size
],
prompt_lens
=
None
,
max_seq_len
=
None
,
start_loc
=
None
,
max_context_len
=
self
.
max_context_len_to_capture
,
context_lens
=
context_lens
[:
batch_size
],
block_tables
=
block_tables
[:
batch_size
],
use_cuda_graph
=
True
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
if
self
.
lora_config
:
lora_mapping
=
LoRAMapping
(
[
0
]
*
batch_size
,
[
0
]
*
batch_size
,
)
self
.
set_active_loras
(
set
(),
lora_mapping
)
graph_runner
=
CUDAGraphRunner
(
self
.
model
)
graph_runner
.
capture
(
input_tokens
[:
batch_size
],
input_positions
[:
batch_size
],
kv_caches
,
input_metadata
,
memory_pool
=
self
.
graph_memory_pool
,
)
self
.
graph_memory_pool
=
graph_runner
.
graph
.
pool
()
self
.
graph_runners
[
batch_size
]
=
graph_runner
end_time
=
time
.
perf_counter
()
elapsed_time
=
end_time
-
start_time
...
...
vllm/worker/spec_decode/multi_step_worker.py
0 → 100644
View file @
51679bbd
from
typing
import
List
,
Dict
import
copy
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.worker.worker
import
Worker
class
MultiStepWorker
(
Worker
):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
by invoking the scheduler less.
The MultiStepWorker does not support cache swap operations, or beam search.
Cache swap operations do not require large modifications. On the other hand,
beam search requires memory allocations during sequence forks and thus
requires more thought for MultiStepWorker support.
"""
@
torch
.
inference_mode
()
def
execute_model_multi_step
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_steps
:
int
,
)
->
List
[
SamplerOutput
]:
"""Run the model forward pass num_steps times. Returns the list of
sampler output, one per model forward pass.
"""
self
.
_raise_if_unsupported
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
# Shallow copy input data so modifications (such as appending tokens)
# do not cause side-effects.
copied_seq_group_metadata_list
=
self
.
_shallow_copy_inputs
(
seq_group_metadata_list
)
# Assert enough KV space for num_steps tokens per sequence.
self
.
_assert_enough_kv_space
(
seq_group_metadata_list
,
num_steps
)
# Run model num_steps times.
model_outputs
=
[]
for
_
in
range
(
num_steps
):
model_output
=
super
().
execute_model
(
seq_group_metadata_list
=
copied_seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
self
.
_append_new_tokens
(
model_output
,
copied_seq_group_metadata_list
)
model_outputs
.
append
(
model_output
)
return
model_outputs
def
_append_new_tokens
(
self
,
model_output
:
SamplerOutput
,
seq_group_metadata_list
:
SequenceGroupMetadata
)
->
None
:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes.
"""
for
seq_group_metadata
,
sequence_group_outputs
in
zip
(
seq_group_metadata_list
,
model_output
):
seq_group_metadata
.
is_prompt
=
False
for
seq_output
in
sequence_group_outputs
.
samples
:
# NOTE: Beam search is not supported, so we can assume that
# parent_seq_id == seq_id.
seq
=
seq_group_metadata
.
seq_data
[
seq_output
.
parent_seq_id
]
token_id
=
seq_output
.
output_token
token_logprob
=
seq_output
.
logprobs
[
token_id
]
seq
.
append_token_id
(
token_id
,
token_logprob
)
def
_shallow_copy_inputs
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
)
->
List
[
SequenceGroupMetadata
]:
"""Copy input data structures to remove side-effects when input data
structures are shared with other modules.
The multi-step worker must be able to append tokens to sequences after
a forward pass. This necessitates modification of the data structures
used by the worker. Since these data structures are shared with other
parts of vLLM, like the scheduler, we must take care not to introduce
unexpected side-effects.
When Ray is used to orchestrate worker processes (such as when the
tensor-parallel degree is >1), this is not a problem because the input
datastructures will be serialized and created anew in the worker
process.
However, when Ray is not used to orchestrate the worker processes (such
as when the tensor-parallel degree is 1), this is a problem. We avoid
the problem by shallow-copying the input datastructures (specifically,
the parts that will change in multiple steps).
"""
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
# append tokens and change is_prompt without external side-effects.
new_seq_group_metadata_list
=
[]
for
old_seq_group_metadata
in
seq_group_metadata_list
:
# We must shallow-copy seq_group_metadata as is_prompt could change.
seq_group_metadata
=
copy
.
copy
(
old_seq_group_metadata
)
new_seq_group_metadata_list
.
append
(
seq_group_metadata
)
# We must shallow-copy seq_data as we will append token ids
new_seq_data
=
{}
for
seq_id
,
old_seq_data
in
seq_group_metadata
.
seq_data
.
items
():
new_seq_data
[
seq_id
]
=
copy
.
copy
(
old_seq_data
)
new_seq_data
[
seq_id
].
output_token_ids
=
old_seq_data
.
output_token_ids
[:]
seq_group_metadata
.
seq_data
=
new_seq_data
return
new_seq_group_metadata_list
def
_assert_enough_kv_space
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
num_steps
:
int
)
->
None
:
"""Assert there are enough physical blocks per sequence to store the
current KV plus additional KV from num_steps tokens.
"""
assert
self
.
model_runner
.
block_size
is
not
None
for
seq_group_metadata
in
seq_group_metadata_list
:
# Only one seq_id is guaranteed because there is no beam search.
seq_id
=
list
(
seq_group_metadata
.
seq_data
.
keys
())[
0
]
seq
=
seq_group_metadata
.
seq_data
[
seq_id
]
# After num_steps, the seq len will be the current seq len
# plus one token per step.
final_seq_len
=
seq
.
get_len
()
+
num_steps
# We will have final_seq_len - 1 KV because vLLM saves KV for a
# token in the iteration after the token was generated.
required_num_kv_slots
=
final_seq_len
-
1
# The allocated number of kv slots is the number of allocated blocks
# times the number of slots of block.
number_physical_blocks
=
len
(
seq_group_metadata
.
block_tables
[
seq_id
])
allocated_kv_slots
=
(
number_physical_blocks
*
self
.
model_runner
.
block_size
)
if
required_num_kv_slots
>
allocated_kv_slots
:
request_id
=
seq_group_metadata
.
request_id
raise
ValueError
(
"The worker attempted to run "
f
"
{
num_steps
}
times but found insufficient KV space for "
f
"
{
request_id
=
}
{
seq_id
=
}
. (
{
allocated_kv_slots
=
}
"
f
"
{
required_num_kv_slots
=
}
)."
)
def
_raise_if_unsupported
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
"""MultiStepWorker does not yet implement support for cache swap
operations or beam search.
"""
if
any
([
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
]):
raise
NotImplementedError
(
"MultiStepWorker does not support cache operations"
)
if
any
(
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
for
seq_group_metadata
in
seq_group_metadata_list
):
raise
NotImplementedError
(
"MultiStepWorker does not support beam search."
)
vllm/worker/worker.py
View file @
51679bbd
"""A GPU worker class."""
import
gc
import
os
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
,
Set
,
Optional
import
torch
import
torch.distributed
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
,
LoRAConfig
)
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast_object_list
)
broadcast_tensor_dict
)
from
vllm.model_executor.parallel_utils.custom_all_reduce
import
init_custom_ar
from
vllm.model_executor.parallel_utils.parallel_state
import
(
initializ
e_model_parallel
)
ensur
e_model_parallel
_initialized
)
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.lora.request
import
LoRARequest
class
Worker
:
...
...
@@ -33,6 +36,8 @@ class Worker:
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
)
->
None
:
self
.
model_config
=
model_config
...
...
@@ -41,12 +46,17 @@ class Worker:
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
lora_config
=
lora_config
self
.
is_driver_worker
=
is_driver_worker
if
self
.
is_driver_worker
:
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
self
.
model_runner
=
ModelRunner
(
model_config
,
parallel_config
,
scheduler_config
,
is_driver_worker
)
self
.
model_runner
=
ModelRunner
(
model_config
,
parallel_config
,
scheduler_config
,
lora_config
=
self
.
lora_config
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
is_driver_worker
)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self
.
cache_config
=
None
...
...
@@ -71,9 +81,10 @@ class Worker:
_check_if_gpu_supports_dtype
(
self
.
model_config
.
dtype
)
# Initialize the distributed environment.
_init_distributed_environment
(
self
.
parallel_config
,
self
.
rank
,
self
.
distributed_init_method
)
init_distributed_environment
(
self
.
parallel_config
,
self
.
rank
,
self
.
distributed_init_method
)
if
not
self
.
parallel_config
.
disable_custom_all_reduce
:
init_custom_ar
()
# Initialize the model.
set_random_seed
(
self
.
model_config
.
seed
)
...
...
@@ -86,7 +97,16 @@ class Worker:
block_size
:
int
,
gpu_memory_utilization
:
float
,
cpu_swap_space
:
int
,
cache_dtype
:
str
,
)
->
Tuple
[
int
,
int
]:
"""Profiles the peak memory usage of the model and returns the maximum
number of GPU and CPU cache blocks that can be allocated.
Args:
block_size: The size of the cache block.
gpu_memory_utilization: The fraction of the total GPU memory to use.
cpu_swap_space: The size of the CPU swap space in bytes.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch
.
cuda
.
empty_cache
()
...
...
@@ -102,13 +122,16 @@ class Worker:
peak_memory
=
total_gpu_memory
-
free_gpu_memory
cache_block_size
=
CacheEngine
.
get_cache_block_size
(
block_size
,
self
.
model_config
,
self
.
parallel_config
)
block_size
,
cache_dtype
,
self
.
model_config
,
self
.
parallel_config
)
num_gpu_blocks
=
int
(
(
total_gpu_memory
*
gpu_memory_utilization
-
peak_memory
)
//
cache_block_size
)
num_cpu_blocks
=
int
(
cpu_swap_space
//
cache_block_size
)
num_gpu_blocks
=
max
(
num_gpu_blocks
,
0
)
num_cpu_blocks
=
max
(
num_cpu_blocks
,
0
)
if
self
.
model_runner
.
lora_manager
:
self
.
model_runner
.
remove_all_loras
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
return
num_gpu_blocks
,
num_cpu_blocks
...
...
@@ -167,20 +190,21 @@ class Worker:
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
block_swapping_info
=
[
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
]
broadcast_object_list
([
num_seq_groups
]
+
block_swapping_info
,
src
=
0
)
data
=
{
"num_seq_groups"
:
num_seq_groups
,
"blocks_to_swap_in"
:
blocks_to_swap_in
,
"blocks_to_swap_out"
:
blocks_to_swap_out
,
"blocks_to_copy"
:
blocks_to_copy
,
}
broadcast_tensor_dict
(
data
,
src
=
0
)
else
:
# num_seq_groups, blocks_to_swap_in, blocks_to_swap_out,
# blocks_to_copy (4 elements)
recv_data
=
[
None
]
*
4
broadcast_object_list
(
recv_data
,
src
=
0
)
num_seq_groups
=
recv_data
[
0
]
block_swapping_info
=
recv_data
[
1
:]
data
=
broadcast_tensor_dict
(
src
=
0
)
num_seq_groups
=
data
[
"num_seq_groups"
]
blocks_to_swap_in
=
data
[
"blocks_to_swap_in"
]
blocks_to_swap_out
=
data
[
"blocks_to_swap_out"
]
blocks_to_copy
=
data
[
"blocks_to_copy"
]
self
.
cache_swap
(
*
block_swap
p
in
g_info
)
self
.
cache_swap
(
block
s_to
_swap
_
in
,
blocks_to_swap_out
,
blocks_to_copy
)
# If there is no input, we don't need to execute the model.
if
num_seq_groups
==
0
:
...
...
@@ -190,8 +214,17 @@ class Worker:
self
.
gpu_cache
)
return
output
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
self
.
model_runner
.
add_lora
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_runner
.
remove_lora
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
model_runner
.
list_loras
()
def
_
init_distributed_environment
(
def
init_distributed_environment
(
parallel_config
:
ParallelConfig
,
rank
:
int
,
distributed_init_method
:
Optional
[
str
]
=
None
,
...
...
@@ -218,8 +251,8 @@ def _init_distributed_environment(
# A small all_reduce for warmup.
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
).
cuda
())
initializ
e_model_parallel
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
ensur
e_model_parallel
_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
...
...
@@ -231,4 +264,6 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
raise
ValueError
(
"Bfloat16 is only supported on GPUs with compute capability "
f
"of at least 8.0. Your
{
gpu_name
}
GPU has compute capability "
f
"
{
compute_capability
[
0
]
}
.
{
compute_capability
[
1
]
}
."
)
f
"
{
compute_capability
[
0
]
}
.
{
compute_capability
[
1
]
}
. "
"You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half."
)
Prev
1
…
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment