Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
925f3332
Unverified
Commit
925f3332
authored
Mar 24, 2024
by
Woosuk Kwon
Committed by
GitHub
Mar 25, 2024
Browse files
[Core] Refactor Attention Take 2 (#3462)
parent
b0dfa91d
Changes
44
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
79 additions
and
126 deletions
+79
-126
vllm/sequence.py
vllm/sequence.py
+1
-1
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+31
-84
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+44
-41
vllm/worker/worker.py
vllm/worker/worker.py
+3
-0
No files found.
vllm/sequence.py
View file @
925f3332
...
@@ -431,7 +431,7 @@ class SequenceGroup:
...
@@ -431,7 +431,7 @@ class SequenceGroup:
class
SequenceGroupMetadata
:
class
SequenceGroupMetadata
:
"""Metadata for a sequence group. Used to create `
Input
Metadata`.
"""Metadata for a sequence group. Used to create `
Attention
Metadata`.
Args:
Args:
request_id: The ID of the request.
request_id: The ID of the request.
...
...
vllm/worker/cache_engine.py
View file @
925f3332
"""CacheEngine class for managing the KV cache."""
"""CacheEngine class for managing the KV cache."""
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
import
torch
import
torch
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
CacheConfig
,
ModelConfig
,
ParallelConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_pin_memory_available
,
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils
import
is_pin_memory_available
,
STR_DTYPE_TO_TORCH_DTYPE
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
CacheEngine
:
class
CacheEngine
:
"""Manages the KV cache.
"""Manages the KV cache.
...
@@ -43,95 +42,43 @@ class CacheEngine:
...
@@ -43,95 +42,43 @@ class CacheEngine:
else
:
else
:
self
.
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
self
.
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_config
.
cache_dtype
]
# Get attention backend.
self
.
attn_backend
=
get_attn_backend
(
model_config
.
dtype
)
# Initialize the cache.
# Initialize the cache.
self
.
gpu_cache
=
self
.
allocate_gpu_cache
()
self
.
gpu_cache
=
self
.
_allocate_kv_cache
(
self
.
num_gpu_blocks
,
"cuda"
)
self
.
cpu_cache
=
self
.
allocate_cpu_cache
()
self
.
cpu_cache
=
self
.
_allocate_kv_cache
(
self
.
num_cpu_blocks
,
"cpu"
)
def
get_key_block_shape
(
self
)
->
Tuple
[
int
,
int
,
int
,
int
]:
element_size
=
torch
.
tensor
([],
dtype
=
self
.
dtype
).
element_size
()
x
=
16
//
element_size
return
(
self
.
num_heads
,
self
.
head_size
//
x
,
self
.
block_size
,
x
,
)
def
get_value_block_shape
(
self
)
->
Tuple
[
int
,
int
,
int
]:
return
(
self
.
num_heads
,
self
.
head_size
,
self
.
block_size
,
)
def
allocate_gpu_cache
(
self
)
->
List
[
KVCache
]:
gpu_cache
:
List
[
KVCache
]
=
[]
key_block_shape
=
self
.
get_key_block_shape
()
value_block_shape
=
self
.
get_value_block_shape
()
for
_
in
range
(
self
.
num_layers
):
key_blocks
=
torch
.
empty
(
size
=
(
self
.
num_gpu_blocks
,
*
key_block_shape
),
dtype
=
self
.
dtype
,
device
=
"cuda"
,
)
value_blocks
=
torch
.
empty
(
size
=
(
self
.
num_gpu_blocks
,
*
value_block_shape
),
dtype
=
self
.
dtype
,
device
=
"cuda"
,
)
gpu_cache
.
append
((
key_blocks
,
value_blocks
))
return
gpu_cache
def
allocate_cpu_cache
(
self
)
->
List
[
KVCache
]:
cpu_cache
:
List
[
KVCache
]
=
[]
key_block_shape
=
self
.
get_key_block_shape
()
value_block_shape
=
self
.
get_value_block_shape
()
pin_memory
=
is_pin_memory_available
()
for
_
in
range
(
self
.
num_layers
):
key_blocks
=
torch
.
empty
(
size
=
(
self
.
num_cpu_blocks
,
*
key_block_shape
),
dtype
=
self
.
dtype
,
pin_memory
=
pin_memory
,
device
=
"cpu"
,
)
value_blocks
=
torch
.
empty
(
size
=
(
self
.
num_cpu_blocks
,
*
value_block_shape
),
dtype
=
self
.
dtype
,
pin_memory
=
pin_memory
,
device
=
"cpu"
,
)
cpu_cache
.
append
((
key_blocks
,
value_blocks
))
return
cpu_cache
def
_swap
(
self
,
src
:
List
[
KVCache
],
dst
:
List
[
KVCache
],
src_to_dst
:
Dict
[
int
,
int
],
)
->
None
:
from
vllm._C
import
cache_ops
for
i
in
range
(
self
.
num_layers
):
def
_allocate_kv_cache
(
src_key_cache
,
src_value_cache
=
src
[
i
]
self
,
dst_key_cache
,
dst_value_cache
=
dst
[
i
]
num_blocks
:
int
,
# Copy the key blocks.
device
:
str
,
cache_ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
)
->
List
[
torch
.
Tensor
]:
# Copy the value blocks.
"""Allocates KV cache on the specified device."""
cache_ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
,
self
.
block_size
,
self
.
num_heads
,
self
.
head_size
)
pin_memory
=
is_pin_memory_available
()
if
device
==
"cpu"
else
False
kv_cache
:
List
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
self
.
num_layers
):
kv_cache
.
append
(
torch
.
empty
(
kv_cache_shape
,
dtype
=
self
.
dtype
,
pin_memory
=
pin_memory
,
device
=
device
))
return
kv_cache
def
swap_in
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
def
swap_in
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
self
.
_swap
(
self
.
cpu_cache
,
self
.
gpu_cache
,
src_to_dst
)
for
i
in
range
(
self
.
num_layers
):
self
.
attn_backend
.
swap_blocks
(
self
.
cpu_cache
[
i
],
self
.
gpu_cache
[
i
],
src_to_dst
)
def
swap_out
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
def
swap_out
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
self
.
_swap
(
self
.
gpu_cache
,
self
.
cpu_cache
,
src_to_dst
)
for
i
in
range
(
self
.
num_layers
):
self
.
attn_backend
.
swap_blocks
(
self
.
gpu_cache
[
i
],
self
.
cpu_cache
[
i
],
src_to_dst
)
def
copy
(
self
,
src_to_dsts
:
Dict
[
int
,
List
[
int
]])
->
None
:
def
copy
(
self
,
src_to_dsts
:
Dict
[
int
,
List
[
int
]])
->
None
:
from
vllm._C
import
cache_ops
self
.
attn_backend
.
copy_blocks
(
self
.
gpu_cache
,
src_to_dsts
)
key_caches
=
[
key_cache
for
key_cache
,
_
in
self
.
gpu_cache
]
value_caches
=
[
value_cache
for
_
,
value_cache
in
self
.
gpu_cache
]
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dsts
)
@
staticmethod
@
staticmethod
def
get_cache_block_size
(
def
get_cache_block_size
(
...
...
vllm/worker/model_runner.py
View file @
925f3332
...
@@ -6,10 +6,11 @@ import numpy as np
...
@@ -6,10 +6,11 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
(
DeviceConfig
,
ModelConfig
,
LoRAConfig
,
ParallelConfig
,
from
vllm.config
import
(
DeviceConfig
,
ModelConfig
,
LoRAConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
InputMetadata
,
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.parallel_utils
import
cupy_utils
from
vllm.model_executor.parallel_utils
import
cupy_utils
from
vllm.model_executor.parallel_utils.communication_op
import
(
from
vllm.model_executor.parallel_utils.communication_op
import
(
...
@@ -28,7 +29,6 @@ from vllm.utils import (async_tensor_h2d, CudaMemoryProfiler,
...
@@ -28,7 +29,6 @@ from vllm.utils import (async_tensor_h2d, CudaMemoryProfiler,
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
_PAD_SLOT_ID
=
-
1
_PAD_SLOT_ID
=
-
1
LORA_WARMUP_RANK
=
8
LORA_WARMUP_RANK
=
8
_BATCH_SIZE_ALIGNMENT
=
8
_BATCH_SIZE_ALIGNMENT
=
8
...
@@ -85,6 +85,9 @@ class ModelRunner:
...
@@ -85,6 +85,9 @@ class ModelRunner:
self
.
pin_memory
=
is_pin_memory_available
()
self
.
pin_memory
=
is_pin_memory_available
()
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
dtype
if
model_config
is
not
None
else
None
)
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
with
CudaMemoryProfiler
()
as
m
:
with
CudaMemoryProfiler
()
as
m
:
self
.
model
=
get_model
(
self
.
model_config
,
self
.
model
=
get_model
(
self
.
model_config
,
...
@@ -127,8 +130,8 @@ class ModelRunner:
...
@@ -127,8 +130,8 @@ class ModelRunner:
def
_prepare_prompt
(
def
_prepare_prompt
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
,
List
[
int
]
,
List
[
int
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
],
List
[
int
],
List
[
int
],
Set
[
LoRARequest
]]:
List
[
int
],
List
[
int
],
List
[
int
],
Set
[
LoRARequest
]]:
assert
len
(
seq_group_metadata_list
)
>
0
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
...
@@ -216,7 +219,7 @@ class ModelRunner:
...
@@ -216,7 +219,7 @@ class ModelRunner:
slot_mapping
.
append
(
slot
)
slot_mapping
.
append
(
slot
)
max_subquery_len
=
max
(
subquery_lens
)
max_subquery_len
=
max
(
subquery_lens
)
max_
seq
_len
=
max
(
prompt_lens
)
max_
prompt
_len
=
max
(
prompt_lens
)
num_prompt_tokens
=
len
(
input_tokens
)
num_prompt_tokens
=
len
(
input_tokens
)
assert
max_subquery_len
>
0
assert
max_subquery_len
>
0
...
@@ -270,7 +273,7 @@ class ModelRunner:
...
@@ -270,7 +273,7 @@ class ModelRunner:
dtype
=
seq_start_loc
.
dtype
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
out
=
seq_start_loc
[
1
:])
input
_metadata
=
InputM
etadata
(
attn
_metadata
=
self
.
attn_backend
.
make_m
etadata
(
is_prompt
=
True
,
is_prompt
=
True
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
prompt_lens
=
prompt_lens
,
prompt_lens
=
prompt_lens
,
...
@@ -279,7 +282,7 @@ class ModelRunner:
...
@@ -279,7 +282,7 @@ class ModelRunner:
num_generation_tokens
=
0
,
num_generation_tokens
=
0
,
max_subquery_len
=
max_subquery_len
,
max_subquery_len
=
max_subquery_len
,
max_context_len
=
None
,
max_context_len
=
None
,
max_
seq
_len
=
max_
seq
_len
,
max_
prompt
_len
=
max_
prompt
_len
,
subquery_start_loc
=
subquery_start_loc
,
subquery_start_loc
=
subquery_start_loc
,
seq_start_loc
=
seq_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens
=
context_lens_tensor
,
context_lens
=
context_lens_tensor
,
...
@@ -287,15 +290,15 @@ class ModelRunner:
...
@@ -287,15 +290,15 @@ class ModelRunner:
use_cuda_graph
=
False
,
use_cuda_graph
=
False
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
)
return
(
input_tokens
,
input_positions
,
input
_metadata
,
prompt_lens
,
return
(
input_tokens
,
input_positions
,
attn
_metadata
,
prompt_lens
,
subquery_lens
,
lora_index_mapping
,
lora_prompt_mapping
,
subquery_lens
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
)
lora_requests
)
def
_prepare_decode
(
def
_prepare_decode
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
,
List
[
int
]
,
List
[
int
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
],
Set
[
LoRARequest
]]:
List
[
int
],
Set
[
LoRARequest
]]:
assert
len
(
seq_group_metadata_list
)
>
0
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
...
@@ -401,7 +404,7 @@ class ModelRunner:
...
@@ -401,7 +404,7 @@ class ModelRunner:
device
=
self
.
device
,
device
=
self
.
device
,
)
)
input
_metadata
=
InputM
etadata
(
attn
_metadata
=
self
.
attn_backend
.
make_m
etadata
(
is_prompt
=
False
,
is_prompt
=
False
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
prompt_lens
=
None
,
prompt_lens
=
None
,
...
@@ -410,7 +413,7 @@ class ModelRunner:
...
@@ -410,7 +413,7 @@ class ModelRunner:
num_generation_tokens
=
len
(
input_tokens
),
num_generation_tokens
=
len
(
input_tokens
),
max_subquery_len
=
None
,
max_subquery_len
=
None
,
max_context_len
=
max_context_len
,
max_context_len
=
max_context_len
,
max_
seq
_len
=
None
,
max_
prompt
_len
=
None
,
subquery_start_loc
=
None
,
subquery_start_loc
=
None
,
seq_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens
=
context_lens
,
context_lens
=
context_lens
,
...
@@ -418,7 +421,7 @@ class ModelRunner:
...
@@ -418,7 +421,7 @@ class ModelRunner:
use_cuda_graph
=
use_captured_graph
,
use_cuda_graph
=
use_captured_graph
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
)
)
return
(
input_tokens
,
input_positions
,
input
_metadata
,
return
(
input_tokens
,
input_positions
,
attn
_metadata
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
)
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
)
def
_prepare_sample
(
def
_prepare_sample
(
...
@@ -522,7 +525,7 @@ class ModelRunner:
...
@@ -522,7 +525,7 @@ class ModelRunner:
def
prepare_input_tensors
(
def
prepare_input_tensors
(
self
,
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Input
Metadata
,
SamplingMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Attention
Metadata
,
SamplingMetadata
,
Set
[
int
],
LoRAMapping
]:
Set
[
int
],
LoRAMapping
]:
if
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
# NOTE: We assume that all sequences in the group are all prompts or
# NOTE: We assume that all sequences in the group are all prompts or
...
@@ -530,11 +533,11 @@ class ModelRunner:
...
@@ -530,11 +533,11 @@ class ModelRunner:
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
# Prepare input tensors.
if
is_prompt
:
if
is_prompt
:
(
input_tokens
,
input_positions
,
input
_metadata
,
prompt_lens
,
(
input_tokens
,
input_positions
,
attn
_metadata
,
prompt_lens
,
subquery_lens
,
lora_index_mapping
,
lora_prompt_mapping
,
subquery_lens
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
lora_requests
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
else
:
(
input_tokens
,
input_positions
,
input
_metadata
,
(
input_tokens
,
input_positions
,
attn
_metadata
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_index_mapping
,
lora_prompt_mapping
,
lora_requests
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
lora_requests
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
prompt_lens
=
[]
prompt_lens
=
[]
...
@@ -560,7 +563,7 @@ class ModelRunner:
...
@@ -560,7 +563,7 @@ class ModelRunner:
"lora_requests"
:
lora_requests
,
"lora_requests"
:
lora_requests
,
"lora_mapping"
:
lora_mapping
,
"lora_mapping"
:
lora_mapping
,
}
}
metadata_dict
.
update
(
input
_metadata
.
asdict_zerocopy
())
metadata_dict
.
update
(
attn
_metadata
.
asdict_zerocopy
())
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
else
:
else
:
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
...
@@ -570,7 +573,7 @@ class ModelRunner:
...
@@ -570,7 +573,7 @@ class ModelRunner:
"selected_token_indices"
)
"selected_token_indices"
)
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
lora_mapping
=
metadata_dict
.
pop
(
"lora_mapping"
)
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
lora_requests
=
metadata_dict
.
pop
(
"lora_requests"
)
input
_metadata
=
InputM
etadata
(
**
metadata_dict
)
attn
_metadata
=
self
.
attn_backend
.
make_m
etadata
(
**
metadata_dict
)
sampling_metadata
=
SamplingMetadata
(
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_groups
=
None
,
seq_data
=
None
,
seq_data
=
None
,
...
@@ -581,16 +584,16 @@ class ModelRunner:
...
@@ -581,16 +584,16 @@ class ModelRunner:
perform_sampling
=
False
,
perform_sampling
=
False
,
)
)
return
(
input_tokens
,
input_positions
,
input
_metadata
,
return
(
input_tokens
,
input_positions
,
attn
_metadata
,
sampling_metadata
,
lora_requests
,
lora_mapping
)
sampling_metadata
,
lora_requests
,
lora_mapping
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
],
kv_caches
:
List
[
torch
.
Tensor
],
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
(
input_tokens
,
input_positions
,
input
_metadata
,
sampling_metadata
,
(
input_tokens
,
input_positions
,
attn
_metadata
,
sampling_metadata
,
lora_requests
,
lora_requests
,
lora_mapping
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
lora_mapping
)
=
self
.
prepare_input_tensors
(
seq_group_metadata_list
)
...
@@ -598,7 +601,7 @@ class ModelRunner:
...
@@ -598,7 +601,7 @@ class ModelRunner:
self
.
set_active_loras
(
lora_requests
,
lora_mapping
)
self
.
set_active_loras
(
lora_requests
,
lora_mapping
)
# Execute the model.
# Execute the model.
if
input
_metadata
.
use_cuda_graph
:
if
attn
_metadata
.
use_cuda_graph
:
graph_batch_size
=
input_tokens
.
shape
[
0
]
graph_batch_size
=
input_tokens
.
shape
[
0
]
model_executable
=
self
.
graph_runners
[
graph_batch_size
]
model_executable
=
self
.
graph_runners
[
graph_batch_size
]
else
:
else
:
...
@@ -607,7 +610,7 @@ class ModelRunner:
...
@@ -607,7 +610,7 @@ class ModelRunner:
input_ids
=
input_tokens
,
input_ids
=
input_tokens
,
positions
=
input_positions
,
positions
=
input_positions
,
kv_caches
=
kv_caches
,
kv_caches
=
kv_caches
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
)
# Compute the logits.
# Compute the logits.
...
@@ -673,7 +676,7 @@ class ModelRunner:
...
@@ -673,7 +676,7 @@ class ModelRunner:
# Run the model with the dummy inputs.
# Run the model with the dummy inputs.
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
kv_caches
=
[
(
None
,
None
)
]
*
num_layers
kv_caches
=
[
None
]
*
num_layers
self
.
execute_model
(
seqs
,
kv_caches
)
self
.
execute_model
(
seqs
,
kv_caches
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
return
return
...
@@ -705,7 +708,7 @@ class ModelRunner:
...
@@ -705,7 +708,7 @@ class ModelRunner:
return
self
.
lora_manager
.
list_loras
()
return
self
.
lora_manager
.
list_loras
()
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
capture_model
(
self
,
kv_caches
:
List
[
KVCache
])
->
None
:
def
capture_model
(
self
,
kv_caches
:
List
[
torch
.
Tensor
])
->
None
:
"""Cuda graph capture a model.
"""Cuda graph capture a model.
Note that CUDA graph's performance gain is negligible if number
Note that CUDA graph's performance gain is negligible if number
...
@@ -759,8 +762,8 @@ class ModelRunner:
...
@@ -759,8 +762,8 @@ class ModelRunner:
# NOTE: Capturing the largest batch size first may help reduce the
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
# memory usage of CUDA graph.
for
batch_size
in
reversed
(
batch_size_capture_list
):
for
batch_size
in
reversed
(
batch_size_capture_list
):
# Create dummy
input
_metadata.
# Create dummy
attn
_metadata.
input
_metadata
=
InputM
etadata
(
attn
_metadata
=
self
.
attn_backend
.
make_m
etadata
(
is_prompt
=
False
,
is_prompt
=
False
,
slot_mapping
=
slot_mapping
[:
batch_size
],
slot_mapping
=
slot_mapping
[:
batch_size
],
prompt_lens
=
None
,
prompt_lens
=
None
,
...
@@ -769,7 +772,7 @@ class ModelRunner:
...
@@ -769,7 +772,7 @@ class ModelRunner:
num_generation_tokens
=
batch_size
,
num_generation_tokens
=
batch_size
,
max_subquery_len
=
None
,
max_subquery_len
=
None
,
max_context_len
=
self
.
max_context_len_to_capture
,
max_context_len
=
self
.
max_context_len_to_capture
,
max_
seq
_len
=
None
,
max_
prompt
_len
=
None
,
subquery_start_loc
=
None
,
subquery_start_loc
=
None
,
seq_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens
=
context_lens
[:
batch_size
],
context_lens
=
context_lens
[:
batch_size
],
...
@@ -790,7 +793,7 @@ class ModelRunner:
...
@@ -790,7 +793,7 @@ class ModelRunner:
input_tokens
[:
batch_size
],
input_tokens
[:
batch_size
],
input_positions
[:
batch_size
],
input_positions
[:
batch_size
],
kv_caches
,
kv_caches
,
input
_metadata
,
attn
_metadata
,
memory_pool
=
self
.
graph_memory_pool
,
memory_pool
=
self
.
graph_memory_pool
,
)
)
self
.
graph_memory_pool
=
graph_runner
.
graph
.
pool
()
self
.
graph_memory_pool
=
graph_runner
.
graph
.
pool
()
...
@@ -826,8 +829,8 @@ class CUDAGraphRunner:
...
@@ -826,8 +829,8 @@ class CUDAGraphRunner:
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
memory_pool
,
memory_pool
,
)
->
None
:
)
->
None
:
assert
self
.
graph
is
None
assert
self
.
graph
is
None
...
@@ -839,7 +842,7 @@ class CUDAGraphRunner:
...
@@ -839,7 +842,7 @@ class CUDAGraphRunner:
input_ids
,
input_ids
,
positions
,
positions
,
kv_caches
,
kv_caches
,
input
_metadata
,
attn
_metadata
,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -853,7 +856,7 @@ class CUDAGraphRunner:
...
@@ -853,7 +856,7 @@ class CUDAGraphRunner:
input_ids
,
input_ids
,
positions
,
positions
,
kv_caches
,
kv_caches
,
input
_metadata
,
attn
_metadata
,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -862,9 +865,9 @@ class CUDAGraphRunner:
...
@@ -862,9 +865,9 @@ class CUDAGraphRunner:
"input_ids"
:
input_ids
,
"input_ids"
:
input_ids
,
"positions"
:
positions
,
"positions"
:
positions
,
"kv_caches"
:
kv_caches
,
"kv_caches"
:
kv_caches
,
"slot_mapping"
:
input
_metadata
.
slot_mapping
,
"slot_mapping"
:
attn
_metadata
.
slot_mapping
,
"context_lens"
:
input
_metadata
.
context_lens
,
"context_lens"
:
attn
_metadata
.
context_lens
,
"block_tables"
:
input
_metadata
.
block_tables
,
"block_tables"
:
attn
_metadata
.
block_tables
,
}
}
self
.
output_buffers
=
{
"hidden_states"
:
hidden_states
}
self
.
output_buffers
=
{
"hidden_states"
:
hidden_states
}
return
return
...
@@ -873,8 +876,8 @@ class CUDAGraphRunner:
...
@@ -873,8 +876,8 @@ class CUDAGraphRunner:
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# KV caches are fixed tensors, so we don't need to copy them.
# KV caches are fixed tensors, so we don't need to copy them.
del
kv_caches
del
kv_caches
...
@@ -882,11 +885,11 @@ class CUDAGraphRunner:
...
@@ -882,11 +885,11 @@ class CUDAGraphRunner:
# Copy the input tensors to the input buffers.
# Copy the input tensors to the input buffers.
self
.
input_buffers
[
"input_ids"
].
copy_
(
input_ids
,
non_blocking
=
True
)
self
.
input_buffers
[
"input_ids"
].
copy_
(
input_ids
,
non_blocking
=
True
)
self
.
input_buffers
[
"positions"
].
copy_
(
positions
,
non_blocking
=
True
)
self
.
input_buffers
[
"positions"
].
copy_
(
positions
,
non_blocking
=
True
)
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
input
_metadata
.
slot_mapping
,
self
.
input_buffers
[
"slot_mapping"
].
copy_
(
attn
_metadata
.
slot_mapping
,
non_blocking
=
True
)
non_blocking
=
True
)
self
.
input_buffers
[
"context_lens"
].
copy_
(
input
_metadata
.
context_lens
,
self
.
input_buffers
[
"context_lens"
].
copy_
(
attn
_metadata
.
context_lens
,
non_blocking
=
True
)
non_blocking
=
True
)
self
.
input_buffers
[
"block_tables"
].
copy_
(
input
_metadata
.
block_tables
,
self
.
input_buffers
[
"block_tables"
].
copy_
(
attn
_metadata
.
block_tables
,
non_blocking
=
True
)
non_blocking
=
True
)
# Run the graph.
# Run the graph.
self
.
graph
.
replay
()
self
.
graph
.
replay
()
...
...
vllm/worker/worker.py
View file @
925f3332
...
@@ -128,6 +128,9 @@ class Worker:
...
@@ -128,6 +128,9 @@ class Worker:
# NOTE(woosuk): Here we assume that the other processes using the same
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
# GPU did not change their memory usage during the profiling.
peak_memory
=
self
.
init_gpu_memory
-
free_gpu_memory
peak_memory
=
self
.
init_gpu_memory
-
free_gpu_memory
assert
peak_memory
>
0
,
(
"Error in memory profiling. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance."
)
cache_block_size
=
self
.
get_cache_block_size_bytes
(
cache_block_size
=
self
.
get_cache_block_size_bytes
(
block_size
,
cache_dtype
)
block_size
,
cache_dtype
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment