Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
58d1b2aa
Unverified
Commit
58d1b2aa
authored
Feb 27, 2025
by
Yang Chen
Committed by
GitHub
Feb 27, 2025
Browse files
[Attention] MLA support for V1 (#13789)
Signed-off-by:
Yang Chen
<
yangche@fb.com
>
parent
f1579b22
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1340 additions
and
59 deletions
+1340
-59
vllm/attention/layer.py
vllm/attention/layer.py
+24
-11
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+11
-2
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+7
-2
vllm/platforms/interface.py
vllm/platforms/interface.py
+1
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+67
-2
vllm/v1/attention/backends/mla/__init__.py
vllm/v1/attention/backends/mla/__init__.py
+0
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+1022
-0
vllm/v1/attention/backends/triton_mla.py
vllm/v1/attention/backends/triton_mla.py
+110
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+63
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+35
-41
No files found.
vllm/attention/layer.py
View file @
58d1b2aa
...
@@ -89,6 +89,7 @@ class Attention(nn.Module):
...
@@ -89,6 +89,7 @@ class Attention(nn.Module):
self
.
_k_scale_float
=
1.0
self
.
_k_scale_float
=
1.0
self
.
_v_scale_float
=
1.0
self
.
_v_scale_float
=
1.0
self
.
use_mla
=
use_mla
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
num_kv_heads
=
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
...
@@ -158,6 +159,10 @@ class Attention(nn.Module):
...
@@ -158,6 +159,10 @@ class Attention(nn.Module):
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
# For some alternate attention backends like MLA the attention output
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape
:
Optional
[
torch
.
Size
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
The KV cache is stored inside this class and is accessed via
The KV cache is stored inside this class and is accessed via
...
@@ -173,17 +178,25 @@ class Attention(nn.Module):
...
@@ -173,17 +178,25 @@ class Attention(nn.Module):
if
attn_metadata
.
enable_kv_scales_calculation
:
if
attn_metadata
.
enable_kv_scales_calculation
:
self
.
calc_kv_scales
(
key
,
value
)
self
.
calc_kv_scales
(
key
,
value
)
if
self
.
use_output
:
if
self
.
use_output
:
output
=
torch
.
empty_like
(
query
)
output_shape
=
(
output_shape
hidden_size
=
query
.
size
(
-
1
)
if
output_shape
is
not
None
else
query
.
shape
)
# Reshape the query, key, and value tensors.
output
=
torch
.
empty
(
output_shape
,
# NOTE(woosuk): We do this outside the custom op to minimize the
dtype
=
query
.
dtype
,
# CPU overheads from the non-CUDA-graph regions.
device
=
query
.
device
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
hidden_size
=
output_shape
[
-
1
]
output
=
output
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
# We skip reshaping query, key and value tensors for the MLA
if
key
is
not
None
:
# backend since these tensors have different semantics and are
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
# processed differently.
if
value
is
not
None
:
if
not
self
.
use_mla
:
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
output
=
output
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
if
key
is
not
None
:
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
value
is
not
None
:
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
self
.
use_direct_call
:
if
self
.
use_direct_call
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
attn_metadata
=
forward_context
.
attn_metadata
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
58d1b2aa
...
@@ -420,9 +420,15 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -420,9 +420,15 @@ class DeepseekV2MLAAttention(nn.Module):
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
# In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular,
# the concat_and_cache_mla op requires
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
self
.
mla_attn
=
Attention
(
self
.
mla_attn
=
Attention
(
num_heads
=
self
.
num_local_heads
,
num_heads
=
self
.
num_local_heads
,
head_size
=
self
.
kv_lora_rank
,
head_size
=
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
scale
=
self
.
scaling
,
scale
=
self
.
scaling
,
num_kv_heads
=
1
,
num_kv_heads
=
1
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
...
@@ -458,7 +464,10 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -458,7 +464,10 @@ class DeepseekV2MLAAttention(nn.Module):
kv_c
,
k_pe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
kv_c
,
k_pe
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
.
contiguous
())
kv_c_normed
=
self
.
kv_a_layernorm
(
kv_c
.
contiguous
())
return
self
.
mla_attn
(
hidden_states_or_q_c
,
kv_c_normed
,
k_pe
)
return
self
.
mla_attn
(
hidden_states_or_q_c
,
kv_c_normed
,
k_pe
,
output_shape
=
hidden_states
.
shape
)
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
...
...
vllm/platforms/cuda.py
View file @
58d1b2aa
...
@@ -162,8 +162,13 @@ class CudaPlatformBase(Platform):
...
@@ -162,8 +162,13 @@ class CudaPlatformBase(Platform):
kv_cache_dtype
,
block_size
,
use_v1
,
kv_cache_dtype
,
block_size
,
use_v1
,
use_mla
)
->
str
:
use_mla
)
->
str
:
if
use_v1
:
if
use_v1
:
logger
.
info
(
"Using Flash Attention backend on V1 engine."
)
if
use_mla
:
return
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
logger
.
info
(
"Using Triton MLA backend on V1 engine."
)
return
"vllm.v1.attention.backends.triton_mla.TritonMLABackend"
else
:
logger
.
info
(
"Using Flash Attention backend on V1 engine."
)
return
(
"vllm.v1.attention.backends.flash_attn."
"FlashAttentionBackend"
)
if
use_mla
:
if
use_mla
:
if
selected_backend
==
_Backend
.
FLASHMLA
:
if
selected_backend
==
_Backend
.
FLASHMLA
:
from
vllm.attention.backends.flashmla
import
(
from
vllm.attention.backends.flashmla
import
(
...
...
vllm/platforms/interface.py
View file @
58d1b2aa
...
@@ -35,6 +35,7 @@ class _Backend(enum.Enum):
...
@@ -35,6 +35,7 @@ class _Backend(enum.Enum):
OPENVINO
=
enum
.
auto
()
OPENVINO
=
enum
.
auto
()
FLASHINFER
=
enum
.
auto
()
FLASHINFER
=
enum
.
auto
()
TRITON_MLA
=
enum
.
auto
()
TRITON_MLA
=
enum
.
auto
()
TRITON_MLA_VLLM_V1
=
enum
.
auto
()
FLASHMLA
=
enum
.
auto
()
FLASHMLA
=
enum
.
auto
()
HPU_ATTN
=
enum
.
auto
()
HPU_ATTN
=
enum
.
auto
()
PALLAS
=
enum
.
auto
()
PALLAS
=
enum
.
auto
()
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
58d1b2aa
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with FlashAttention."""
"""Attention layer with FlashAttention."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -14,6 +14,11 @@ from vllm.logger import init_logger
...
@@ -14,6 +14,11 @@ from vllm.logger import init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
if
TYPE_CHECKING
:
from
vllm.v1.core.scheduler_output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
if
current_platform
.
is_cuda
():
if
current_platform
.
is_cuda
():
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
@@ -40,6 +45,10 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -40,6 +45,10 @@ class FlashAttentionBackend(AttentionBackend):
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
FlashAttentionMetadata
return
FlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"FlashAttentionMetadataBuilder"
]:
return
FlashAttentionMetadataBuilder
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
num_blocks
:
int
,
num_blocks
:
int
,
...
@@ -85,6 +94,62 @@ class FlashAttentionMetadata:
...
@@ -85,6 +94,62 @@ class FlashAttentionMetadata:
num_input_tokens
:
int
=
0
# Number of tokens including padding.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
class
FlashAttentionMetadataBuilder
:
def
__init__
(
self
,
runner
:
"GPUModelRunner"
):
self
.
runner
=
runner
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
):
pass
def
build
(
self
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
max_query_len
:
int
,
common_prefix_len
:
int
):
max_seq_len
=
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
()
query_start_loc
=
self
.
runner
.
query_start_loc_cpu
[:
num_reqs
+
1
].
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
seq_lens
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
].
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
block_table
=
(
self
.
runner
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
])
slot_mapping
=
self
.
runner
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
self
.
runner
.
device
,
non_blocking
=
True
).
long
()
use_cascade
=
common_prefix_len
>
0
if
use_cascade
:
# TODO: Optimize.
cu_prefix_query_lens
=
torch
.
tensor
([
0
,
num_actual_tokens
],
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
prefix_kv_lens
=
torch
.
tensor
([
common_prefix_len
],
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
suffix_kv_lens
=
(
self
.
runner
.
seq_lens_np
[:
num_reqs
]
-
common_prefix_len
)
suffix_kv_lens
=
torch
.
from_numpy
(
suffix_kv_lens
).
to
(
self
.
runner
.
device
)
else
:
cu_prefix_query_lens
=
None
prefix_kv_lens
=
None
suffix_kv_lens
=
None
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
query_start_loc
=
query_start_loc
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
block_table
=
block_table
,
slot_mapping
=
slot_mapping
,
use_cascade
=
use_cascade
,
common_prefix_len
=
common_prefix_len
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
)
return
attn_metadata
class
FlashAttentionImpl
(
AttentionImpl
):
class
FlashAttentionImpl
(
AttentionImpl
):
def
__init__
(
def
__init__
(
...
@@ -371,4 +436,4 @@ def cascade_attention(
...
@@ -371,4 +436,4 @@ def cascade_attention(
# Merge prefix and suffix outputs, and store the result in output.
# Merge prefix and suffix outputs, and store the result in output.
merge_attn_states
(
output
,
prefix_output
,
prefix_lse
,
suffix_output
,
merge_attn_states
(
output
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
)
suffix_lse
)
\ No newline at end of file
vllm/v1/attention/backends/mla/__init__.py
0 → 100644
View file @
58d1b2aa
vllm/v1/attention/backends/mla/common.py
0 → 100644
View file @
58d1b2aa
This diff is collapsed.
Click to expand it.
vllm/v1/attention/backends/triton_mla.py
0 → 100644
View file @
58d1b2aa
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
import
torch
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.ops.triton_decode_attention
import
decode_attention_fwd
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
MLACommonImpl
,
MLACommonMetadata
)
logger
=
init_logger
(
__name__
)
class
TritonMLABackend
(
MLACommonBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"TRITON_MLA_VLLM_V1"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"TritonMLAImpl"
]:
return
TritonMLAImpl
class
TritonMLAImpl
(
MLACommonImpl
[
MLACommonMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
# MLA Specific Arguments
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
**
mla_args
)
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
]
if
any
(
unsupported_features
):
raise
NotImplementedError
(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap"
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl"
)
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
raise
NotImplementedError
(
"FP8 Triton MLA not yet supported"
)
B
=
q_nope
.
shape
[
0
]
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
o
=
torch
.
zeros
(
B
,
self
.
num_heads
,
self
.
kv_lora_rank
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
num_kv_splits
=
4
# TODO: heuristic
# TODO(lucas) Allocate ahead of time
attn_logits
=
torch
.
empty
(
(
B
,
self
.
num_heads
,
num_kv_splits
,
# NOTE(lucas) idk why the +1 is here but sglang has it so we
# just mirror that
self
.
kv_lora_rank
+
1
,
),
dtype
=
torch
.
float32
,
device
=
q
.
device
,
)
# Add a head dim of 1
kv_c_and_k_pe_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
2
)
kv_c_cache
=
kv_c_and_k_pe_cache
[...,
:
self
.
kv_lora_rank
]
PAGE_SIZE
=
kv_c_and_k_pe_cache
.
size
(
1
)
# Run MQA
decode_attention_fwd
(
q
,
kv_c_and_k_pe_cache
,
kv_c_cache
,
o
,
attn_metadata
.
block_table
,
attn_metadata
.
seq_lens
,
attn_logits
,
num_kv_splits
,
self
.
scale
,
PAGE_SIZE
)
return
self
.
_v_up_proj_and_o_proj
(
o
)
vllm/v1/worker/gpu_input_batch.py
View file @
58d1b2aa
...
@@ -80,7 +80,14 @@ class InputBatch:
...
@@ -80,7 +80,14 @@ class InputBatch:
self
.
num_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_tokens_no_spec
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_tokens_no_spec
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_prompt_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_prompt_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_tokens_cpu
=
np
.
empty
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_tokens_cpu_tensor
=
torch
.
zeros
(
(
max_num_reqs
,
),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
pin_memory
,
)
self
.
num_computed_tokens_cpu
=
\
self
.
num_computed_tokens_cpu_tensor
.
numpy
()
# Block table.
# Block table.
self
.
block_table
=
BlockTable
(
self
.
block_table
=
BlockTable
(
...
@@ -356,6 +363,61 @@ class InputBatch:
...
@@ -356,6 +363,61 @@ class InputBatch:
self
.
allowed_token_ids_mask_cpu_tensor
[
req_index
].
fill_
(
False
)
self
.
allowed_token_ids_mask_cpu_tensor
[
req_index
].
fill_
(
False
)
return
req_index
return
req_index
def
swap_states
(
self
,
i1
:
int
,
i2
:
int
)
->
None
:
old_id_i1
=
self
.
_req_ids
[
i1
]
old_id_i2
=
self
.
_req_ids
[
i2
]
self
.
_req_ids
[
i1
],
self
.
_req_ids
[
i2
]
=
\
self
.
_req_ids
[
i2
],
self
.
_req_ids
[
i1
]
# noqa
self
.
req_output_token_ids
[
i1
],
self
.
req_output_token_ids
[
i2
]
=
\
self
.
req_output_token_ids
[
i2
],
self
.
req_output_token_ids
[
i1
]
assert
old_id_i1
is
not
None
and
old_id_i2
is
not
None
self
.
req_id_to_index
[
old_id_i1
],
self
.
req_id_to_index
[
old_id_i2
]
=
\
self
.
req_id_to_index
[
old_id_i2
],
self
.
req_id_to_index
[
old_id_i1
]
self
.
num_tokens
[
i1
],
self
.
num_tokens
[
i2
]
=
\
self
.
num_tokens
[
i2
],
self
.
num_tokens
[
i1
]
self
.
token_ids_cpu
[
i1
,
...],
self
.
token_ids_cpu
[
i2
,
...],
=
\
self
.
token_ids_cpu
[
i2
,
...],
self
.
token_ids_cpu
[
i1
,
...]
self
.
num_tokens_no_spec
[
i1
],
self
.
num_tokens_no_spec
[
i2
]
=
\
self
.
num_tokens_no_spec
[
i2
],
self
.
num_tokens_no_spec
[
i1
]
self
.
num_prompt_tokens
[
i1
],
self
.
num_prompt_tokens
[
i2
]
=
\
self
.
num_prompt_tokens
[
i2
],
self
.
num_prompt_tokens
[
i1
]
self
.
num_computed_tokens_cpu
[
i1
],
self
.
num_computed_tokens_cpu
[
i2
]
=
\
self
.
num_computed_tokens_cpu
[
i2
],
self
.
num_computed_tokens_cpu
[
i1
]
self
.
temperature_cpu
[
i1
],
self
.
temperature_cpu
[
i2
]
=
\
self
.
temperature_cpu
[
i2
],
self
.
temperature_cpu
[
i1
]
self
.
top_p_cpu
[
i1
],
self
.
top_p_cpu
[
i2
]
=
\
self
.
top_p_cpu
[
i2
],
self
.
top_p_cpu
[
i1
]
self
.
top_k_cpu
[
i1
],
self
.
top_k_cpu
[
i2
]
=
\
self
.
top_k_cpu
[
i2
],
self
.
top_k_cpu
[
i1
]
self
.
frequency_penalties_cpu
[
i1
],
self
.
frequency_penalties_cpu
[
i2
]
=
\
self
.
frequency_penalties_cpu
[
i2
],
self
.
frequency_penalties_cpu
[
i1
]
self
.
presence_penalties_cpu
[
i1
],
self
.
presence_penalties_cpu
[
i2
]
=
\
self
.
presence_penalties_cpu
[
i2
],
self
.
presence_penalties_cpu
[
i1
]
self
.
repetition_penalties_cpu
[
i1
],
self
.
repetition_penalties_cpu
[
i2
]
=
\
self
.
repetition_penalties_cpu
[
i2
],
self
.
repetition_penalties_cpu
[
i1
]
self
.
min_p_cpu
[
i1
],
self
.
min_p_cpu
[
i2
]
=
\
self
.
min_p_cpu
[
i2
],
self
.
min_p_cpu
[
i1
]
g1
=
self
.
generators
.
get
(
i1
)
g2
=
self
.
generators
.
get
(
i2
)
if
g1
is
not
None
:
self
.
generators
[
i2
]
=
g1
if
g2
is
not
None
:
self
.
generators
[
i1
]
=
g2
t1
=
self
.
min_tokens
.
get
(
i1
)
t2
=
self
.
min_tokens
.
get
(
i2
)
if
t1
is
not
None
:
self
.
min_tokens
[
i2
]
=
t1
if
t2
is
not
None
:
self
.
min_tokens
[
i1
]
=
t2
self
.
request_lora_mapping
[
i1
],
self
.
request_lora_mapping
[
i2
]
=
\
self
.
request_lora_mapping
[
i2
],
self
.
request_lora_mapping
[
i1
]
self
.
logit_bias
[
i1
],
self
.
logit_bias
[
i2
]
=
\
self
.
logit_bias
[
i2
],
self
.
logit_bias
[
i1
]
self
.
block_table
.
swap_row
(
i1
,
i2
)
def
condense
(
self
,
empty_req_indices
:
List
[
int
])
->
None
:
def
condense
(
self
,
empty_req_indices
:
List
[
int
])
->
None
:
num_reqs
=
self
.
num_reqs
num_reqs
=
self
.
num_reqs
if
num_reqs
==
0
:
if
num_reqs
==
0
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
58d1b2aa
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
import
gc
import
gc
import
time
import
time
import
weakref
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -9,7 +10,7 @@ import torch
...
@@ -9,7 +10,7 @@ import torch
import
torch.distributed
import
torch.distributed
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.attention
.backends.abstract
import
AttentionType
from
vllm.attention
import
AttentionType
,
get_attn_backend
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.distributed.parallel_state
import
get_pp_group
,
graph_capture
from
vllm.distributed.parallel_state
import
get_pp_group
,
graph_capture
...
@@ -24,8 +25,7 @@ from vllm.sampling_params import SamplingType
...
@@ -24,8 +25,7 @@ from vllm.sampling_params import SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
LayerBlockType
,
cdiv
,
is_pin_memory_available
)
LayerBlockType
,
cdiv
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
(
FlashAttentionBackend
,
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
FlashAttentionMetadata
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.engine.mm_input_cache
import
MMInputCacheClient
from
vllm.v1.engine.mm_input_cache
import
MMInputCacheClient
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
...
@@ -92,6 +92,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -92,6 +92,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
head_size
=
model_config
.
get_head_size
()
self
.
head_size
=
model_config
.
get_head_size
()
self
.
hidden_size
=
model_config
.
get_hidden_size
()
self
.
hidden_size
=
model_config
.
get_hidden_size
()
self
.
attn_backend
=
get_attn_backend
(
self
.
head_size
,
self
.
dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
use_mla
=
self
.
model_config
.
use_mla
,
)
if
self
.
attn_backend
is
None
:
error_msg
=
(
f
"Error with get_att_backend:
{
self
.
head_size
=
}
, "
f
"
{
self
.
dtype
=
}
,
{
self
.
kv_cache_dtype
=
}
,
{
self
.
block_size
=
}
, "
f
"
{
self
.
model_config
.
is_attention_free
=
}
, "
f
"
{
self
.
model_config
.
use_mla
=
}
"
)
logger
.
error
(
error_msg
)
raise
NotImplementedError
(
"Non-Attention backend is not supported by V1 GPUModelRunner."
)
self
.
attn_metadata_builder
=
self
.
attn_backend
.
get_builder_cls
()(
weakref
.
proxy
(
self
))
# Multi-modal data support
# Multi-modal data support
self
.
input_registry
=
INPUT_REGISTRY
self
.
input_registry
=
INPUT_REGISTRY
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
mm_registry
=
MULTIMODAL_REGISTRY
...
@@ -433,6 +454,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -433,6 +454,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_reqs
=
self
.
input_batch
.
num_reqs
num_reqs
=
self
.
input_batch
.
num_reqs
assert
num_reqs
>
0
assert
num_reqs
>
0
# Some attention backends (namely MLA) may want to separate requests
# based on if the attention computation will be compute-bound or
# memory-bound. This gives them a hook to do that.
self
.
attn_metadata_builder
.
reorder_batch
(
self
.
input_batch
,
scheduler_output
)
# OPTIMIZATION: Start copying the block table first.
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
# This way, we can overlap the copy with the following CPU operations.
self
.
input_batch
.
block_table
.
commit
(
num_reqs
)
self
.
input_batch
.
block_table
.
commit
(
num_reqs
)
...
@@ -515,7 +542,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -515,7 +542,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
seq_lens_np
[:
num_reqs
]
=
(
self
.
seq_lens_np
[:
num_reqs
]
=
(
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
]
+
self
.
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
]
+
num_scheduled_tokens
)
num_scheduled_tokens
)
max_seq_len
=
self
.
seq_lens_np
[:
num_reqs
].
max
()
# Copy the tensors to the GPU.
# Copy the tensors to the GPU.
self
.
input_ids
[:
total_num_scheduled_tokens
].
copy_
(
self
.
input_ids
[:
total_num_scheduled_tokens
].
copy_
(
...
@@ -530,49 +556,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -530,49 +556,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
positions
[:
total_num_scheduled_tokens
].
copy_
(
self
.
positions
[:
total_num_scheduled_tokens
].
copy_
(
self
.
positions_cpu
[:
total_num_scheduled_tokens
],
self
.
positions_cpu
[:
total_num_scheduled_tokens
],
non_blocking
=
True
)
non_blocking
=
True
)
query_start_loc
=
self
.
query_start_loc_cpu
[:
num_reqs
+
1
].
to
(
self
.
device
,
non_blocking
=
True
)
seq_lens
=
self
.
seq_lens_cpu
[:
num_reqs
].
to
(
self
.
device
,
non_blocking
=
True
)
slot_mapping
=
self
.
slot_mapping_cpu
[:
total_num_scheduled_tokens
].
to
(
self
.
device
,
non_blocking
=
True
).
long
()
# Prepare for cascade attention if needed.
# Prepare for cascade attention if needed.
common_prefix_len
=
self
.
_compute_cascade_attn_prefix_len
(
common_prefix_len
=
self
.
_compute_cascade_attn_prefix_len
(
num_scheduled_tokens
,
num_scheduled_tokens
,
scheduler_output
.
num_common_prefix_blocks
,
scheduler_output
.
num_common_prefix_blocks
,
)
)
use_cascade
=
common_prefix_len
>
0
attn_metadata
=
self
.
attn_metadata_builder
.
build
(
if
use_cascade
:
num_reqs
=
num_reqs
,
# TODO: Optimize.
cu_prefix_query_lens
=
torch
.
tensor
(
[
0
,
total_num_scheduled_tokens
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
prefix_kv_lens
=
torch
.
tensor
([
common_prefix_len
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
suffix_kv_lens
=
(
self
.
seq_lens_np
[:
num_reqs
]
-
common_prefix_len
)
suffix_kv_lens
=
torch
.
from_numpy
(
suffix_kv_lens
).
to
(
self
.
device
)
else
:
cu_prefix_query_lens
=
None
prefix_kv_lens
=
None
suffix_kv_lens
=
None
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
total_num_scheduled_tokens
,
num_actual_tokens
=
total_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
max_query_len
=
max_num_scheduled_tokens
,
query_start_loc
=
query_start_loc
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
block_table
=
(
self
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
]),
slot_mapping
=
slot_mapping
,
use_cascade
=
use_cascade
,
common_prefix_len
=
common_prefix_len
,
common_prefix_len
=
common_prefix_len
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
)
)
use_spec_decode
=
len
(
use_spec_decode
=
len
(
...
@@ -586,7 +580,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -586,7 +580,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# from these partial requests, we do so for simplicity.
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
# TODO: Support prompt logprobs.
logits_indices
=
query_start_loc
[
1
:]
-
1
logits_indices
=
attn_metadata
.
query_start_loc
[
1
:]
-
1
# Hot-Swap lora model
# Hot-Swap lora model
if
self
.
lora_config
:
if
self
.
lora_config
:
...
@@ -667,7 +661,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -667,7 +661,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# common_prefix_len should be a multiple of the block size.
# common_prefix_len should be a multiple of the block size.
common_prefix_len
=
(
common_prefix_len
//
self
.
block_size
*
common_prefix_len
=
(
common_prefix_len
//
self
.
block_size
*
self
.
block_size
)
self
.
block_size
)
use_cascade
=
FlashAttentionB
ackend
.
use_cascade_attention
(
use_cascade
=
self
.
attn_b
ackend
.
use_cascade_attention
(
common_prefix_len
=
common_prefix_len
,
common_prefix_len
=
common_prefix_len
,
query_lens
=
num_scheduled_tokens
,
query_lens
=
num_scheduled_tokens
,
num_query_heads
=
self
.
num_query_heads
,
num_query_heads
=
self
.
num_query_heads
,
...
@@ -1379,7 +1373,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1379,7 +1373,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert
tensor_config
.
size
%
layer_spec
.
page_size_bytes
==
0
assert
tensor_config
.
size
%
layer_spec
.
page_size_bytes
==
0
num_blocks
=
tensor_config
.
size
//
layer_spec
.
page_size_bytes
num_blocks
=
tensor_config
.
size
//
layer_spec
.
page_size_bytes
if
isinstance
(
layer_spec
,
FullAttentionSpec
):
if
isinstance
(
layer_spec
,
FullAttentionSpec
):
kv_cache_shape
=
FlashAttentionB
ackend
.
get_kv_cache_shape
(
kv_cache_shape
=
self
.
attn_b
ackend
.
get_kv_cache_shape
(
num_blocks
,
layer_spec
.
block_size
,
layer_spec
.
num_kv_heads
,
num_blocks
,
layer_spec
.
block_size
,
layer_spec
.
num_kv_heads
,
layer_spec
.
head_size
)
layer_spec
.
head_size
)
dtype
=
layer_spec
.
dtype
dtype
=
layer_spec
.
dtype
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment