Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e0c15758
Unverified
Commit
e0c15758
authored
Jul 22, 2024
by
Cody Yu
Committed by
GitHub
Jul 23, 2024
Browse files
[Core] Modulize prepare input and attention metadata builder (#6596)
parent
bdf5fd13
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
409 additions
and
298 deletions
+409
-298
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+3
-17
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+22
-21
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+31
-29
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+25
-24
vllm/utils.py
vllm/utils.py
+5
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+323
-207
No files found.
vllm/attention/backends/abstract.py
View file @
e0c15758
...
...
@@ -7,7 +7,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
import
torch
if
TYPE_CHECKING
:
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.worker.model_runner_base
import
ModelRunnerInputBuilderBase
...
...
@@ -128,25 +127,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""
@
abstractmethod
def
__init__
(
self
,
input_builder
)
->
None
:
def
__init__
(
self
,
input_builder
:
"ModelRunnerInputBuilderBase"
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
add_seq_group
(
self
,
seq_group_metadata
:
"SequenceGroupMetadata"
,
token_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
curr_seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
context_lens
:
List
[
int
],
curr_sliding_window_blocks
:
List
[
int
],
prefix_cache_hit
:
bool
,
chunked_prefill_enabled
:
bool
):
"""Add a sequence group to the metadata and update
corresponding fields (in Python objects).
"""
raise
NotImplementedError
@
abstractmethod
def
build
(
self
,
runner
:
"ModelRunnerInputBuilderBase"
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
)
->
T
:
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
)
->
T
:
"""Build attention metadata with on-device tensors."""
raise
NotImplementedError
...
...
vllm/attention/backends/flash_attn.py
View file @
e0c15758
...
...
@@ -13,12 +13,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUBuilder
)
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
class
FlashAttentionBackend
(
AttentionBackend
):
...
...
@@ -212,30 +210,30 @@ class FlashAttentionMetadataBuilder(
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
self
.
use_v2_block_manager
=
(
input_builder
.
scheduler_config
.
use_v2_block_manager
)
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
token_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
curr_seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
context_lens
:
List
[
int
],
curr_sliding_window_blocks
:
List
[
int
],
prefix_cache_hit
:
bool
,
chunked_prefill_enabled
:
bool
):
def
_add_seq_group
(
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
chunked_prefill_enabled
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt
=
seq_group_meta
data
.
is_prompt
block_tables
=
seq_group_meta
data
.
block_tables
is_prompt
=
inter_
data
.
is_prompt
block_tables
=
inter_
data
.
block_tables
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
seq_group_metadata
.
seq_data
.
keys
(),
token_lens
,
seq_lens
,
curr_seq_lens
,
query_lens
,
context_lens
,
curr_sliding_window_blocks
):
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
...
...
@@ -254,7 +252,7 @@ class FlashAttentionMetadataBuilder(
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
if
inter_data
.
prefix_cache_hit
:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table
=
block_tables
[
seq_id
]
...
...
@@ -270,16 +268,19 @@ class FlashAttentionMetadataBuilder(
self
.
use_v2_block_manager
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
seq_group_metadata
.
block_tables
)
self
.
block_size
,
inter_data
.
block_tables
)
def
build
(
self
,
runner
:
"GPUModelRunnerBase"
,
seq_lens
,
query_lens
,
def
build
(
self
,
seq_lens
:
List
[
int
]
,
query_lens
:
List
[
int
]
,
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors."""
device
=
runner
.
device
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
logits_soft_cap
=
getattr
(
runner
.
model_config
.
hf_config
,
logits_soft_cap
=
getattr
(
self
.
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
...
...
@@ -300,7 +301,7 @@ class FlashAttentionMetadataBuilder(
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
runner
.
graph_block_tables
[:
batch_size
]
input_block_tables
=
self
.
runner
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
...
...
vllm/attention/backends/flashinfer.py
View file @
e0c15758
...
...
@@ -21,12 +21,10 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
get_kv_cache_torch_dtype
,
make_tensor_with_pad
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUBuilder
)
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
class
FlashInferBackend
(
AttentionBackend
):
...
...
@@ -216,6 +214,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
self
.
use_v2_block_manager
=
(
...
...
@@ -238,26 +239,24 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# paged_kv_last_page_len is the length of the last page of each request
self
.
paged_kv_last_page_len
:
List
[
int
]
=
[]
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
token_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
curr_seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
context_lens
:
List
[
int
],
curr_sliding_window_blocks
:
List
[
int
],
prefix_cache_hit
:
bool
,
chunked_prefill_enabled
:
bool
):
def
_add_seq_group
(
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
chunked_prefill_enabled
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt
=
seq_group_meta
data
.
is_prompt
block_tables
=
seq_group_meta
data
.
block_tables
computed_block_nums
=
seq_group_meta
data
.
computed_block_nums
is_prompt
=
inter_
data
.
is_prompt
block_tables
=
inter_
data
.
block_tables
computed_block_nums
=
inter_
data
.
computed_block_nums
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
seq_group_metadata
.
seq_data
.
keys
(),
token_lens
,
seq_lens
,
curr_seq_lens
,
query_lens
,
context_lens
,
curr_sliding_window_blocks
):
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
self
.
num_prefills
+=
1
...
...
@@ -275,7 +274,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
if
inter_data
.
prefix_cache_hit
:
block_table
=
computed_block_nums
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
...
...
@@ -290,8 +289,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
use_v2_block_manager
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
seq_group_metadata
.
block_tables
)
self
.
block_size
,
inter_data
.
block_tables
)
# It is not necessary to add paged_kv_indices, paged_kv_indptr,
# and paged_kv_last_page_len for profile run because we will
...
...
@@ -317,9 +315,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
last_page_len
=
self
.
block_size
self
.
paged_kv_last_page_len
.
append
(
last_page_len
)
def
build
(
self
,
runner
:
"GPUModelRunnerBase"
,
seq_lens
,
query_lens
,
def
build
(
self
,
seq_lens
:
List
[
int
]
,
query_lens
:
List
[
int
]
,
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
device
=
runner
.
device
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
max_query_len
=
max
(
query_lens
)
...
...
@@ -333,7 +335,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
runner
.
graph_block_tables
[:
batch_size
]
input_block_tables
=
self
.
runner
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
...
...
@@ -377,7 +379,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype
=
torch
.
long
,
device
=
device
)
logits_soft_cap
=
getattr
(
runner
.
model_config
.
hf_config
,
logits_soft_cap
=
getattr
(
self
.
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
len
(
self
.
paged_kv_indptr
)
>
0
:
...
...
@@ -394,8 +396,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr_tensor
=
None
paged_kv_last_page_len_tensor
=
None
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
runner
.
kv_cache_dtype
,
runner
.
model_config
.
dtype
)
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
runner
.
kv_cache_dtype
,
self
.
runner
.
model_config
.
dtype
)
return
FlashInferMetadata
(
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
...
...
@@ -406,11 +408,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr
=
paged_kv_indptr_tensor
,
paged_kv_indices
=
paged_kv_indices_tensor
,
paged_kv_last_page_len
=
paged_kv_last_page_len_tensor
,
num_qo_heads
=
runner
.
model_config
.
get_num_attention_heads
(
runner
.
parallel_config
),
num_kv_heads
=
runner
.
model_config
.
get_num_kv_heads
(
runner
.
parallel_config
),
head_dim
=
runner
.
model_config
.
get_head_size
(),
num_qo_heads
=
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
runner
.
parallel_config
),
num_kv_heads
=
self
.
runner
.
model_config
.
get_num_kv_heads
(
self
.
runner
.
parallel_config
),
head_dim
=
self
.
runner
.
model_config
.
get_head_size
(),
page_size
=
self
.
block_size
,
seq_start_loc
=
seq_start_loc
,
query_start_loc
=
query_start_loc
,
...
...
vllm/attention/backends/utils.py
View file @
e0c15758
...
...
@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
import
torch
from
vllm.attention
import
AttentionMetadata
,
AttentionMetadataBuilder
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
# Error string(s) for encoder/decoder
...
...
@@ -15,8 +14,7 @@ STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
PAD_SLOT_ID
=
-
1
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
ModelInputForGPUBuilder
)
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
def
is_block_tables_empty
(
block_tables
:
Union
[
None
,
Dict
]):
...
...
@@ -95,26 +93,27 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
self
.
use_v2_block_manager
=
(
input_builder
.
scheduler_config
.
use_v2_block_manager
)
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
token_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
curr_seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
context_lens
:
List
[
int
],
curr_sliding_window_blocks
:
List
[
int
],
prefix_cache_hit
,
chunked_prefill_enabled
):
is_prompt
=
seq_group_metadata
.
is_prompt
block_tables
=
seq_group_metadata
.
block_tables
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
def
_add_seq_group
(
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
chunked_prefill_enabled
:
bool
):
is_prompt
=
inter_data
.
is_prompt
block_tables
=
inter_data
.
block_tables
computed_block_nums
=
inter_data
.
computed_block_nums
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
seq_group_metadata
.
seq_data
.
keys
(),
token_lens
,
seq_lens
,
curr_seq_lens
,
query_lens
,
context_lens
,
curr_sliding_window_blocks
):
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
self
.
num_prefills
+=
1
...
...
@@ -132,7 +131,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
if
inter_data
.
prefix_cache_hit
:
block_table
=
computed_block_nums
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
...
...
@@ -146,16 +145,18 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self
.
use_v2_block_manager
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
seq_group_metadata
.
block_tables
)
self
.
block_size
,
inter_data
.
block_tables
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
def
build
(
self
,
runner
:
"GPUModelRunnerBase"
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
device
=
runner
.
device
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
logits_soft_cap
=
getattr
(
runner
.
model_config
.
hf_config
,
logits_soft_cap
=
getattr
(
self
.
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
...
...
@@ -176,7 +177,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables
=
runner
.
graph_block_tables
[:
batch_size
]
input_block_tables
=
self
.
runner
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
...
...
vllm/utils.py
View file @
e0c15758
...
...
@@ -719,6 +719,11 @@ def merge_dicts(dict1: Dict[K, List[T]],
return
dict
(
merged_dict
)
def
flatten_2d_lists
(
lists
:
List
[
List
[
T
]])
->
List
[
T
]:
"""Flatten a list of lists to a single list."""
return
[
item
for
sublist
in
lists
for
item
in
sublist
]
def
init_cached_hf_modules
()
->
None
:
"""
Lazy initialization of the Hugging Face modules.
...
...
vllm/worker/model_runner.py
View file @
e0c15758
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment