Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f6bb18fd
Unverified
Commit
f6bb18fd
authored
Mar 05, 2025
by
Lucas Wilkinson
Committed by
GitHub
Mar 05, 2025
Browse files
[BugFix] MLA + V1, illegal memory access and accuracy issues (#14253)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
71eaf896
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
334 additions
and
161 deletions
+334
-161
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+89
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+2
-2
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+178
-127
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+33
-25
vllm/v1/attention/backends/mla/triton_mla.py
vllm/v1/attention/backends/mla/triton_mla.py
+5
-2
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+23
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+4
-2
No files found.
tests/v1/worker/test_gpu_input_batch.py
View file @
f6bb18fd
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
inspect
from
typing
import
Optional
from
typing
import
Optional
import
numpy
as
np
import
numpy
as
np
...
@@ -9,7 +10,8 @@ import torch
...
@@ -9,7 +10,8 @@ import torch
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
(
BlockTable
,
CachedRequestState
,
InputBatch
)
VOCAB_SIZE
=
1024
VOCAB_SIZE
=
1024
NUM_OUTPUT_TOKENS
=
20
NUM_OUTPUT_TOKENS
=
20
...
@@ -20,6 +22,34 @@ CUDA_DEVICES = [
...
@@ -20,6 +22,34 @@ CUDA_DEVICES = [
MAX_NUM_PROMPT_TOKENS
=
64
MAX_NUM_PROMPT_TOKENS
=
64
def
_compare_objs
(
obj1
,
obj2
):
attrs
=
inspect
.
getmembers
(
obj1
,
lambda
a
:
not
(
inspect
.
isroutine
(
a
)))
attr_names
=
set
([
a
[
0
]
for
a
in
attrs
if
not
(
a
[
0
].
startswith
(
'__'
)
and
a
[
0
].
endswith
(
'__'
))
])
for
attr_name
in
attr_names
:
a
=
getattr
(
obj1
,
attr_name
)
b
=
getattr
(
obj2
,
attr_name
)
is_same
=
False
if
isinstance
(
a
,
torch
.
Tensor
):
if
(
a
.
numel
()
==
0
or
b
.
numel
()
==
0
):
is_same
=
(
a
.
numel
()
==
0
and
b
.
numel
()
==
0
)
elif
torch
.
allclose
(
a
,
b
):
is_same
=
True
elif
isinstance
(
a
,
np
.
ndarray
):
if
np
.
allclose
(
a
,
b
):
is_same
=
True
elif
isinstance
(
a
,
(
BlockTable
,
SamplingMetadata
)):
_compare_objs
(
a
,
b
)
is_same
=
True
# if we make it here must be same
elif
a
==
b
:
is_same
=
True
assert
is_same
,
f
"Attribute
{
attr_name
}
is different"
\
f
" in
{
obj1
}
and
{
obj2
}
:
{
a
}
!=
{
b
}
"
def
_remove_requests
(
def
_remove_requests
(
input_batch
:
InputBatch
,
batch_size
:
int
,
input_batch
:
InputBatch
,
batch_size
:
int
,
reqs
:
list
[
CachedRequestState
])
->
tuple
[
set
[
str
],
list
[
int
]]:
reqs
:
list
[
CachedRequestState
])
->
tuple
[
set
[
str
],
list
[
int
]]:
...
@@ -254,3 +284,61 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
...
@@ -254,3 +284,61 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
assert
torch
.
allclose
(
assert
torch
.
allclose
(
expected_sampling_metadata
.
allowed_token_ids_mask
,
expected_sampling_metadata
.
allowed_token_ids_mask
,
sampling_metadata
.
allowed_token_ids_mask
)
sampling_metadata
.
allowed_token_ids_mask
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"swap_list"
,
[((
0
,
1
),
)])
def
test_swap_states_in_input_batch
(
device
:
str
,
batch_size
:
int
,
swap_list
:
list
):
"""
Tests the logic for managing sampling metadata in the InputBatch.
This test involves adding a set of requests to the InputBatch,
followed by removing a subset of them. Afterward, the batch is compacted,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
"""
input_batch
:
InputBatch
=
InputBatch
(
max_num_reqs
=
batch_size
,
max_model_len
=
1024
,
max_num_blocks_per_req
=
10
,
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
)
ref_input_batch
:
InputBatch
=
InputBatch
(
max_num_reqs
=
batch_size
,
max_model_len
=
1024
,
max_num_blocks_per_req
=
10
,
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
)
reqs
:
list
[
CachedRequestState
]
=
[]
req_id_reqs
=
{}
req_id_output_token_ids
=
{}
# Add requests
for
req_index
in
range
(
batch_size
):
req
:
CachedRequestState
=
_construct_cached_request_state
(
req_index
)
input_batch
.
add_request
(
req
,
req_index
)
reqs
.
append
(
req
)
req_id_reqs
[
req
.
req_id
]
=
req
req_id_output_token_ids
[
req
.
req_id
]
=
req
.
output_token_ids
reordered_reqs
=
reqs
.
copy
()
for
swap_pair
in
swap_list
:
reordered_reqs
[
swap_pair
[
0
]],
reordered_reqs
[
swap_pair
[
1
]]
=
\
reordered_reqs
[
swap_pair
[
1
]],
reordered_reqs
[
swap_pair
[
0
]]
input_batch
.
swap_states
(
swap_pair
[
0
],
swap_pair
[
1
])
for
req_index
in
range
(
batch_size
):
req
=
reordered_reqs
[
req_index
]
ref_input_batch
.
add_request
(
req
,
req_index
)
input_batch
.
refresh_sampling_metadata
()
ref_input_batch
.
refresh_sampling_metadata
()
_compare_objs
(
input_batch
,
ref_input_batch
)
vllm/v1/attention/backends/flash_attn.py
View file @
f6bb18fd
...
@@ -100,8 +100,8 @@ class FlashAttentionMetadataBuilder:
...
@@ -100,8 +100,8 @@ class FlashAttentionMetadataBuilder:
self
.
runner
=
runner
self
.
runner
=
runner
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
):
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
pass
return
False
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
):
common_prefix_len
:
int
):
...
...
vllm/v1/attention/backends/mla/common.py
View file @
f6bb18fd
...
@@ -275,17 +275,47 @@ class MLACommonBackend(AttentionBackend):
...
@@ -275,17 +275,47 @@ class MLACommonBackend(AttentionBackend):
@
dataclass
@
dataclass
class
MLACommonMetadata
:
class
MLACommon
Prefill
Metadata
:
"""
Metadata for MLACommon.
"""
Prefill Specific Metadata """
NOTE: Please read the comment at the top of the file before trying to
@
dataclass
understand this class
class
ChunkedContextMetadata
:
"""
# New for MLA (compared to FlashAttention)
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens
:
torch
.
Tensor
starts
:
torch
.
Tensor
seq_tot
:
list
[
int
]
max_seq_lens
:
list
[
int
]
workspace
:
torch
.
Tensor
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions
:
torch
.
Tensor
block_table
:
torch
.
Tensor
query_start_loc
:
torch
.
Tensor
max_query_len
:
int
chunked_context
:
Optional
[
ChunkedContextMetadata
]
=
None
@
dataclass
class
MLACommonDecodeMetadata
:
# Input positions for rotrary embeddings since for MLA the rotary
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
# position embeddings are applied inside the attention backend
input_positions
:
torch
.
Tensor
input_positions
:
torch
.
Tensor
block_table
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
D
=
TypeVar
(
"D"
,
bound
=
MLACommonDecodeMetadata
)
@
dataclass
class
MLACommonMetadata
(
Generic
[
D
]):
"""Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |---------------- N iteration ---------------------|
...
@@ -295,30 +325,23 @@ class MLACommonMetadata:
...
@@ -295,30 +325,23 @@ class MLACommonMetadata:
# |-- query_len ---|
# |-- query_len ---|
num_actual_tokens
:
int
# Number of tokens excluding padding.
num_actual_tokens
:
int
# Number of tokens excluding padding.
max_query_len
:
int
query_start_loc
:
torch
.
Tensor
query_start_loc
:
torch
.
Tensor
max_seq_len
:
int
seq_lens
:
torch
.
Tensor
block_table
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes
:
int
num_decode_tokens
:
int
num_prefills
:
int
# For logging.
# For logging.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
# The dimension of the attention heads
# The dimension of the attention heads
head_dim
:
Optional
[
int
]
=
None
head_dim
:
Optional
[
int
]
=
None
# New for MLA (compared to FlashAttention)
decode
:
Optional
[
D
]
=
None
# For chunked prefill
prefill
:
Optional
[
MLACommonPrefillMetadata
]
=
None
num_decodes
:
Optional
[
int
]
=
None
num_decode_tokens
:
Optional
[
int
]
=
None
num_prefills
:
Optional
[
int
]
=
None
has_context
:
bool
=
False
context_chunk_cu_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
context_chunk_starts
:
Optional
[
torch
.
Tensor
]
=
None
context_chunk_seq_tot
:
Optional
[
list
[
int
]]
=
None
context_chunk_max_seq_lens
:
Optional
[
list
[
int
]]
=
None
chunked_prefill_workspace
:
Optional
[
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
supported_head_sizes
=
MLACommonBackend
.
get_supported_head_sizes
()
supported_head_sizes
=
MLACommonBackend
.
get_supported_head_sizes
()
...
@@ -329,10 +352,10 @@ class MLACommonMetadata:
...
@@ -329,10 +352,10 @@ class MLACommonMetadata:
f
"received
{
self
.
head_dim
}
."
)
f
"received
{
self
.
head_dim
}
."
)
T
=
TypeVar
(
"
T
"
,
bound
=
MLACommonMetadata
)
M
=
TypeVar
(
"
M
"
,
bound
=
MLACommonMetadata
)
class
MLACommonMetadataBuilder
(
Generic
[
T
]):
class
MLACommonMetadataBuilder
(
Generic
[
M
]):
"""
"""
NOTE: Please read the comment at the top of the file before trying to
NOTE: Please read the comment at the top of the file before trying to
understand this class
understand this class
...
@@ -340,8 +363,9 @@ class MLACommonMetadataBuilder(Generic[T]):
...
@@ -340,8 +363,9 @@ class MLACommonMetadataBuilder(Generic[T]):
def
__init__
(
self
,
def
__init__
(
self
,
runner
:
"GPUModelRunner"
,
runner
:
"GPUModelRunner"
,
cls
:
Optional
[
type
[
T
]]
=
None
):
metadata_cls
:
Optional
[
type
[
M
]]
=
None
):
self
.
cls
=
cls
if
cls
is
not
None
else
MLACommonMetadata
self
.
metadata_cls
=
metadata_cls
\
if
metadata_cls
is
not
None
else
MLACommonMetadata
self
.
runner
=
runner
self
.
runner
=
runner
scheduler_config
=
runner
.
scheduler_config
scheduler_config
=
runner
.
scheduler_config
model_config
=
runner
.
model_config
model_config
=
runner
.
model_config
...
@@ -375,7 +399,7 @@ class MLACommonMetadataBuilder(Generic[T]):
...
@@ -375,7 +399,7 @@ class MLACommonMetadataBuilder(Generic[T]):
self
.
page_size
=
self
.
runner
.
block_size
self
.
page_size
=
self
.
runner
.
block_size
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
):
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
# We now want to reorder the batch so that the "decode" requests are and
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
...
@@ -413,6 +437,7 @@ class MLACommonMetadataBuilder(Generic[T]):
...
@@ -413,6 +437,7 @@ class MLACommonMetadataBuilder(Generic[T]):
num_decodes
=
len
(
decodes
)
num_decodes
=
len
(
decodes
)
num_prefills
=
len
(
prefills
)
num_prefills
=
len
(
prefills
)
first_prefill
=
0
first_prefill
=
0
modified_batch
=
False
for
i
in
range
(
1
,
min
(
num_decodes
,
num_prefills
)
+
1
):
for
i
in
range
(
1
,
min
(
num_decodes
,
num_prefills
)
+
1
):
# If the decode is at the "back" of the batch, i, we can swap it
# If the decode is at the "back" of the batch, i, we can swap it
...
@@ -421,6 +446,7 @@ class MLACommonMetadataBuilder(Generic[T]):
...
@@ -421,6 +446,7 @@ class MLACommonMetadataBuilder(Generic[T]):
input_batch
.
swap_states
(
prefills
[
first_prefill
],
input_batch
.
swap_states
(
prefills
[
first_prefill
],
decodes
[
num_decodes
-
i
])
decodes
[
num_decodes
-
i
])
first_prefill
+=
1
first_prefill
+=
1
modified_batch
=
True
else
:
else
:
break
break
...
@@ -432,10 +458,21 @@ class MLACommonMetadataBuilder(Generic[T]):
...
@@ -432,10 +458,21 @@ class MLACommonMetadataBuilder(Generic[T]):
self
.
_num_decode_tokens
=
num_decode_tokens
self
.
_num_decode_tokens
=
num_decode_tokens
self
.
_num_prefill_tokens
=
num_prefill_tokens
self
.
_num_prefill_tokens
=
num_prefill_tokens
return
modified_batch
def
_build_decode
(
self
,
input_positions
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
return
MLACommonDecodeMetadata
(
input_positions
=
input_positions
,
block_table
=
block_table
,
seq_lens
=
seq_lens
,
)
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
)
->
T
:
common_prefix_len
:
int
)
->
M
:
assert
self
.
_num_decodes
+
self
.
_num_prefills
==
num_reqs
device
=
self
.
runner
.
device
device
=
self
.
runner
.
device
max_seq_len
=
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
()
query_start_loc
=
self
.
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
].
to
(
query_start_loc
=
self
.
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
].
to
(
device
,
non_blocking
=
True
)
device
,
non_blocking
=
True
)
seq_lens
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
].
to
(
device
,
seq_lens
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
].
to
(
device
,
...
@@ -447,85 +484,103 @@ class MLACommonMetadataBuilder(Generic[T]):
...
@@ -447,85 +484,103 @@ class MLACommonMetadataBuilder(Generic[T]):
input_positions
=
self
.
runner
.
positions_cpu
[:
num_actual_tokens
].
to
(
input_positions
=
self
.
runner
.
positions_cpu
[:
num_actual_tokens
].
to
(
device
,
non_blocking
=
True
).
long
()
device
,
non_blocking
=
True
).
long
()
context_chunk_cu_seq_lens
=
None
prefill_metadata
=
None
context_chunk_starts
=
None
if
self
.
_num_prefills
>
0
:
context_chunk_seq_tot
=
None
reqs_start
=
self
.
_num_decodes
# prefill_start
context_chunk_max_seq_lens
=
None
tokens_start
=
self
.
_num_decode_tokens
num_computed_tokens_cpu_tensor
=
\
context_lens_cpu
=
self
.
runner
.
input_batch
.
\
self
.
runner
.
input_batch
.
num_computed_tokens_cpu_tensor
[:
num_reqs
]
num_computed_tokens_cpu_tensor
[
reqs_start
:
num_reqs
]
context_lens_tensor
=
\
context_lens
=
context_lens_cpu
.
to
(
device
,
non_blocking
=
True
)
num_computed_tokens_cpu_tensor
.
to
(
device
,
non_blocking
=
True
)
chunked_context_metadata
=
None
if
self
.
chunked_prefill_enabled
and
self
.
_num_prefills
>
0
\
if
self
.
chunked_prefill_enabled
and
self
.
_num_prefills
>
0
\
and
context_lens_tensor
[
self
.
_num_decodes
:].
max
()
>
0
:
and
context_lens
.
max
()
>
0
:
# NOTE: it is recommend you read the `Chunked Prefill` section in
# NOTE: it is recommend you read the `Chunked Prefill` section
# the comment at the top of the file before trying to understand
# in the comment at the top of the file before trying to
# the following code
# understand the following code
self
.
has_context
=
True
num_prefills_with_context
=
\
num_prefills_with_context
=
(
context_lens
>
0
).
sum
().
item
()
(
context_lens_tensor
[
self
.
_num_decodes
:]
>
0
).
sum
().
item
()
# currently we allocate an equal amount of workspace for each
# currently we allocate an equal amount of workspace for each
# prefill in the batch, we could probably use a more advanced
# prefill in the batch, we could probably use a more advanced
# algorithm here and allocate more workspace to prefills with
# algorithm here and allocate more workspace to prefills with
# longer context lengths
# longer context lengths
max_context_chunk
=
\
max_context_chunk
=
\
self
.
chunked_prefill_workspace_size
//
num_prefills_with_context
self
.
chunked_prefill_workspace_size
\
//
num_prefills_with_context
# align max_context_chunk to page_size by rounding down,
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
# `context_chunk_starts` that are not aligned to page_size
max_context_chunk
=
round_down
(
max_context_chunk
,
self
.
page_size
)
max_context_chunk
=
round_down
(
max_context_chunk
,
self
.
page_size
)
assert
max_context_chunk
>
0
assert
max_context_chunk
>
0
num_chunks
=
cdiv
(
context_lens
_tensor
.
max
(),
max_context_chunk
)
num_chunks
=
cdiv
(
context_lens
.
max
(),
max_context_chunk
)
# if `max_context_chunk = 256`, `num_chunks = 3`, and
# if `max_context_chunk = 256`, `num_chunks = 3`, and
# `num_prefills_with_context = 4`, create a tensor that looks like
# `num_prefills_with_context = 4`, create a tensor that looks
# like
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
context_
chunk_starts
=
\
chunk_starts
=
\
torch
.
arange
(
num_chunks
,
device
=
device
,
dtype
=
torch
.
int32
)
\
torch
.
arange
(
num_chunks
,
device
=
device
,
dtype
=
torch
.
int32
)
\
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
_num_prefills
)
\
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
_num_prefills
)
\
*
max_context_chunk
*
max_context_chunk
chunk_ends
=
torch
.
min
(
context_lens
_tensor
[
self
.
_num_decodes
:]
\
chunk_ends
=
torch
.
min
(
context_lens
.
unsqueeze
(
0
),
.
unsqueeze
(
0
),
context_
chunk_starts
+
max_context_chunk
)
chunk_starts
+
max_context_chunk
)
chunk_seq_lens
=
(
chunk_ends
-
context_
chunk_starts
).
clamp
(
min
=
0
)
chunk_seq_lens
=
(
chunk_ends
-
chunk_starts
).
clamp
(
min
=
0
)
_context
_chunk_cu_seq_lens
=
chunk_seq_lens
.
cumsum
(
dim
=
1
).
to
(
_chunk_cu_seq_lens
=
chunk_seq_lens
.
cumsum
(
dim
=
1
).
to
(
torch
.
int32
)
torch
.
int32
)
zero
=
torch
.
zeros
(
num_chunks
,
dtype
=
torch
.
int32
,
device
=
device
)
\
zero
=
torch
.
zeros
(
num_chunks
,
.
unsqueeze
(
-
1
)
dtype
=
torch
.
int32
,
context_chunk_cu_seq_lens
=
\
device
=
device
).
unsqueeze
(
-
1
)
torch
.
cat
([
zero
,
_context_chunk_cu_seq_lens
],
dim
=
1
)
context_chunk_max_seq_lens
=
\
chunked_context_metadata
=
\
chunk_seq_lens
.
max
(
dim
=
1
).
values
.
tolist
()
MLACommonPrefillMetadata
.
ChunkedContextMetadata
(
context_chunk_seq_tot
=
chunk_seq_lens
.
sum
(
dim
=
1
).
tolist
()
cu_seq_lens
=
torch
.
cat
(
assert
max
(
context_chunk_seq_tot
)
<=
\
[
zero
,
_chunk_cu_seq_lens
],
dim
=
1
),
starts
=
chunk_starts
,
seq_tot
=
chunk_seq_lens
.
sum
(
dim
=
1
).
tolist
(),
max_seq_lens
=
chunk_seq_lens
.
max
(
dim
=
1
).
values
.
tolist
(),
workspace
=
self
.
chunked_prefill_workspace
,
)
assert
max
(
chunked_context_metadata
.
max_seq_lens
)
<=
\
self
.
chunked_prefill_workspace_size
self
.
chunked_prefill_workspace_size
return
self
.
cls
(
prefill_metadata
=
MLACommonPrefillMetadata
(
input_positions
=
input_positions
,
input_positions
=
input_positions
[
tokens_start
:],
block_table
=
block_table
[
reqs_start
:,
...],
query_start_loc
=
query_start_loc
[
reqs_start
:]
-
query_start_loc
[
reqs_start
],
max_query_len
=
seq_lens
[
reqs_start
:].
max
().
item
(),
chunked_context
=
chunked_context_metadata
,
)
decode_metadata
=
None
if
self
.
_num_decodes
>
0
:
decode_metadata
=
self
.
_build_decode
(
input_positions
=
input_positions
[:
self
.
_num_decode_tokens
],
block_table
=
block_table
[:
self
.
_num_decodes
,
...],
seq_lens
=
seq_lens
[:
self
.
_num_decodes
],
)
return
self
.
metadata_cls
(
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
block_table
=
block_table
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
head_dim
=
self
.
runner
.
model_config
.
get_head_size
(),
head_dim
=
self
.
runner
.
model_config
.
get_head_size
(),
# MLACommonMetadata Chunk prefill specific
# MLACommonMetadata Chunk prefill specific
num_decodes
=
self
.
_num_decodes
,
num_decodes
=
self
.
_num_decodes
,
num_decode_tokens
=
self
.
_num_decode_tokens
,
num_decode_tokens
=
self
.
_num_decode_tokens
,
num_prefills
=
self
.
_num_prefills
,
num_prefills
=
self
.
_num_prefills
,
context_chunk_cu_seq_lens
=
context_chunk_cu_seq_lens
,
prefill
=
prefill_metadata
,
context_chunk_starts
=
context_chunk_starts
,
decode
=
decode_metadata
,
context_chunk_seq_tot
=
context_chunk_seq_tot
,
context_chunk_max_seq_lens
=
context_chunk_max_seq_lens
,
)
)
class
MLACommonImpl
(
MLAAttentionImpl
[
T
],
Generic
[
T
]):
class
MLACommonImpl
(
MLAAttentionImpl
[
M
],
Generic
[
M
]):
"""
"""
NOTE: Please read the comment at the top of the file before trying to
NOTE: Please read the comment at the top of the file before trying to
understand this class
understand this class
...
@@ -798,28 +853,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -798,28 +853,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
attn_metadata
:
MLACommonMetadata
,
):
):
assert
attn_metadata
.
num_prefills
is
not
None
assert
attn_metadata
.
prefill
is
not
None
assert
attn_metadata
.
context_chunk_seq_tot
is
not
None
prefill_metadata
=
attn_metadata
.
prefill
assert
attn_metadata
.
context_chunk_cu_seq_lens
is
not
None
assert
prefill_metadata
.
chunked_context
is
not
None
assert
attn_metadata
.
context_chunk_starts
is
not
None
assert
attn_metadata
.
context_chunk_max_seq_lens
is
not
None
output
=
None
output
=
None
iters
=
len
(
attn_metadata
.
context_chunk_seq_tot
)
iters
=
len
(
prefill_metadata
.
chunked_context
.
seq_tot
)
workspace
=
prefill_metadata
.
chunked_context
.
workspace
assert
attn_metadata
.
chunked_prefill_workspace
is
not
None
workspace
=
attn_metadata
.
chunked_prefill_workspace
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
toks
=
attn
_metadata
.
c
ontext_chunk_
seq_tot
[
i
]
toks
=
prefill
_metadata
.
c
hunked_context
.
seq_tot
[
i
]
ops
.
gather_cache
(
ops
.
gather_cache
(
src_cache
=
kv_c_and_k_pe_cache
,
src_cache
=
kv_c_and_k_pe_cache
,
dst
=
workspace
,
dst
=
workspace
,
block_table
=
attn
_metadata
.
block_table
,
block_table
=
prefill
_metadata
.
block_table
,
cu_seq_lens
=
attn
_metadata
.
c
ontext_chunk_
cu_seq_lens
[
i
],
cu_seq_lens
=
prefill
_metadata
.
c
hunked_context
.
cu_seq_lens
[
i
],
batch_size
=
attn_metadata
.
num_prefills
,
batch_size
=
attn_metadata
.
num_prefills
,
seq_starts
=
attn
_metadata
.
c
ontext_chunk_
starts
[
i
],
seq_starts
=
prefill
_metadata
.
c
hunked_context
.
starts
[
i
],
)
)
kv_c_normed
=
workspace
[:
toks
]
\
kv_c_normed
=
workspace
[:
toks
]
\
...
@@ -845,10 +896,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -845,10 +896,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v_padded
,
v
=
v_padded
,
cu_seqlens_q
=
attn
_metadata
.
query_start_loc
,
cu_seqlens_q
=
prefill
_metadata
.
query_start_loc
,
cu_seqlens_k
=
attn
_metadata
.
c
ontext_chunk_
cu_seq_lens
[
i
],
cu_seqlens_k
=
prefill
_metadata
.
c
hunked_context
.
cu_seq_lens
[
i
],
max_seqlen_q
=
attn
_metadata
.
max_query_len
,
max_seqlen_q
=
prefill
_metadata
.
max_query_len
,
max_seqlen_k
=
attn
_metadata
.
c
ontext_chunk_
max_seq_lens
[
i
],
max_seqlen_k
=
prefill
_metadata
.
c
hunked_context
.
max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
causal
=
False
,
# Context is unmasked
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
...
@@ -881,7 +932,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -881,7 +932,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
attn_metadata
:
MLACommonMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
has_context
=
attn_metadata
.
has_context
assert
attn_metadata
.
prefill
is
not
None
has_context
=
attn_metadata
.
prefill
.
chunked_context
is
not
None
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv_nope
\
k_nope
,
v
=
kv_nope
\
...
@@ -898,10 +951,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -898,10 +951,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v_padded
,
v
=
v_padded
,
cu_seqlens_q
=
attn_metadata
.
query_start_loc
,
cu_seqlens_q
=
attn_metadata
.
prefill
.
query_start_loc
,
cu_seqlens_k
=
attn_metadata
.
query_start_loc
,
cu_seqlens_k
=
attn_metadata
.
prefill
.
query_start_loc
,
max_seqlen_q
=
attn_metadata
.
max_query_len
,
max_seqlen_q
=
attn_metadata
.
prefill
.
max_query_len
,
max_seqlen_k
=
attn_metadata
.
max_seq
_len
,
max_seqlen_k
=
attn_metadata
.
prefill
.
max_query
_len
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
return_softmax_lse
=
has_context
,
return_softmax_lse
=
has_context
,
...
@@ -934,7 +987,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -934,7 +987,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
q_nope
:
torch
.
Tensor
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
attn_metadata
:
M
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -945,7 +998,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -945,7 +998,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k_c_normed
:
torch
.
Tensor
,
# key in unified attn
k_c_normed
:
torch
.
Tensor
,
# key in unified attn
k_pe
:
torch
.
Tensor
,
# value in unified attn
k_pe
:
torch
.
Tensor
,
# value in unified attn
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
attn_metadata
:
M
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -966,7 +1019,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -966,7 +1019,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# Restore head dim (for rotary embedding)
# Restore head dim (for rotary embedding)
k_pe
=
k_pe
.
unsqueeze
(
1
)
k_pe
=
k_pe
.
unsqueeze
(
1
)
assert
hasattr
(
attn_metadata
,
"input_positions"
)
assert
attn_metadata
.
num_decodes
is
not
None
and
\
assert
attn_metadata
.
num_decodes
is
not
None
and
\
attn_metadata
.
num_prefills
is
not
None
and
\
attn_metadata
.
num_prefills
is
not
None
and
\
...
@@ -978,28 +1030,27 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -978,28 +1030,27 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
decode_hs_or_q_c
=
hidden_states_or_q_c
[:
num_decode_tokens
]
decode_hs_or_q_c
=
hidden_states_or_q_c
[:
num_decode_tokens
]
decode_k_pe
=
k_pe
[:
num_decode_tokens
]
decode_k_pe
=
k_pe
[:
num_decode_tokens
]
decode_input_positions
=
\
attn_metadata
.
input_positions
[:
num_decode_tokens
]
prefill_hs_or_q_c
=
hidden_states_or_q_c
[
num_decode_tokens
:]
prefill_hs_or_q_c
=
hidden_states_or_q_c
[
num_decode_tokens
:]
prefill_k_pe
=
k_pe
[
num_decode_tokens
:]
prefill_k_pe
=
k_pe
[
num_decode_tokens
:]
prefill_input_positions
=
\
attn_metadata
.
input_positions
[
num_decode_tokens
:]
prefill_k_c_normed
=
k_c_normed
[
num_decode_tokens
:]
prefill_k_c_normed
=
k_c_normed
[
num_decode_tokens
:]
if
has_decode
:
if
has_decode
:
assert
attn_metadata
.
decode
is
not
None
decode_q_nope
=
self
.
_q_proj_and_k_up_proj
(
decode_hs_or_q_c
)
decode_q_nope
=
self
.
_q_proj_and_k_up_proj
(
decode_hs_or_q_c
)
decode_q_pe
=
torch
.
matmul
(
decode_hs_or_q_c
,
self
.
W_QR
)
\
decode_q_pe
=
torch
.
matmul
(
decode_hs_or_q_c
,
self
.
W_QR
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_rope_head_dim
)
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_rope_head_dim
)
decode_q_pe
[...],
decode_k_pe
[...]
=
self
.
rotary_emb
(
decode_q_pe
[...],
decode_k_pe
[...]
=
self
.
rotary_emb
(
decode
_
input_positions
,
decode_q_pe
,
decode_k_pe
)
attn_metadata
.
decode
.
input_positions
,
decode_q_pe
,
decode_k_pe
)
if
has_prefill
:
if
has_prefill
:
assert
attn_metadata
.
prefill
is
not
None
prefill_q
=
self
.
q_proj
(
prefill_hs_or_q_c
)[
0
]
\
prefill_q
=
self
.
q_proj
(
prefill_hs_or_q_c
)[
0
]
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
prefill_q_pe
=
prefill_q
[...,
self
.
qk_nope_head_dim
:]
prefill_q_pe
=
prefill_q
[...,
self
.
qk_nope_head_dim
:]
prefill_q_pe
[...],
prefill_k_pe
[...]
=
self
.
rotary_emb
(
prefill_q_pe
[...],
prefill_k_pe
[...]
=
self
.
rotary_emb
(
prefill_input_positions
,
prefill_q_pe
,
prefill_k_pe
)
attn_metadata
.
prefill
.
input_positions
,
prefill_q_pe
,
prefill_k_pe
)
# write the latent and rope to kv cache
# write the latent and rope to kv cache
if
kv_cache
.
numel
()
>
0
:
if
kv_cache
.
numel
()
>
0
:
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
f6bb18fd
...
@@ -11,6 +11,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
...
@@ -11,6 +11,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
is_flashmla_supported
)
is_flashmla_supported
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
MLACommonDecodeMetadata
,
MLACommonImpl
,
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadata
,
MLACommonMetadataBuilder
)
MLACommonMetadataBuilder
)
...
@@ -38,34 +39,41 @@ class FlashMLABackend(MLACommonBackend):
...
@@ -38,34 +39,41 @@ class FlashMLABackend(MLACommonBackend):
@
dataclass
@
dataclass
class
FlashMLAMetadata
(
MLACommonMetadata
):
class
FlashMLADecodeMetadata
(
MLACommonDecodeMetadata
):
decode_tile_scheduler_metadata
:
Optional
[
tuple
[
torch
.
Tensor
,
tile_scheduler_metadata
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
torch
.
Tensor
]]
=
None
num_splits
:
torch
.
Tensor
decode_num_splits
:
Optional
[
torch
.
Tensor
]
=
None
@
dataclass
class
FlashMLAMetadata
(
MLACommonMetadata
[
FlashMLADecodeMetadata
]):
pass
class
FlashMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
FlashMLAMetadata
]):
class
FlashMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
FlashMLAMetadata
]):
def
__init__
(
self
,
runner
):
def
__init__
(
self
,
runner
):
super
().
__init__
(
runner
,
cls
=
FlashMLAMetadata
)
super
().
__init__
(
runner
)
self
.
num_q_heads
=
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
num_q_heads
=
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
runner
.
parallel_config
)
self
.
runner
.
parallel_config
)
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
def
_build_decode
(
self
,
input_positions
:
torch
.
Tensor
,
common_prefix_len
:
int
):
block_table
:
torch
.
Tensor
,
m
=
super
().
build
(
num_reqs
,
num_actual_tokens
,
max_query_len
,
seq_lens
:
torch
.
Tensor
)
->
FlashMLADecodeMetadata
:
common_prefix_len
)
tile_scheduler_metadata
,
num_splits
=
\
if
m
.
num_decode_tokens
is
not
None
and
m
.
num_decode_tokens
>
0
:
m
.
decode_tile_scheduler_metadata
,
m
.
decode_num_splits
=
\
get_mla_metadata
(
get_mla_metadata
(
m
.
seq_lens
[:
m
.
num_decode_tokens
]
,
seq_lens
,
self
.
num_q_heads
,
self
.
num_q_heads
,
1
,
# MQA for the decode path
1
,
# MQA for the decode path
)
)
return
m
return
FlashMLADecodeMetadata
(
input_positions
=
input_positions
,
block_table
=
block_table
,
seq_lens
=
seq_lens
,
tile_scheduler_metadata
=
tile_scheduler_metadata
,
num_splits
=
num_splits
,
)
class
FlashMLAImpl
(
MLACommonImpl
[
FlashMLAMetadata
]):
class
FlashMLAImpl
(
MLACommonImpl
[
FlashMLAMetadata
]):
...
@@ -115,6 +123,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -115,6 +123,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
attn_metadata
:
FlashMLAMetadata
,
attn_metadata
:
FlashMLAMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
raise
NotImplementedError
(
"FP8 FlashMLA not yet supported"
)
raise
NotImplementedError
(
"FP8 FlashMLA not yet supported"
)
...
@@ -124,14 +134,12 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -124,14 +134,12 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
o
,
_
=
flash_mla_with_kvcache
(
o
,
_
=
flash_mla_with_kvcache
(
q
=
q
,
q
=
q
,
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
block_table
=
attn_metadata
.
block_table
[:
attn_metadata
.
num_decodes
,
block_table
=
attn_metadata
.
decode
.
block_table
,
...],
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
cache_seqlens
=
attn_metadata
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
],
head_dim_v
=
self
.
kv_lora_rank
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
attn_metadata
.
tile_scheduler_metadata
=
attn_metadata
.
decode
.
decode_
tile_scheduler_metadata
,
tile_scheduler_metadata
,
num_splits
=
attn_metadata
.
decode
_
num_splits
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
)
)
...
...
vllm/v1/attention/backends/mla/triton_mla.py
View file @
f6bb18fd
...
@@ -69,6 +69,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
...
@@ -69,6 +69,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata
:
MLACommonMetadata
,
attn_metadata
:
MLACommonMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
raise
NotImplementedError
(
"FP8 Triton MLA not yet supported"
)
raise
NotImplementedError
(
"FP8 Triton MLA not yet supported"
)
...
@@ -104,7 +106,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
...
@@ -104,7 +106,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
# Run MQA
# Run MQA
decode_attention_fwd
(
q
,
kv_c_and_k_pe_cache
,
kv_c_cache
,
o
,
decode_attention_fwd
(
q
,
kv_c_and_k_pe_cache
,
kv_c_cache
,
o
,
attn_metadata
.
block_table
,
attn_metadata
.
seq_lens
,
attn_metadata
.
decode
.
block_table
,
attn_logits
,
num_kv_splits
,
self
.
scale
,
PAGE_SIZE
)
attn_metadata
.
decode
.
seq_lens
,
attn_logits
,
num_kv_splits
,
self
.
scale
,
PAGE_SIZE
)
return
self
.
_v_up_proj_and_o_proj
(
o
)
return
self
.
_v_up_proj_and_o_proj
(
o
)
vllm/v1/worker/gpu_input_batch.py
View file @
f6bb18fd
...
@@ -383,8 +383,6 @@ class InputBatch:
...
@@ -383,8 +383,6 @@ class InputBatch:
self
.
req_id_to_index
[
old_id_i2
],
self
.
req_id_to_index
[
old_id_i1
]
self
.
req_id_to_index
[
old_id_i2
],
self
.
req_id_to_index
[
old_id_i1
]
self
.
num_tokens
[
i1
],
self
.
num_tokens
[
i2
]
=
\
self
.
num_tokens
[
i1
],
self
.
num_tokens
[
i2
]
=
\
self
.
num_tokens
[
i2
],
self
.
num_tokens
[
i1
]
self
.
num_tokens
[
i2
],
self
.
num_tokens
[
i1
]
self
.
token_ids_cpu
[
i1
,
...],
self
.
token_ids_cpu
[
i2
,
...],
=
\
self
.
token_ids_cpu
[
i2
,
...],
self
.
token_ids_cpu
[
i1
,
...]
self
.
num_tokens_no_spec
[
i1
],
self
.
num_tokens_no_spec
[
i2
]
=
\
self
.
num_tokens_no_spec
[
i1
],
self
.
num_tokens_no_spec
[
i2
]
=
\
self
.
num_tokens_no_spec
[
i2
],
self
.
num_tokens_no_spec
[
i1
]
self
.
num_tokens_no_spec
[
i2
],
self
.
num_tokens_no_spec
[
i1
]
self
.
num_prompt_tokens
[
i1
],
self
.
num_prompt_tokens
[
i2
]
=
\
self
.
num_prompt_tokens
[
i1
],
self
.
num_prompt_tokens
[
i2
]
=
\
...
@@ -406,24 +404,47 @@ class InputBatch:
...
@@ -406,24 +404,47 @@ class InputBatch:
self
.
min_p_cpu
[
i1
],
self
.
min_p_cpu
[
i2
]
=
\
self
.
min_p_cpu
[
i1
],
self
.
min_p_cpu
[
i2
]
=
\
self
.
min_p_cpu
[
i2
],
self
.
min_p_cpu
[
i1
]
self
.
min_p_cpu
[
i2
],
self
.
min_p_cpu
[
i1
]
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporiarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp
=
self
.
token_ids_cpu
[
i1
,
...].
copy
()
self
.
token_ids_cpu
[
i1
,
...]
=
self
.
token_ids_cpu
[
i2
,
...]
self
.
token_ids_cpu
[
i2
,
...]
=
tmp
g1
=
self
.
generators
.
get
(
i1
)
g1
=
self
.
generators
.
get
(
i1
)
g2
=
self
.
generators
.
get
(
i2
)
g2
=
self
.
generators
.
get
(
i2
)
if
g1
is
not
None
:
if
g1
is
not
None
:
self
.
generators
[
i2
]
=
g1
self
.
generators
[
i2
]
=
g1
else
:
self
.
generators
.
pop
(
i2
,
None
)
if
g2
is
not
None
:
if
g2
is
not
None
:
self
.
generators
[
i1
]
=
g2
self
.
generators
[
i1
]
=
g2
else
:
self
.
generators
.
pop
(
i1
,
None
)
t1
=
self
.
min_tokens
.
get
(
i1
)
t1
=
self
.
min_tokens
.
get
(
i1
)
t2
=
self
.
min_tokens
.
get
(
i2
)
t2
=
self
.
min_tokens
.
get
(
i2
)
if
t1
is
not
None
:
if
t1
is
not
None
:
self
.
min_tokens
[
i2
]
=
t1
self
.
min_tokens
[
i2
]
=
t1
else
:
self
.
min_tokens
.
pop
(
i2
,
None
)
if
t2
is
not
None
:
if
t2
is
not
None
:
self
.
min_tokens
[
i1
]
=
t2
self
.
min_tokens
[
i1
]
=
t2
else
:
self
.
min_tokens
.
pop
(
i1
,
None
)
self
.
request_lora_mapping
[
i1
],
self
.
request_lora_mapping
[
i2
]
=
\
self
.
request_lora_mapping
[
i1
],
self
.
request_lora_mapping
[
i2
]
=
\
self
.
request_lora_mapping
[
i2
],
self
.
request_lora_mapping
[
i1
]
self
.
request_lora_mapping
[
i2
],
self
.
request_lora_mapping
[
i1
]
self
.
logit_bias
[
i1
],
self
.
logit_bias
[
i2
]
=
\
self
.
logit_bias
[
i1
],
self
.
logit_bias
[
i2
]
=
\
self
.
logit_bias
[
i2
],
self
.
logit_bias
[
i1
]
self
.
logit_bias
[
i2
],
self
.
logit_bias
[
i1
]
if
self
.
allowed_token_ids_mask_cpu_tensor
is
not
None
:
self
.
allowed_token_ids_mask_cpu_tensor
[
i1
],
\
self
.
allowed_token_ids_mask_cpu_tensor
[
i2
]
=
\
self
.
allowed_token_ids_mask_cpu_tensor
[
i2
],
\
self
.
allowed_token_ids_mask_cpu_tensor
[
i1
]
self
.
block_table
.
swap_row
(
i1
,
i2
)
self
.
block_table
.
swap_row
(
i1
,
i2
)
def
condense
(
self
,
empty_req_indices
:
list
[
int
])
->
None
:
def
condense
(
self
,
empty_req_indices
:
list
[
int
])
->
None
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
f6bb18fd
...
@@ -456,8 +456,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -456,8 +456,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Some attention backends (namely MLA) may want to separate requests
# Some attention backends (namely MLA) may want to separate requests
# based on if the attention computation will be compute-bound or
# based on if the attention computation will be compute-bound or
# memory-bound. This gives them a hook to do that.
# memory-bound. This gives them a hook to do that.
self
.
attn_metadata_builder
.
reorder_batch
(
self
.
input_batch
,
modified_batch
=
self
.
attn_metadata_builder
.
reorder_batch
(
scheduler_output
)
self
.
input_batch
,
scheduler_output
)
if
modified_batch
:
self
.
input_batch
.
refresh_sampling_metadata
()
# OPTIMIZATION: Start copying the block table first.
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
# This way, we can overlap the copy with the following CPU operations.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment