Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e0c15758
Unverified
Commit
e0c15758
authored
Jul 22, 2024
by
Cody Yu
Committed by
GitHub
Jul 23, 2024
Browse files
[Core] Modulize prepare input and attention metadata builder (#6596)
parent
bdf5fd13
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
409 additions
and
298 deletions
+409
-298
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+3
-17
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+22
-21
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+31
-29
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+25
-24
vllm/utils.py
vllm/utils.py
+5
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+323
-207
No files found.
vllm/attention/backends/abstract.py
View file @
e0c15758
...
@@ -7,7 +7,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
...
@@ -7,7 +7,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
import
torch
import
torch
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.worker.model_runner_base
import
ModelRunnerInputBuilderBase
from
vllm.worker.model_runner_base
import
ModelRunnerInputBuilderBase
...
@@ -128,25 +127,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
...
@@ -128,25 +127,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""
"""Abstract class for attention metadata builders."""
@
abstractmethod
@
abstractmethod
def
__init__
(
self
,
input_builder
)
->
None
:
def
__init__
(
self
,
input_builder
:
"ModelRunnerInputBuilderBase"
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
add_seq_group
(
self
,
seq_group_metadata
:
"SequenceGroupMetadata"
,
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
token_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
)
->
T
:
curr_seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
context_lens
:
List
[
int
],
curr_sliding_window_blocks
:
List
[
int
],
prefix_cache_hit
:
bool
,
chunked_prefill_enabled
:
bool
):
"""Add a sequence group to the metadata and update
corresponding fields (in Python objects).
"""
raise
NotImplementedError
@
abstractmethod
def
build
(
self
,
runner
:
"ModelRunnerInputBuilderBase"
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
)
->
T
:
"""Build attention metadata with on-device tensors."""
"""Build attention metadata with on-device tensors."""
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/attention/backends/flash_attn.py
View file @
e0c15758
...
@@ -13,12 +13,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
...
@@ -13,12 +13,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
from
vllm.utils
import
make_tensor_with_pad
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
ModelInputForGPUBuilder
)
class
FlashAttentionBackend
(
AttentionBackend
):
class
FlashAttentionBackend
(
AttentionBackend
):
...
@@ -212,30 +210,30 @@ class FlashAttentionMetadataBuilder(
...
@@ -212,30 +210,30 @@ class FlashAttentionMetadataBuilder(
self
.
num_prefill_tokens
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
self
.
block_size
=
input_builder
.
block_size
self
.
use_v2_block_manager
=
(
self
.
use_v2_block_manager
=
(
input_builder
.
scheduler_config
.
use_v2_block_manager
)
input_builder
.
scheduler_config
.
use_v2_block_manager
)
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
def
_add_seq_group
(
token_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
curr_seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
chunked_prefill_enabled
:
bool
):
context_lens
:
List
[
int
],
curr_sliding_window_blocks
:
List
[
int
],
prefix_cache_hit
:
bool
,
chunked_prefill_enabled
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
1. context length.
2. block table.
2. block table.
3. slot mapping.
3. slot mapping.
"""
"""
is_prompt
=
seq_group_meta
data
.
is_prompt
is_prompt
=
inter_
data
.
is_prompt
block_tables
=
seq_group_meta
data
.
block_tables
block_tables
=
inter_
data
.
block_tables
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
curr_sliding_window_block
)
in
zip
(
seq_group_metadata
.
seq_data
.
keys
(),
token_lens
,
seq_lens
,
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
curr_seq_lens
,
query_lens
,
context_lens
,
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
curr_sliding_window_blocks
):
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
if
is_prompt
:
...
@@ -254,7 +252,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -254,7 +252,7 @@ class FlashAttentionMetadataBuilder(
# only allowing multiple of block_size chunk size.
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
block_table
=
[]
if
prefix_cache_hit
:
if
inter_data
.
prefix_cache_hit
:
# NOTE(woosuk): For flash-attn, the block table should
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
# include the entries for the incoming prefill tokens.
block_table
=
block_tables
[
seq_id
]
block_table
=
block_tables
[
seq_id
]
...
@@ -270,16 +268,19 @@ class FlashAttentionMetadataBuilder(
...
@@ -270,16 +268,19 @@ class FlashAttentionMetadataBuilder(
self
.
use_v2_block_manager
)
self
.
use_v2_block_manager
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
self
.
block_size
,
inter_data
.
block_tables
)
seq_group_metadata
.
block_tables
)
def
build
(
self
,
runner
:
"GPUModelRunnerBase"
,
seq_lens
,
query_lens
,
def
build
(
self
,
seq_lens
:
List
[
int
]
,
query_lens
:
List
[
int
]
,
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors."""
"""Build attention metadata with on-device tensors."""
device
=
runner
.
device
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
logits_soft_cap
=
getattr
(
runner
.
model_config
.
hf_config
,
logits_soft_cap
=
getattr
(
self
.
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
"attn_logit_softcapping"
,
None
)
if
logits_soft_cap
is
not
None
:
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -300,7 +301,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -300,7 +301,7 @@ class FlashAttentionMetadataBuilder(
# The shape of graph_block_tables is
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
# [max batch size, max context len // block size].
input_block_tables
=
runner
.
graph_block_tables
[:
batch_size
]
input_block_tables
=
self
.
runner
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
...
...
vllm/attention/backends/flashinfer.py
View file @
e0c15758
...
@@ -21,12 +21,10 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
...
@@ -21,12 +21,10 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
get_kv_cache_torch_dtype
,
make_tensor_with_pad
from
vllm.utils
import
get_kv_cache_torch_dtype
,
make_tensor_with_pad
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
ModelInputForGPUBuilder
)
class
FlashInferBackend
(
AttentionBackend
):
class
FlashInferBackend
(
AttentionBackend
):
...
@@ -216,6 +214,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -216,6 +214,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
num_prefill_tokens
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
self
.
block_size
=
input_builder
.
block_size
self
.
use_v2_block_manager
=
(
self
.
use_v2_block_manager
=
(
...
@@ -238,26 +239,24 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -238,26 +239,24 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# paged_kv_last_page_len is the length of the last page of each request
# paged_kv_last_page_len is the length of the last page of each request
self
.
paged_kv_last_page_len
:
List
[
int
]
=
[]
self
.
paged_kv_last_page_len
:
List
[
int
]
=
[]
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
def
_add_seq_group
(
token_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
curr_seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
chunked_prefill_enabled
:
bool
):
context_lens
:
List
[
int
],
curr_sliding_window_blocks
:
List
[
int
],
prefix_cache_hit
:
bool
,
chunked_prefill_enabled
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
1. context length.
2. block table.
2. block table.
3. slot mapping.
3. slot mapping.
"""
"""
is_prompt
=
seq_group_meta
data
.
is_prompt
is_prompt
=
inter_
data
.
is_prompt
block_tables
=
seq_group_meta
data
.
block_tables
block_tables
=
inter_
data
.
block_tables
computed_block_nums
=
seq_group_meta
data
.
computed_block_nums
computed_block_nums
=
inter_
data
.
computed_block_nums
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
curr_sliding_window_block
)
in
zip
(
seq_group_metadata
.
seq_data
.
keys
(),
token_lens
,
seq_lens
,
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
curr_seq_lens
,
query_lens
,
context_lens
,
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
curr_sliding_window_blocks
):
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
if
is_prompt
:
self
.
num_prefills
+=
1
self
.
num_prefills
+=
1
...
@@ -275,7 +274,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -275,7 +274,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# only allowing multiple of block_size chunk size.
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
block_table
=
[]
if
prefix_cache_hit
:
if
inter_data
.
prefix_cache_hit
:
block_table
=
computed_block_nums
block_table
=
computed_block_nums
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
and
block_tables
is
not
None
):
...
@@ -290,8 +289,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -290,8 +289,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
use_v2_block_manager
)
self
.
use_v2_block_manager
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
self
.
block_size
,
inter_data
.
block_tables
)
seq_group_metadata
.
block_tables
)
# It is not necessary to add paged_kv_indices, paged_kv_indptr,
# It is not necessary to add paged_kv_indices, paged_kv_indptr,
# and paged_kv_last_page_len for profile run because we will
# and paged_kv_last_page_len for profile run because we will
...
@@ -317,9 +315,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -317,9 +315,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
last_page_len
=
self
.
block_size
last_page_len
=
self
.
block_size
self
.
paged_kv_last_page_len
.
append
(
last_page_len
)
self
.
paged_kv_last_page_len
.
append
(
last_page_len
)
def
build
(
self
,
runner
:
"GPUModelRunnerBase"
,
seq_lens
,
query_lens
,
def
build
(
self
,
seq_lens
:
List
[
int
]
,
query_lens
:
List
[
int
]
,
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
device
=
runner
.
device
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
max_query_len
=
max
(
query_lens
)
max_query_len
=
max
(
query_lens
)
...
@@ -333,7 +335,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -333,7 +335,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# The shape of graph_block_tables is
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
# [max batch size, max context len // block size].
input_block_tables
=
runner
.
graph_block_tables
[:
batch_size
]
input_block_tables
=
self
.
runner
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
...
@@ -377,7 +379,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -377,7 +379,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
)
device
=
device
)
logits_soft_cap
=
getattr
(
runner
.
model_config
.
hf_config
,
logits_soft_cap
=
getattr
(
self
.
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
"attn_logit_softcapping"
,
None
)
if
len
(
self
.
paged_kv_indptr
)
>
0
:
if
len
(
self
.
paged_kv_indptr
)
>
0
:
...
@@ -394,8 +396,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -394,8 +396,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr_tensor
=
None
paged_kv_indptr_tensor
=
None
paged_kv_last_page_len_tensor
=
None
paged_kv_last_page_len_tensor
=
None
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
runner
.
kv_cache_dtype
,
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
runner
.
model_config
.
dtype
)
self
.
runner
.
kv_cache_dtype
,
self
.
runner
.
model_config
.
dtype
)
return
FlashInferMetadata
(
return
FlashInferMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
slot_mapping
=
slot_mapping_tensor
,
...
@@ -406,11 +408,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -406,11 +408,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr
=
paged_kv_indptr_tensor
,
paged_kv_indptr
=
paged_kv_indptr_tensor
,
paged_kv_indices
=
paged_kv_indices_tensor
,
paged_kv_indices
=
paged_kv_indices_tensor
,
paged_kv_last_page_len
=
paged_kv_last_page_len_tensor
,
paged_kv_last_page_len
=
paged_kv_last_page_len_tensor
,
num_qo_heads
=
runner
.
model_config
.
get_num_attention_heads
(
num_qo_heads
=
self
.
runner
.
model_config
.
get_num_attention_heads
(
runner
.
parallel_config
),
self
.
runner
.
parallel_config
),
num_kv_heads
=
runner
.
model_config
.
get_num_kv_heads
(
num_kv_heads
=
self
.
runner
.
model_config
.
get_num_kv_heads
(
runner
.
parallel_config
),
self
.
runner
.
parallel_config
),
head_dim
=
runner
.
model_config
.
get_head_size
(),
head_dim
=
self
.
runner
.
model_config
.
get_head_size
(),
page_size
=
self
.
block_size
,
page_size
=
self
.
block_size
,
seq_start_loc
=
seq_start_loc
,
seq_start_loc
=
seq_start_loc
,
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
...
...
vllm/attention/backends/utils.py
View file @
e0c15758
...
@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
...
@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
import
torch
import
torch
from
vllm.attention
import
AttentionMetadata
,
AttentionMetadataBuilder
from
vllm.attention
import
AttentionMetadata
,
AttentionMetadataBuilder
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
from
vllm.utils
import
make_tensor_with_pad
# Error string(s) for encoder/decoder
# Error string(s) for encoder/decoder
...
@@ -15,8 +14,7 @@ STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
...
@@ -15,8 +14,7 @@ STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
PAD_SLOT_ID
=
-
1
PAD_SLOT_ID
=
-
1
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
GPUModelRunnerBase
,
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
ModelInputForGPUBuilder
)
def
is_block_tables_empty
(
block_tables
:
Union
[
None
,
Dict
]):
def
is_block_tables_empty
(
block_tables
:
Union
[
None
,
Dict
]):
...
@@ -95,26 +93,27 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -95,26 +93,27 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self
.
num_prefill_tokens
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
self
.
block_size
=
input_builder
.
block_size
self
.
use_v2_block_manager
=
(
self
.
use_v2_block_manager
=
(
input_builder
.
scheduler_config
.
use_v2_block_manager
)
input_builder
.
scheduler_config
.
use_v2_block_manager
)
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
def
_add_seq_group
(
token_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
curr_seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
chunked_prefill_enabled
:
bool
):
context_lens
:
List
[
int
],
is_prompt
=
inter_data
.
is_prompt
curr_sliding_window_blocks
:
List
[
int
],
prefix_cache_hit
,
block_tables
=
inter_data
.
block_tables
chunked_prefill_enabled
):
computed_block_nums
=
inter_data
.
computed_block_nums
is_prompt
=
seq_group_metadata
.
is_prompt
block_tables
=
seq_group_metadata
.
block_tables
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
curr_sliding_window_block
)
in
zip
(
seq_group_metadata
.
seq_data
.
keys
(),
token_lens
,
seq_lens
,
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
curr_seq_lens
,
query_lens
,
context_lens
,
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
curr_sliding_window_blocks
):
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
if
is_prompt
:
self
.
num_prefills
+=
1
self
.
num_prefills
+=
1
...
@@ -132,7 +131,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -132,7 +131,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
# only allowing multiple of block_size chunk size.
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
block_table
=
[]
if
prefix_cache_hit
:
if
inter_data
.
prefix_cache_hit
:
block_table
=
computed_block_nums
block_table
=
computed_block_nums
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
and
block_tables
is
not
None
):
...
@@ -146,16 +145,18 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -146,16 +145,18 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self
.
use_v2_block_manager
)
self
.
use_v2_block_manager
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
self
.
block_size
,
inter_data
.
block_tables
)
seq_group_metadata
.
block_tables
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
def
build
(
self
,
runner
:
"GPUModelRunnerBase"
,
seq_lens
:
List
[
int
],
device
=
self
.
runner
.
device
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
device
=
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
logits_soft_cap
=
getattr
(
runner
.
model_config
.
hf_config
,
logits_soft_cap
=
getattr
(
self
.
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
"attn_logit_softcapping"
,
None
)
if
logits_soft_cap
is
not
None
:
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -176,7 +177,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -176,7 +177,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
# The shape of graph_block_tables is
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
# [max batch size, max context len // block size].
input_block_tables
=
runner
.
graph_block_tables
[:
batch_size
]
input_block_tables
=
self
.
runner
.
graph_block_tables
[:
batch_size
]
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
...
...
vllm/utils.py
View file @
e0c15758
...
@@ -719,6 +719,11 @@ def merge_dicts(dict1: Dict[K, List[T]],
...
@@ -719,6 +719,11 @@ def merge_dicts(dict1: Dict[K, List[T]],
return
dict
(
merged_dict
)
return
dict
(
merged_dict
)
def
flatten_2d_lists
(
lists
:
List
[
List
[
T
]])
->
List
[
T
]:
"""Flatten a list of lists to a single list."""
return
[
item
for
sublist
in
lists
for
item
in
sublist
]
def
init_cached_hf_modules
()
->
None
:
def
init_cached_hf_modules
()
->
None
:
"""
"""
Lazy initialization of the Hugging Face modules.
Lazy initialization of the Hugging Face modules.
...
...
vllm/worker/model_runner.py
View file @
e0c15758
...
@@ -3,7 +3,7 @@ import gc
...
@@ -3,7 +3,7 @@ import gc
import
time
import
time
import
warnings
import
warnings
import
weakref
import
weakref
from
collections
import
defaultdict
from
dataclasses
import
dataclass
,
field
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
,
Union
)
Tuple
,
Type
,
TypeVar
,
Union
)
...
@@ -49,7 +49,8 @@ from vllm.prompt_adapter.worker_manager import (
...
@@ -49,7 +49,8 @@ from vllm.prompt_adapter.worker_manager import (
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
from
vllm.sequence
import
(
IntermediateTensors
,
SamplerOutput
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
get_kv_cache_torch_dtype
,
is_hip
,
from
vllm.utils
import
(
CudaMemoryProfiler
,
flatten_2d_lists
,
get_kv_cache_torch_dtype
,
is_hip
,
is_pin_memory_available
)
is_pin_memory_available
)
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
...
@@ -76,7 +77,7 @@ _NUM_WARMUP_ITERS = 2
...
@@ -76,7 +77,7 @@ _NUM_WARMUP_ITERS = 2
TModelInputForGPU
=
TypeVar
(
'TModelInputForGPU'
,
bound
=
"ModelInputForGPU"
)
TModelInputForGPU
=
TypeVar
(
'TModelInputForGPU'
,
bound
=
"ModelInputForGPU"
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForGPU
(
ModelRunnerInputBase
):
class
ModelInputForGPU
(
ModelRunnerInputBase
):
"""
"""
This base class contains metadata needed for the base model forward pass
This base class contains metadata needed for the base model forward pass
...
@@ -126,7 +127,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
...
@@ -126,7 +127,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
return
cls
(
**
tensor_dict
)
return
cls
(
**
tensor_dict
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForGPUWithSamplingMetadata
(
ModelInputForGPU
):
class
ModelInputForGPUWithSamplingMetadata
(
ModelInputForGPU
):
"""
"""
Used by the ModelRunner.
Used by the ModelRunner.
...
@@ -168,12 +169,84 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
...
@@ -168,12 +169,84 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
class
ModelInputForGPUBuilder
(
ModelRunnerInputBuilderBase
[
ModelInputForGPU
]):
class
ModelInputForGPUBuilder
(
ModelRunnerInputBuilderBase
[
ModelInputForGPU
]):
"""TBA"""
"""Build ModelInputForGPU from SequenceGroupMetadata."""
@
dataclass
class
InterDataForSeqGroup
:
"""Intermediate data for the current sequence group."""
# From sequence group metadata.
request_id
:
str
seq_ids
:
List
[
int
]
is_prompt
:
bool
block_tables
:
Optional
[
Dict
[
int
,
List
[
int
]]]
computed_block_nums
:
List
[
int
]
n_seqs
:
int
=
0
# Input tokens and positions.
input_tokens
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
input_positions
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
# The sequence length (may be capped to the sliding window).
seq_lens
:
List
[
int
]
=
field
(
default_factory
=
list
)
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens
:
List
[
int
]
=
field
(
default_factory
=
list
)
# The query length.
query_lens
:
List
[
int
]
=
field
(
default_factory
=
list
)
# The number of tokens that are already computed.
context_lens
:
List
[
int
]
=
field
(
default_factory
=
list
)
# The current sliding window block.
curr_sliding_window_blocks
:
List
[
int
]
=
field
(
default_factory
=
list
)
# LoRA inputs.
lora_index_mapping
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
lora_prompt_mapping
:
List
[
List
[
int
]]
=
field
(
default_factory
=
list
)
lora_requests
:
Set
[
LoRARequest
]
=
field
(
default_factory
=
set
)
# Prompt adapter inputs.
prompt_adapter_index_mapping
:
List
[
int
]
=
field
(
default_factory
=
list
)
prompt_adapter_prompt_mapping
:
List
[
int
]
=
field
(
default_factory
=
list
)
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
# Multi-modal inputs.
multi_modal_inputs
:
Optional
[
MultiModalInputs
]
=
None
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit
:
bool
=
False
def
__post_init__
(
self
):
self
.
n_seqs
=
len
(
self
.
seq_ids
)
self
.
input_tokens
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
input_positions
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
seq_lens
=
[
0
]
*
self
.
n_seqs
self
.
orig_seq_lens
=
[
0
]
*
self
.
n_seqs
self
.
query_lens
=
[
0
]
*
self
.
n_seqs
self
.
context_lens
=
[
0
]
*
self
.
n_seqs
self
.
curr_sliding_window_blocks
=
[
0
]
*
self
.
n_seqs
self
.
lora_index_mapping
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
self
.
lora_prompt_mapping
=
[[]
for
_
in
range
(
self
.
n_seqs
)]
def
__init__
(
self
,
def
__init__
(
self
,
runner
:
"GPUModelRunnerBase"
,
runner
:
"GPUModelRunnerBase"
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
):
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
):
super
().
__init__
()
super
().
__init__
()
# Compute functions for each sequence in a sequence group.
# WARNING: The order of the functions matters!
self
.
per_seq_compute_fns
=
[
self
.
_compute_lens
,
self
.
_compute_for_prefix_cache_hit
,
self
.
_compute_for_sliding_window
,
self
.
_compute_lora_input
,
]
# Compute functions for each sequence group.
# WARNING: The order of the functions matters!
self
.
per_seq_group_compute_fns
=
[
self
.
_compute_prompt_adapter_input
,
self
.
_compute_multi_modal_input
,
]
self
.
runner
=
runner
self
.
runner
=
runner
self
.
model_input_cls
=
self
.
runner
.
_model_input_cls
self
.
model_input_cls
=
self
.
runner
.
_model_input_cls
self
.
attn_backend
=
self
.
runner
.
attn_backend
self
.
attn_backend
=
self
.
runner
.
attn_backend
...
@@ -187,30 +260,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -187,30 +260,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
finished_requests_ids
=
finished_requests_ids
self
.
finished_requests_ids
=
finished_requests_ids
self
.
decode_only
=
True
self
.
decode_only
=
True
# Common inputs.
# Intermediate data (data in CPU before going to GPU) for
self
.
input_tokens
:
List
[
int
]
=
[]
# the current sequence group.
self
.
input_positions
:
List
[
int
]
=
[]
self
.
inter_data_list
:
List
[
self
.
seq_lens
:
List
[
int
]
=
[]
ModelInputForGPUBuilder
.
InterDataForSeqGroup
]
=
[]
self
.
query_lens
:
List
[
int
]
=
[]
self
.
max_decode_seq_len
:
int
=
0
self
.
request_ids_to_seq_ids
:
Dict
[
str
,
List
[
int
]]
=
defaultdict
(
list
)
# LoRA inputs.
self
.
lora_index_mapping
:
List
[
int
]
=
[]
self
.
lora_prompt_mapping
:
List
[
int
]
=
[]
self
.
lora_requests
:
Set
[
LoRARequest
]
=
set
()
# Prompt adapter inputs.
self
.
prompt_adapter_index_mapping
:
List
[
int
]
=
[]
self
.
prompt_adapter_prompt_mapping
:
List
[
int
]
=
[]
self
.
prompt_adapter_requests
:
Set
[
PromptAdapterRequest
]
=
set
()
# Multi-modal inputs.
self
.
multi_modal_inputs_list
:
List
[
MultiModalInputs
]
=
[]
# Attention metadata inputs.
# Attention metadata inputs.
self
.
attn_metadata_builder
=
self
.
attn_backend
.
make_metadata_builder
(
self
.
attn_metadata_builder
=
self
.
attn_backend
.
make_metadata_builder
(
self
)
weakref
.
proxy
(
self
)
)
# Engine/Model configurations.
# Engine/Model configurations.
self
.
chunked_prefill_enabled
=
(
self
.
chunked_prefill_enabled
=
(
...
@@ -222,175 +279,222 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -222,175 +279,222 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
block_aligned_sliding_window
=
\
self
.
block_aligned_sliding_window
=
\
self
.
sliding_window_blocks
*
self
.
block_size
self
.
sliding_window_blocks
*
self
.
block_size
def
_compute_len_for_sliding_window
(
self
,
seq_len
:
int
):
def
_compute_lens
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
curr_sliding_window_blocks
=
0
seq_group_metadata
:
SequenceGroupMetadata
):
sliding_seq_len
=
seq_len
"""Compute context length, sequence length and tokens
for the given sequence data.
"""
seq_data
=
seq_group_metadata
.
seq_data
[
inter_data
.
seq_ids
[
seq_idx
]]
token_chunk_size
=
seq_group_metadata
.
token_chunk_size
# TODO(sang): This is a hack to make sliding window work with
# Compute context length (the number of tokens that are
# paged attn. We can remove it if we make paged attn kernel
# already computed) and sequence length (total number of tokens).
# to properly handle slinding window attn.
seq_len
=
seq_data
.
get_len
()
if
self
.
sliding_window
is
not
None
:
if
inter_data
.
is_prompt
:
curr_sliding_window_blocks
=
self
.
sliding_window_blocks
context_len
=
seq_data
.
get_num_computed_tokens
()
else
:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
context_len
=
seq_len
-
1
seq_len
=
min
(
seq_len
,
context_len
+
token_chunk_size
)
# Compute tokens.
if
inter_data
.
is_prompt
:
tokens
=
seq_data
.
get_token_ids
()[
context_len
:
seq_len
]
else
:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens
=
[
seq_data
.
get_last_token_id
()]
inter_data
.
seq_lens
[
seq_idx
]
=
seq_len
inter_data
.
orig_seq_lens
[
seq_idx
]
=
seq_len
inter_data
.
context_lens
[
seq_idx
]
=
context_len
inter_data
.
input_tokens
[
seq_idx
]
=
tokens
inter_data
.
input_positions
[
seq_idx
]
=
list
(
range
(
context_len
,
seq_len
))
inter_data
.
query_lens
[
seq_idx
]
=
seq_len
-
context_len
if
inter_data
.
is_prompt
else
1
def
_compute_for_prefix_cache_hit
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""Check if hit prefix cache (i.e., some blocks are already computed).
If hit, update input tokens and positions to only compute the
remaining blocks.
"""
computed_block_nums
=
inter_data
.
computed_block_nums
# Note that prefix caching does not support sliding window.
prefix_cache_hit
=
(
computed_block_nums
is
not
None
and
len
(
computed_block_nums
)
>
0
and
self
.
sliding_window
is
None
and
inter_data
.
is_prompt
)
inter_data
.
prefix_cache_hit
=
prefix_cache_hit
if
self
.
chunked_prefill_enabled
and
prefix_cache_hit
:
raise
RuntimeError
(
"chunked prefill cannot be used with prefix caching now."
)
# If prefix cache is hit, advance context length to bypass
# hit blocks. Accordingly, input tokens, position and query length
# have to be updated.
if
prefix_cache_hit
:
assert
computed_block_nums
is
not
None
context_len
=
len
(
computed_block_nums
)
*
self
.
block_size
inter_data
.
input_tokens
[
seq_idx
]
=
inter_data
.
input_tokens
[
seq_idx
][
context_len
:]
inter_data
.
input_positions
[
seq_idx
]
=
inter_data
.
input_positions
[
seq_idx
][
context_len
:]
inter_data
.
context_lens
[
seq_idx
]
=
context_len
inter_data
.
query_lens
[
seq_idx
]
=
inter_data
.
seq_lens
[
seq_idx
]
-
context_len
def
_compute_for_sliding_window
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""Update seq_len and curr_sliding_window_block for the given
sequence data (only required by decoding) if sliding window is enabled.
"""
curr_sliding_window_block
=
0
sliding_seq_len
=
inter_data
.
seq_lens
[
seq_idx
]
if
not
inter_data
.
is_prompt
and
self
.
sliding_window
is
not
None
:
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
curr_sliding_window_block
=
self
.
sliding_window_blocks
if
self
.
scheduler_config
.
use_v2_block_manager
:
if
self
.
scheduler_config
.
use_v2_block_manager
:
# number of elements in last block
# number of elements in last block
suff_len
=
seq_len
%
self
.
block_size
suff_len
=
inter_data
.
seq_lens
[
seq_idx
]
%
self
.
block_size
sliding_seq_len
=
min
(
sliding_seq_len
=
min
(
seq_len
,
self
.
block_aligned_sliding_window
+
suff_len
)
inter_data
.
seq_lens
[
seq_idx
],
self
.
block_aligned_sliding_window
+
suff_len
)
if
suff_len
>
0
:
if
suff_len
>
0
:
curr_sliding_window_block
s
+=
1
curr_sliding_window_block
+=
1
else
:
else
:
sliding_seq_len
=
min
(
seq_len
,
self
.
sliding_window
)
sliding_seq_len
=
min
(
inter_data
.
seq_lens
[
seq_idx
],
return
curr_sliding_window_blocks
,
sliding_seq_len
self
.
sliding_window
)
inter_data
.
curr_sliding_window_blocks
[
seq_idx
]
=
curr_sliding_window_block
inter_data
.
seq_lens
[
seq_idx
]
=
sliding_seq_len
def
_compute_lora_input
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""If LoRA is enabled, compute LoRA index and prompt mapping."""
if
not
self
.
enable_lora
:
return
lora_id
=
seq_group_metadata
.
lora_int_id
if
lora_id
>
0
:
inter_data
.
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
query_len
=
inter_data
.
query_lens
[
seq_idx
]
inter_data
.
lora_index_mapping
.
append
([
lora_id
]
*
query_len
)
inter_data
.
lora_prompt_mapping
.
append
(
[
lora_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
is
not
None
else
1
))
def
_compute_prompt_adapter_input
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""If prompt adapter is enabled, compute index and prompt mapping.
"""
# Note that when is_prompt=True, we expect only one sequence
# in the group.
if
not
self
.
enable_prompt_adapter
:
return
prompt_adapter_id
=
seq_group_metadata
.
prompt_adapter_id
if
prompt_adapter_id
<=
0
or
not
inter_data
.
is_prompt
:
return
# We expect only one sequence in the group when is_prompt=True.
assert
inter_data
.
n_seqs
==
1
query_len
=
inter_data
.
query_lens
[
0
]
inter_data
.
prompt_adapter_request
=
(
seq_group_metadata
.
prompt_adapter_request
)
num_tokens
=
seq_group_metadata
.
prompt_adapter_num_virtual_tokens
inter_data
.
prompt_adapter_index_mapping
=
[
prompt_adapter_id
]
*
num_tokens
+
[
0
]
*
(
query_len
-
num_tokens
)
inter_data
.
prompt_adapter_prompt_mapping
=
[
prompt_adapter_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
)
def
_compute_multi_modal_input
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""If multi-modal data is given, add it to the input."""
mm_data
=
seq_group_metadata
.
multi_modal_data
if
not
mm_data
:
return
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
inter_data
.
multi_modal_inputs
=
mm_kwargs
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
"""Add a sequence group to the builder."""
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
n_seqs
=
len
(
seq_ids
)
n_seqs
=
len
(
seq_ids
)
is_prompt
=
seq_group_metadata
.
is_prompt
is_prompt
=
seq_group_metadata
.
is_prompt
token_chunk_size
=
seq_group_metadata
.
token_chunk_size
if
is_prompt
:
if
is_prompt
:
assert
n_seqs
==
1
assert
n_seqs
==
1
self
.
decode_only
=
False
self
.
decode_only
=
False
# Mapping from request IDs to sequence IDs. Used for Jamba models
inter_data
=
self
.
InterDataForSeqGroup
(
# that manages the cache by itself.
request_id
=
seq_group_metadata
.
request_id
,
self
.
request_ids_to_seq_ids
[
seq_group_metadata
.
request_id
]
=
[]
seq_ids
=
seq_ids
,
# The number of input tokens in each sequence.
is_prompt
=
is_prompt
,
token_lens
:
List
[
int
]
=
[]
block_tables
=
seq_group_metadata
.
block_tables
,
# The number of tokens that are already computed.
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
)
context_lens
:
List
[
int
]
=
[]
self
.
inter_data_list
.
append
(
inter_data
)
# The current sliding window block for each sequence.
curr_sliding_window_blocks
:
List
[
int
]
=
[]
# The original sequence length (before applying sliding window)
# for each sequence.
orig_seq_lens
:
List
[
int
]
=
[]
# The sequence length (may be capped to the sliding window).
curr_seq_lens
:
List
[
int
]
=
[]
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
self
.
request_ids_to_seq_ids
[
seq_group_metadata
.
request_id
].
append
(
seq_id
)
computed_block_nums
=
seq_group_metadata
.
computed_block_nums
# Check if hit prefix cache (i.e., some blocks are already computed)
# Note that prefix caching does not support sliding window.
prefix_cache_hit
=
(
computed_block_nums
is
not
None
and
len
(
computed_block_nums
)
>
0
and
self
.
sliding_window
is
None
and
is_prompt
)
if
self
.
chunked_prefill_enabled
and
prefix_cache_hit
:
raise
RuntimeError
(
"chunked prefill cannot be used with prefix caching now."
)
# Compute context length (the number of tokens that are
# already computed) and sequence length (total number of tokens).
seq_len
=
seq_data
.
get_len
()
if
is_prompt
:
context_len
=
seq_data
.
get_num_computed_tokens
()
else
:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
context_len
=
seq_len
-
1
seq_len
=
min
(
seq_len
,
context_len
+
token_chunk_size
)
# Compute tokens.
if
is_prompt
:
tokens
=
seq_data
.
get_token_ids
()[
context_len
:
seq_len
]
else
:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens
=
[
seq_data
.
get_last_token_id
()]
if
prefix_cache_hit
:
assert
computed_block_nums
is
not
None
context_len
=
len
(
computed_block_nums
)
*
self
.
block_size
tokens
=
tokens
[
context_len
:]
# These are seq_len/context_len capped to the sliding window.
# They are passed to decode kernel.
# We still need original seq_len/context_len to compute slot
# mapping (and input position) below.
if
is_prompt
:
curr_sliding_window_block
=
0
sliding_seq_len
=
seq_len
query_len
=
seq_len
-
context_len
else
:
curr_sliding_window_block
,
sliding_seq_len
=
(
self
.
_compute_len_for_sliding_window
(
seq_len
))
query_len
=
1
self
.
seq_lens
.
append
(
sliding_seq_len
)
if
not
is_prompt
:
self
.
max_decode_seq_len
=
max
(
self
.
max_decode_seq_len
,
sliding_seq_len
)
self
.
query_lens
.
append
(
query_len
)
self
.
input_tokens
.
extend
(
tokens
)
self
.
input_positions
.
extend
(
list
(
range
(
context_len
,
seq_len
)))
# Intermediate data of the current sequence group for
# the attention metadata.
token_lens
.
append
(
len
(
tokens
))
context_lens
.
append
(
context_len
)
curr_seq_lens
.
append
(
sliding_seq_len
)
curr_sliding_window_blocks
.
append
(
curr_sliding_window_block
)
orig_seq_lens
.
append
(
seq_len
)
# Update attention metadata. Note that input builder attributes
# (self.xxx) include all added sequences, so we need to slice
# the last n_seqs sequences.
self
.
attn_metadata_builder
.
add_seq_group
(
seq_group_metadata
,
token_lens
,
orig_seq_lens
,
curr_seq_lens
,
self
.
query_lens
[
-
n_seqs
:],
context_lens
,
curr_sliding_window_blocks
,
prefix_cache_hit
,
self
.
chunked_prefill_enabled
)
# LoRA data.
if
self
.
enable_lora
:
lora_id
=
seq_group_metadata
.
lora_int_id
for
query_len
in
self
.
query_lens
[
-
n_seqs
:]:
if
lora_id
>
0
:
self
.
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
self
.
lora_index_mapping
+=
[
lora_id
]
*
query_len
self
.
lora_prompt_mapping
.
extend
(
[
lora_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
is
not
None
else
1
))
# Prompt adapter data. Note that when is_prompt=True,
# we expect only one sequence in the group.
if
self
.
enable_prompt_adapter
:
prompt_adapter_id
=
seq_group_metadata
.
prompt_adapter_id
if
prompt_adapter_id
>
0
and
is_prompt
:
query_len
=
self
.
query_lens
[
-
1
]
self
.
prompt_adapter_requests
.
add
(
seq_group_metadata
.
prompt_adapter_request
)
num_tokens
=
seq_group_metadata
.
\
prompt_adapter_num_virtual_tokens
pm
=
[
prompt_adapter_id
]
*
num_tokens
+
[
0
]
*
(
query_len
-
num_tokens
)
self
.
prompt_adapter_index_mapping
+=
pm
self
.
prompt_adapter_prompt_mapping
.
extend
(
[
prompt_adapter_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
# Multi-modal data.
for
seq_idx
in
range
(
n_seqs
):
mm_data
=
seq_group_metadata
.
multi_modal_data
for
per_seq_fn
in
self
.
per_seq_compute_fns
:
if
mm_
data
:
per_seq_fn
(
inter_data
,
seq_idx
,
seq_group_meta
data
)
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
for
per_seq_group_fn
in
self
.
per_seq_group_compute_fns
:
self
.
multi_modal_inputs_list
.
append
(
mm_kwargs
)
per_seq_group_fn
(
inter_data
,
seq_group_metadata
)
def
build
(
self
)
->
ModelInputForGPU
:
def
build
(
self
)
->
ModelInputForGPU
:
if
not
self
.
input_tokens
:
"""Finalize the builder intermediate data and
create on-device tensors.
"""
# Combine and flatten intermediate data.
input_tokens
=
flatten_2d_lists
([
flatten_2d_lists
(
inter_data
.
input_tokens
)
for
inter_data
in
self
.
inter_data_list
])
if
not
input_tokens
:
# This may happen when all prefill requests hit
# prefix caching and there is no decode request.
return
self
.
model_input_cls
()
return
self
.
model_input_cls
()
input_positions
=
flatten_2d_lists
([
flatten_2d_lists
(
inter_data
.
input_positions
)
for
inter_data
in
self
.
inter_data_list
])
seq_lens
=
[]
max_decode_seq_len
=
0
for
inter_data
in
self
.
inter_data_list
:
seq_lens
.
extend
(
inter_data
.
seq_lens
)
if
not
inter_data
.
is_prompt
:
max_decode_seq_len
=
max
(
max_decode_seq_len
,
max
(
inter_data
.
seq_lens
))
query_lens
=
flatten_2d_lists
(
[
inter_data
.
query_lens
for
inter_data
in
self
.
inter_data_list
])
# Mapping from request IDs to sequence IDs. Used for Jamba models
# that manages the cache by itself.
request_ids_to_seq_ids
=
{
data
.
request_id
:
data
.
seq_ids
for
data
in
self
.
inter_data_list
}
batch_size
=
len
(
self
.
input_tokens
)
batch_size
=
len
(
input_tokens
)
use_captured_graph
=
(
use_captured_graph
=
(
self
.
decode_only
and
not
self
.
runner
.
model_config
.
enforce_eager
self
.
decode_only
and
not
self
.
runner
.
model_config
.
enforce_eager
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
batch_size
<=
_BATCH_SIZES_TO_CAPTURE
[
-
1
]
and
self
.
max_decode_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
)
and
max_decode_seq_len
<=
self
.
runner
.
max_seq_len_to_capture
)
# If cuda graph can be used, pad tensors accordingly.
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# See `capture_model` API for more details.
...
@@ -403,60 +507,84 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -403,60 +507,84 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
batch_size
=
graph_batch_size
batch_size
=
graph_batch_size
# Tokens and positions.
# Tokens and positions.
self
.
input_tokens
.
extend
([
0
]
*
cuda_graph_pad_size
)
input_tokens
.
extend
([
0
]
*
cuda_graph_pad_size
)
self
.
input_positions
.
extend
([
0
]
*
cuda_graph_pad_size
)
input_positions
.
extend
([
0
]
*
cuda_graph_pad_size
)
input_tokens_tensor
=
torch
.
tensor
(
self
.
input_tokens
,
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
runner
.
device
)
device
=
self
.
runner
.
device
)
input_positions_tensor
=
torch
.
tensor
(
self
.
input_positions
,
input_positions_tensor
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
runner
.
device
)
device
=
self
.
runner
.
device
)
# Sequence and query lengths.
# Sequence and query lengths.
self
.
seq_lens
.
extend
([
1
]
*
cuda_graph_pad_size
)
seq_lens
.
extend
([
1
]
*
cuda_graph_pad_size
)
# Attention metadata.
# Attention metadata.
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
self
.
runner
,
self
.
seq_lens
,
self
.
query_lens
,
cuda_graph_pad_size
,
seq_lens
,
query_lens
,
cuda_graph_pad_size
,
batch_size
)
batch_size
)
# LoRA data.
# LoRA data.
lora_requests
=
set
()
lora_mapping
=
None
if
self
.
enable_lora
:
if
self
.
enable_lora
:
self
.
lora_index_mapping
.
extend
([
0
]
*
cuda_graph_pad_size
)
lora_requests
=
set
(
r
for
data
in
self
.
inter_data_list
for
r
in
data
.
lora_requests
)
lora_index_mapping
=
flatten_2d_lists
([
flatten_2d_lists
(
inter_data
.
lora_index_mapping
)
for
inter_data
in
self
.
inter_data_list
])
lora_index_mapping
.
extend
([
0
]
*
cuda_graph_pad_size
)
lora_prompt_mapping
=
flatten_2d_lists
([
flatten_2d_lists
(
inter_data
.
lora_prompt_mapping
)
for
inter_data
in
self
.
inter_data_list
])
lora_mapping
=
LoRAMapping
(
lora_mapping
=
LoRAMapping
(
self
.
lora_index_mapping
,
lora_index_mapping
,
self
.
lora_prompt_mapping
,
lora_prompt_mapping
,
)
)
else
:
lora_mapping
=
None
# Prompt adapter data.
# Prompt adapter data.
prompt_adapter_requests
:
Set
[
PromptAdapterRequest
]
=
set
()
prompt_adapter_mapping
=
None
if
self
.
enable_prompt_adapter
:
if
self
.
enable_prompt_adapter
:
self
.
prompt_adapter_index_mapping
.
extend
([
0
]
*
cuda_graph_pad_size
)
prompt_adapter_requests
=
set
(
data
.
prompt_adapter_request
for
data
in
self
.
inter_data_list
if
data
.
prompt_adapter_request
is
not
None
)
prompt_adapter_index_mapping
=
flatten_2d_lists
([
inter_data
.
prompt_adapter_index_mapping
for
inter_data
in
self
.
inter_data_list
])
prompt_adapter_index_mapping
.
extend
([
0
]
*
cuda_graph_pad_size
)
prompt_adapter_prompt_mapping
=
flatten_2d_lists
([
inter_data
.
prompt_adapter_prompt_mapping
for
inter_data
in
self
.
inter_data_list
])
prompt_adapter_mapping
=
PromptAdapterMapping
(
prompt_adapter_mapping
=
PromptAdapterMapping
(
self
.
prompt_adapter_index_mapping
,
prompt_adapter_index_mapping
,
self
.
prompt_adapter_prompt_mapping
,
prompt_adapter_prompt_mapping
,
)
)
else
:
prompt_adapter_mapping
=
None
# Multi-modal data.
# Multi-modal data.
multi_modal_kwargs
=
MultiModalInputs
.
batch
(
multi_modal_inputs_list
=
[
self
.
multi_modal_inputs_list
,
device
=
self
.
runner
.
device
)
data
.
multi_modal_inputs
for
data
in
self
.
inter_data_list
if
data
.
multi_modal_inputs
is
not
None
]
multi_modal_kwargs
=
MultiModalInputs
.
batch
(
multi_modal_inputs_list
,
device
=
self
.
runner
.
device
)
return
self
.
model_input_cls
(
return
self
.
model_input_cls
(
input_tokens
=
input_tokens_tensor
,
input_tokens
=
input_tokens_tensor
,
input_positions
=
input_positions_tensor
,
input_positions
=
input_positions_tensor
,
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
seq_lens
=
self
.
seq_lens
,
seq_lens
=
seq_lens
,
query_lens
=
self
.
query_lens
,
query_lens
=
query_lens
,
lora_mapping
=
lora_mapping
,
lora_mapping
=
lora_mapping
,
lora_requests
=
self
.
lora_requests
,
lora_requests
=
lora_requests
,
multi_modal_kwargs
=
multi_modal_kwargs
,
multi_modal_kwargs
=
multi_modal_kwargs
,
request_ids_to_seq_ids
=
self
.
request_ids_to_seq_ids
,
request_ids_to_seq_ids
=
request_ids_to_seq_ids
,
finished_requests_ids
=
self
.
finished_requests_ids
,
finished_requests_ids
=
self
.
finished_requests_ids
,
prompt_adapter_mapping
=
prompt_adapter_mapping
,
prompt_adapter_mapping
=
prompt_adapter_mapping
,
prompt_adapter_requests
=
self
.
prompt_adapter_requests
)
prompt_adapter_requests
=
prompt_adapter_requests
)
class
GPUModelRunnerBase
(
ModelRunnerBase
[
TModelInputForGPU
]):
class
GPUModelRunnerBase
(
ModelRunnerBase
[
TModelInputForGPU
]):
...
@@ -1393,15 +1521,3 @@ def _get_graph_batch_size(batch_size: int) -> int:
...
@@ -1393,15 +1521,3 @@ def _get_graph_batch_size(batch_size: int) -> int:
else
:
else
:
return
((
batch_size
+
_BATCH_SIZE_ALIGNMENT
-
1
)
//
return
((
batch_size
+
_BATCH_SIZE_ALIGNMENT
-
1
)
//
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
_BATCH_SIZE_ALIGNMENT
*
_BATCH_SIZE_ALIGNMENT
)
def
_is_block_tables_empty
(
block_tables
:
Union
[
None
,
Dict
]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if
block_tables
is
None
:
return
True
if
isinstance
(
block_tables
,
dict
)
and
all
(
value
is
None
for
value
in
block_tables
.
values
()):
return
True
return
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment