Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a78dd330
Unverified
Commit
a78dd330
authored
Nov 01, 2024
by
sroy745
Committed by
GitHub
Nov 01, 2024
Browse files
[Encoder Decoder] Add flash_attn kernel support for encoder-decoder models (#9559)
parent
d522034c
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
716 additions
and
317 deletions
+716
-317
tests/encoder_decoder/test_e2e_correctness.py
tests/encoder_decoder/test_e2e_correctness.py
+51
-37
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+115
-41
tests/kernels/utils.py
tests/kernels/utils.py
+74
-16
tests/models/encoder_decoder/vision_language/test_florence2.py
.../models/encoder_decoder/vision_language/test_florence2.py
+1
-1
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+278
-86
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+143
-16
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+26
-105
vllm/attention/selector.py
vllm/attention/selector.py
+1
-1
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+0
-2
vllm/utils.py
vllm/utils.py
+2
-2
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+25
-10
No files found.
tests/encoder_decoder/test_e2e_correctness.py
View file @
a78dd330
...
@@ -7,12 +7,18 @@ from typing import List, Optional, Tuple
...
@@ -7,12 +7,18 @@ from typing import List, Optional, Tuple
import
pytest
import
pytest
from
transformers
import
AutoModelForSeq2SeqLM
from
transformers
import
AutoModelForSeq2SeqLM
from
vllm.attention.selector
import
(
_Backend
,
global_force_attn_backend_context_manager
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
SampleLogprobs
from
vllm.sequence
import
SampleLogprobs
from
..conftest
import
DecoderPromptType
from
..conftest
import
DecoderPromptType
from
..models.utils
import
check_logprobs_close
from
..models.utils
import
check_logprobs_close
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
,
None
]
def
vllm_to_hf_output
(
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]],
vllm_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]],
...
@@ -29,7 +35,8 @@ def vllm_to_hf_output(
...
@@ -29,7 +35,8 @@ def vllm_to_hf_output(
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/bart-large-cnn"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/bart-large-cnn"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
LIST_ENC_DEC_SUPPORTED_BACKENDS
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
list
(
DecoderPromptType
))
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
list
(
DecoderPromptType
))
...
@@ -48,6 +55,7 @@ def test_encoder_decoder_e2e(
...
@@ -48,6 +55,7 @@ def test_encoder_decoder_e2e(
num_logprobs
:
int
,
num_logprobs
:
int
,
decoder_prompt_type
:
DecoderPromptType
,
decoder_prompt_type
:
DecoderPromptType
,
enforce_eager
:
bool
,
enforce_eager
:
bool
,
attn_backend
:
_Backend
,
)
->
None
:
)
->
None
:
'''
'''
End-to-End (E2E) test for the encoder-decoder framework.
End-to-End (E2E) test for the encoder-decoder framework.
...
@@ -56,43 +64,49 @@ def test_encoder_decoder_e2e(
...
@@ -56,43 +64,49 @@ def test_encoder_decoder_e2e(
implementations to ensure that both implementations produce consistent
implementations to ensure that both implementations produce consistent
and correct results.
and correct results.
'''
'''
test_case_prompts
=
example_encoder_decoder_prompts
[
decoder_prompt_type
]
with
global_force_attn_backend_context_manager
(
attn_backend
):
if
attn_backend
==
_Backend
.
FLASH_ATTN
:
# Flash Attention works only with bfloat16 data-type
dtype
=
'bfloat16'
test_case_prompts
=
example_encoder_decoder_prompts
[
decoder_prompt_type
]
# Configuration settings for HF baseline
# Configuration settings for HF baseline
hf_kwargs
=
{
hf_kwargs
=
{
"top_k"
:
None
,
"top_k"
:
None
,
"num_beams"
:
1
,
"num_beams"
:
1
,
"repetition_penalty"
:
1.0
,
"repetition_penalty"
:
1.0
,
"top_p"
:
1.0
,
"top_p"
:
1.0
,
"length_penalty"
:
1.0
,
"length_penalty"
:
1.0
,
"early_stopping"
:
False
,
"early_stopping"
:
False
,
"no_repeat_ngram_size"
:
None
,
"no_repeat_ngram_size"
:
None
,
"min_length"
:
0
"min_length"
:
0
}
}
with
hf_runner
(
model
,
dtype
=
dtype
,
with
hf_runner
(
model
,
dtype
=
dtype
,
auto_cls
=
AutoModelForSeq2SeqLM
)
as
hf_model
:
auto_cls
=
AutoModelForSeq2SeqLM
)
as
hf_model
:
hf_outputs
=
(
hf_model
.
generate_encoder_decoder_greedy_logprobs_limit
(
hf_outputs
=
(
test_case_prompts
,
hf_model
.
generate_encoder_decoder_greedy_logprobs_limit
(
max_tokens
,
test_case_prompts
,
num_logprobs
,
max_tokens
,
**
hf_kwargs
,
num_logprobs
,
))
**
hf_kwargs
,
with
vllm_runner
(
model
,
dtype
=
dtype
,
))
enforce_eager
=
enforce_eager
)
as
vllm_model
:
with
vllm_runner
(
model
,
dtype
=
dtype
,
vllm_outputs
=
vllm_model
.
generate_encoder_decoder_greedy_logprobs
(
enforce_eager
=
enforce_eager
)
as
vllm_model
:
test_case_prompts
,
max_tokens
,
num_logprobs
)
vllm_outputs
=
vllm_model
.
generate_encoder_decoder_greedy_logprobs
(
test_case_prompts
,
max_tokens
,
num_logprobs
)
hf_skip_tokens
=
(
1
hf_skip_tokens
=
(
1
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
else
0
)
else
0
)
check_logprobs_close
(
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
[
outputs_1_lst
=
[
vllm_to_hf_output
(
vllm_output
,
decoder_prompt_type
)
vllm_to_hf_output
(
vllm_output
,
decoder_prompt_type
)
for
vllm_output
in
vllm_outputs
for
vllm_output
in
vllm_outputs
],
],
name_0
=
"hf"
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
name_1
=
"vllm"
,
num_outputs_0_skip_tokens
=
hf_skip_tokens
,
num_outputs_0_skip_tokens
=
hf_skip_tokens
,
)
)
tests/kernels/test_encoder_decoder_attn.py
View file @
a78dd330
...
@@ -16,13 +16,13 @@ from tests.kernels.utils import *
...
@@ -16,13 +16,13 @@ from tests.kernels.utils import *
from
vllm.attention
import
(
Attention
,
AttentionBackend
,
AttentionMetadata
,
from
vllm.attention
import
(
Attention
,
AttentionBackend
,
AttentionMetadata
,
AttentionType
)
AttentionType
)
from
vllm.attention.backends.utils
import
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from
vllm.attention.backends.utils
import
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from
vllm.attention.selector
import
(
_Backend
,
from
vllm.attention.selector
import
(
_Backend
,
get_attn_backend
,
global_force_attn_backend_context_manager
)
global_force_attn_backend_context_manager
)
from
vllm.forward_context
import
set_forward_context
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
# List of support backends for encoder/decoder models
# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
]
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]
HEAD_SIZES
=
[
64
,
256
]
HEAD_SIZES
=
[
64
,
256
]
NUM_HEADS
=
[
1
,
16
]
NUM_HEADS
=
[
1
,
16
]
...
@@ -145,7 +145,8 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
...
@@ -145,7 +145,8 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
test_pt
.
num_heads
,
test_pt
.
num_heads
,
test_pt
.
head_size
,
test_pt
.
head_size
,
test_pt
.
block_size
,
test_pt
.
block_size
,
device
=
CUDA_DEVICE
)
device
=
CUDA_DEVICE
,
backend
=
test_pt
.
backend_name
)
return
TestResources
(
scale
,
attn_backend
,
attn
,
kv_cache
)
return
TestResources
(
scale
,
attn_backend
,
attn
,
kv_cache
)
...
@@ -592,6 +593,7 @@ def _run_encoder_attention_test(
...
@@ -592,6 +593,7 @@ def _run_encoder_attention_test(
attn
:
Attention
,
attn
:
Attention
,
encoder_test_params
:
PhaseTestParameters
,
encoder_test_params
:
PhaseTestParameters
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
test_pt
:
TestPoint
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
'''
'''
Run encoder attention.
Run encoder attention.
...
@@ -610,6 +612,8 @@ def _run_encoder_attention_test(
...
@@ -610,6 +612,8 @@ def _run_encoder_attention_test(
(number_of_tokens x num_heads x head_size)
(number_of_tokens x num_heads x head_size)
query/key/value fields
query/key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
* attn_metadata: attention metadata for encoder/decoder-self attention
* test_pt: The TestPoint object containing test details like number of
model heads, head size, name of the backend being used etc.
Returns:
Returns:
* Attention.forward() applied to packed {query,key,value} and
* Attention.forward() applied to packed {query,key,value} and
...
@@ -619,20 +623,31 @@ def _run_encoder_attention_test(
...
@@ -619,20 +623,31 @@ def _run_encoder_attention_test(
attn_type
=
AttentionType
.
ENCODER
attn_type
=
AttentionType
.
ENCODER
packed_qkv
=
encoder_test_params
.
packed_qkvo
.
packed_qkv
packed_qkv
=
encoder_test_params
.
packed_qkvo
.
packed_qkv
assert
packed_qkv
is
not
None
assert
packed_qkv
is
not
None
return
attn
.
forward
(
packed_qkv
.
query
,
with
set_forward_context
(
attn_metadata
):
packed_qkv
.
key
,
# In the test setup the shape of the query is
packed_qkv
.
value
,
# [batch_size, seq_len, num_heads, head_size]. However
torch
.
tensor
([],
# the attention backend expect the shape to be
dtype
=
torch
.
float32
,
# [num_tokens, hidden_size]. Hence reshape the query before
device
=
packed_qkv
.
query
.
device
),
# invoking the forward method.
attn_metadata
,
# TODO - Update the way we construct the query so that it
attn_type
=
attn_type
)
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
packed_qkv
.
query
.
device
),
attn_metadata
,
attn_type
=
attn_type
)
def
_run_decoder_self_attention_test
(
def
_run_decoder_self_attention_test
(
test_rsrcs
:
TestResources
,
test_rsrcs
:
TestResources
,
decoder_test_params
:
PhaseTestParameters
,
decoder_test_params
:
PhaseTestParameters
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
test_pt
:
TestPoint
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
'''
'''
Run decoder self-attention test.
Run decoder self-attention test.
...
@@ -650,6 +665,8 @@ def _run_decoder_self_attention_test(
...
@@ -650,6 +665,8 @@ def _run_decoder_self_attention_test(
query/key/value fields
query/key/value fields
* attn_metadata: attention metadata for decoder-self attention
* attn_metadata: attention metadata for decoder-self attention
(contains KV cache memory-mapping)
(contains KV cache memory-mapping)
* test_pt: The TestPoint object containing test details like number of
model heads, head size, name of the backend being used etc.
Returns:
Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache
* Attention.forward() applied to packed_{query,key,value}, kv_cache
...
@@ -660,12 +677,22 @@ def _run_decoder_self_attention_test(
...
@@ -660,12 +677,22 @@ def _run_decoder_self_attention_test(
kv_cache
=
test_rsrcs
.
kv_cache
kv_cache
=
test_rsrcs
.
kv_cache
packed_qkv
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
packed_qkv
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
assert
packed_qkv
is
not
None
assert
packed_qkv
is
not
None
return
attn
.
forward
(
packed_qkv
.
query
,
with
set_forward_context
(
attn_metadata
):
packed_qkv
.
key
,
# In the test setup the shape of the query is
packed_qkv
.
value
,
# [batch_size, seq_len, num_heads, head_size]. However
kv_cache
,
# the attention backend expect the shape to be
attn_metadata
,
# [num_tokens, hidden_size]. Hence reshape the query before
attn_type
=
attn_type
)
# invoking the forward method.
# TODO - Update the way we construct the query so that it
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
kv_cache
,
attn_metadata
,
attn_type
=
attn_type
)
def
_run_encoder_decoder_cross_attention_test
(
def
_run_encoder_decoder_cross_attention_test
(
...
@@ -673,6 +700,7 @@ def _run_encoder_decoder_cross_attention_test(
...
@@ -673,6 +700,7 @@ def _run_encoder_decoder_cross_attention_test(
decoder_test_params
:
PhaseTestParameters
,
decoder_test_params
:
PhaseTestParameters
,
cross_test_params
:
Optional
[
PhaseTestParameters
],
cross_test_params
:
Optional
[
PhaseTestParameters
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
test_pt
:
TestPoint
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
'''
'''
Run encoder/decoder cross-attention test.
Run encoder/decoder cross-attention test.
...
@@ -701,6 +729,8 @@ def _run_encoder_decoder_cross_attention_test(
...
@@ -701,6 +729,8 @@ def _run_encoder_decoder_cross_attention_test(
(number_of_tokens x num_heads x head_size)
(number_of_tokens x num_heads x head_size)
key/value fields
key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
* attn_metadata: attention metadata for encoder/decoder-self attention
* test_pt: The TestPoint object containing test details like number of
model heads, head size, name of the backend being used etc.
Returns:
Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache
* Attention.forward() applied to packed_{query,key,value}, kv_cache
...
@@ -718,12 +748,37 @@ def _run_encoder_decoder_cross_attention_test(
...
@@ -718,12 +748,37 @@ def _run_encoder_decoder_cross_attention_test(
cross_pckd_qkv
=
cross_test_params
.
packed_qkvo
.
packed_qkv
cross_pckd_qkv
=
cross_test_params
.
packed_qkvo
.
packed_qkv
key
=
(
None
if
cross_pckd_qkv
is
None
else
cross_pckd_qkv
.
key
)
key
=
(
None
if
cross_pckd_qkv
is
None
else
cross_pckd_qkv
.
key
)
value
=
(
None
if
cross_pckd_qkv
is
None
else
cross_pckd_qkv
.
value
)
value
=
(
None
if
cross_pckd_qkv
is
None
else
cross_pckd_qkv
.
value
)
return
attn
.
forward
(
decoder_test_params
.
packed_qkvo
.
packed_qkv
.
query
,
with
set_forward_context
(
attn_metadata
):
key
,
# In the test setup the shape of the query is
value
,
# [batch_size, seq_len, num_heads, head_size]. However
kv_cache
,
# the attention backend expect the shape to be
attn_metadata
,
# [num_tokens, hidden_size]. Hence reshape the query before
attn_type
=
attn_type
)
# invoking the forward method.
# TODO - Update the way we construct the query so that it
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
key
,
value
,
kv_cache
,
attn_metadata
,
attn_type
=
attn_type
)
@
pytest
.
fixture
(
autouse
=
True
)
def
set_reset_environment
(
attn_backend
):
# Set the default torch datatype to bfloat16 to enable
# testing of the Flash Attention backend. Also clear the
# cached value of the backend.
default_dtype
=
torch
.
get_default_dtype
()
if
attn_backend
.
name
==
'FLASH_ATTN'
:
torch
.
set_default_dtype
(
torch
.
bfloat16
)
get_attn_backend
.
cache_clear
()
yield
# Reset the torch datatype to what it was before the test
# so as not to impact the remaining tests.
torch
.
set_default_dtype
(
default_dtype
)
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
...
@@ -773,10 +828,8 @@ def test_encoder_only(
...
@@ -773,10 +828,8 @@ def test_encoder_only(
* max_dec_seq_len: max length of decoder input sequences
* max_dec_seq_len: max length of decoder input sequences
* max_enc_seq_len: max length of encoder input sequences
* max_enc_seq_len: max length of encoder input sequences
'''
'''
# Force Attention wrapper backend
# Force Attention wrapper backend
with
global_force_attn_backend_context_manager
(
attn_backend
):
with
global_force_attn_backend_context_manager
(
attn_backend
):
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
# is not part of this test
...
@@ -807,10 +860,14 @@ def test_encoder_only(
...
@@ -807,10 +860,14 @@ def test_encoder_only(
# PREFILL: encoder attention
# PREFILL: encoder attention
enc_pckd_act_out
:
torch
.
Tensor
=
(
_run_encoder_attention_test
(
enc_pckd_act_out
:
torch
.
Tensor
=
(
_run_encoder_attention_test
(
test_rsrcs
.
attn
,
enc_test_params
,
prephase_attn_metadata
))
test_rsrcs
.
attn
,
enc_test_params
,
prephase_attn_metadata
,
test_pt
=
test_pt
))
# - Is encoder attention result correct?
# - Is encoder attention result correct?
assert_actual_matches_ideal
(
enc_test_params
,
enc_pckd_act_out
)
assert_actual_matches_ideal
(
enc_test_params
,
enc_pckd_act_out
,
attn_backend
.
name
)
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
...
@@ -892,10 +949,8 @@ def test_e2e_enc_dec_attn(
...
@@ -892,10 +949,8 @@ def test_e2e_enc_dec_attn(
* max_dec_seq_len: max length of decoder input sequences
* max_dec_seq_len: max length of decoder input sequences
* max_enc_seq_len: max length of encoder input sequences
* max_enc_seq_len: max length of encoder input sequences
'''
'''
# Force Attention wrapper backend
# Force Attention wrapper backend
with
global_force_attn_backend_context_manager
(
attn_backend
):
with
global_force_attn_backend_context_manager
(
attn_backend
):
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
# is not part of this test
...
@@ -955,29 +1010,39 @@ def test_e2e_enc_dec_attn(
...
@@ -955,29 +1010,39 @@ def test_e2e_enc_dec_attn(
enc_pckd_act_out
=
_run_encoder_attention_test
(
test_rsrcs
.
attn
,
enc_pckd_act_out
=
_run_encoder_attention_test
(
test_rsrcs
.
attn
,
enc_test_params
,
enc_test_params
,
prephase_attn_metadata
)
prephase_attn_metadata
,
test_pt
=
test_pt
)
# - Is encoder attention result correct?
# - Is encoder attention result correct?
assert_actual_matches_ideal
(
enc_test_params
,
enc_pckd_act_out
)
assert_actual_matches_ideal
(
enc_test_params
,
enc_pckd_act_out
,
attn_backend
.
name
)
# PREFILL: decoder self-attention test
# PREFILL: decoder self-attention test
prephase_dec_pckd_act_out
=
_run_decoder_self_attention_test
(
prephase_dec_pckd_act_out
=
_run_decoder_self_attention_test
(
test_rsrcs
,
prephase_dec_test_params
,
prephase_attn_metadata
)
test_rsrcs
,
prephase_dec_test_params
,
prephase_attn_metadata
,
test_pt
=
test_pt
)
# - Is prefill decoder self-attention correct?
# - Is prefill decoder self-attention correct?
assert_actual_matches_ideal
(
prephase_dec_test_params
,
assert_actual_matches_ideal
(
prephase_dec_test_params
,
prephase_dec_pckd_act_out
)
prephase_dec_pckd_act_out
,
attn_backend
.
name
)
# PREFILL: encoder/decoder cross-attention test
# PREFILL: encoder/decoder cross-attention test
prephase_cross_pckd_act_out
=
_run_encoder_decoder_cross_attention_test
(
prephase_cross_pckd_act_out
=
_run_encoder_decoder_cross_attention_test
(
test_rsrcs
,
prephase_dec_test_params
,
prephase_cross_test_params
,
test_rsrcs
,
prephase_attn_metadata
)
prephase_dec_test_params
,
prephase_cross_test_params
,
prephase_attn_metadata
,
test_pt
=
test_pt
)
# - Is prefill encoder/decoder cross-attention correct?
# - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal
(
prephase_cross_test_params
,
assert_actual_matches_ideal
(
prephase_cross_test_params
,
prephase_cross_pckd_act_out
)
prephase_cross_pckd_act_out
,
attn_backend
.
name
)
# DECODE: build decode-phase attention metadata
# DECODE: build decode-phase attention metadata
...
@@ -993,17 +1058,26 @@ def test_e2e_enc_dec_attn(
...
@@ -993,17 +1058,26 @@ def test_e2e_enc_dec_attn(
# DECODE: decoder self-attention test
# DECODE: decoder self-attention test
decphase_dec_pckd_act_out
=
_run_decoder_self_attention_test
(
decphase_dec_pckd_act_out
=
_run_decoder_self_attention_test
(
test_rsrcs
,
decphase_dec_test_params
,
decphase_attn_metadata
)
test_rsrcs
,
decphase_dec_test_params
,
decphase_attn_metadata
,
test_pt
=
test_pt
)
# - Is decode-phase decoder self-attention correct?
# - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal
(
decphase_dec_test_params
,
assert_actual_matches_ideal
(
decphase_dec_test_params
,
decphase_dec_pckd_act_out
)
decphase_dec_pckd_act_out
,
attn_backend
.
name
)
# DECODE: encoder/decoder cross-attention test
# DECODE: encoder/decoder cross-attention test
decphase_cross_pckd_act_out
=
_run_encoder_decoder_cross_attention_test
(
decphase_cross_pckd_act_out
=
_run_encoder_decoder_cross_attention_test
(
test_rsrcs
,
decphase_dec_test_params
,
None
,
decphase_attn_metadata
)
test_rsrcs
,
decphase_dec_test_params
,
None
,
decphase_attn_metadata
,
test_pt
=
test_pt
)
# - Is decode-phase encoder/decoder cross-attention correct?
# - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal
(
decphase_cross_test_params
,
assert_actual_matches_ideal
(
decphase_cross_test_params
,
decphase_cross_pckd_act_out
)
decphase_cross_pckd_act_out
,
attn_backend
.
name
)
tests/kernels/utils.py
View file @
a78dd330
...
@@ -13,8 +13,8 @@ from torch._prims_common import TensorLikeType
...
@@ -13,8 +13,8 @@ from torch._prims_common import TensorLikeType
from
vllm.attention
import
AttentionBackend
,
AttentionMetadata
,
AttentionType
from
vllm.attention
import
AttentionBackend
,
AttentionMetadata
,
AttentionType
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.utils
import
(
STR_BACKEND_ENV_VAR
,
STR_
XFORMERS
_ATTN_VAL
,
from
vllm.utils
import
(
STR_BACKEND_ENV_VAR
,
STR_
FLASH
_ATTN_VAL
,
make_tensor_with_pad
)
STR_XFORMERS_ATTN_VAL
,
make_tensor_with_pad
)
# For now, disable "test_aot_dispatch_dynamic" since there are some
# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
# bugs related to this test in PyTorch 2.4.
...
@@ -525,17 +525,22 @@ def make_backend(backend_name: str) -> AttentionBackend:
...
@@ -525,17 +525,22 @@ def make_backend(backend_name: str) -> AttentionBackend:
if
backend_name
==
STR_XFORMERS_ATTN_VAL
:
if
backend_name
==
STR_XFORMERS_ATTN_VAL
:
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
from
vllm.attention.backends.xformers
import
XFormersBackend
from
vllm.attention.backends.xformers
import
XFormersBackend
return
XFormersBackend
()
return
XFormersBackend
()
elif
backend_name
==
STR_FLASH_ATTN_VAL
:
from
vllm.attention.backends.flash_attn
import
FlashAttentionBackend
return
FlashAttentionBackend
()
raise
AssertionError
(
raise
AssertionError
(
f
"Unrecognized backend_name
{
backend_name
}
for unit test"
)
f
"Unrecognized backend_name
{
backend_name
}
for unit test"
)
def
_make_metadata_tensors
(
def
_make_metadata_tensors
(
seq_lens
:
Optional
[
List
[
int
]],
context_lens
:
Optional
[
List
[
int
]],
seq_lens
:
Optional
[
List
[
int
]],
encoder_seq_lens
:
Optional
[
List
[
int
]],
device
:
Union
[
torch
.
device
,
str
]
context_lens
:
Optional
[
List
[
int
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Any
,
Any
,
Optional
[
List
[
int
]],
encoder_seq_lens
:
Optional
[
List
[
int
]],
torch
.
Tensor
,
Optional
[
int
]]:
device
:
Union
[
torch
.
device
,
str
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Any
,
Any
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
int
]]:
'''
'''
Build scalar & tensor values required to build attention metadata structure.
Build scalar & tensor values required to build attention metadata structure.
...
@@ -553,6 +558,8 @@ def _make_metadata_tensors(
...
@@ -553,6 +558,8 @@ def _make_metadata_tensors(
* max_context_len: max(context_lens)
* max_context_len: max(context_lens)
* max_seq_len: max(seq_lens)
* max_seq_len: max(seq_lens)
* seq_start_loc: start idx of each sequence
* seq_start_loc: start idx of each sequence
* encoder_seq_lens_tensor: encoder seq_lens list, as tensor
* encoder_seq_start_loc: start idx of each encoder sequence
* max_encoder_seq_len: encoder seq_lens list, as tensor
* max_encoder_seq_len: encoder seq_lens list, as tensor
'''
'''
seq_lens_tensor
=
maybe_make_int_tensor
(
seq_lens
,
device
)
seq_lens_tensor
=
maybe_make_int_tensor
(
seq_lens
,
device
)
...
@@ -566,8 +573,26 @@ def _make_metadata_tensors(
...
@@ -566,8 +573,26 @@ def _make_metadata_tensors(
seq_start_loc
=
None
seq_start_loc
=
None
if
seq_lens_tensor
is
not
None
:
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens_tensor
.
device
)
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
encoder_seq_start_loc
=
torch
.
zeros
(
encoder_seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
encoder_seq_lens_tensor
.
device
)
torch
.
cumsum
(
encoder_seq_lens_tensor
,
dim
=
0
,
dtype
=
encoder_seq_start_loc
.
dtype
,
out
=
encoder_seq_start_loc
[
1
:])
return
(
seq_lens_tensor
,
context_lens_tensor
,
max_context_len
,
max_seq_len
,
return
(
seq_lens_tensor
,
context_lens_tensor
,
max_context_len
,
max_seq_len
,
seq_start_loc
,
encoder_seq_lens_tensor
,
max_encoder_seq_len
)
seq_start_loc
,
encoder_seq_lens_tensor
,
encoder_seq_start_loc
,
max_encoder_seq_len
)
def
make_kv_cache
(
num_blocks
:
int
,
def
make_kv_cache
(
num_blocks
:
int
,
...
@@ -575,6 +600,7 @@ def make_kv_cache(num_blocks: int,
...
@@ -575,6 +600,7 @@ def make_kv_cache(num_blocks: int,
head_size
:
int
,
head_size
:
int
,
block_size
:
int
,
block_size
:
int
,
device
:
Union
[
torch
.
device
,
str
],
device
:
Union
[
torch
.
device
,
str
],
backend
:
str
,
default_val
:
float
=
0.0
)
->
torch
.
Tensor
:
default_val
:
float
=
0.0
)
->
torch
.
Tensor
:
'''
'''
Create a fake KV cache.
Create a fake KV cache.
...
@@ -591,10 +617,20 @@ def make_kv_cache(num_blocks: int,
...
@@ -591,10 +617,20 @@ def make_kv_cache(num_blocks: int,
Returns:
Returns:
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
* for backend 'XFORMERS'
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
* for backend 'FLASH_ATTN'
'''
'''
if
backend
==
'XFORMERS'
:
kv_cache
=
torch
.
rand
(
kv_cache
=
torch
.
rand
(
(
2
,
num_blocks
,
block_size
*
num_heads
*
head_size
)).
to
(
device
)
(
2
,
num_blocks
,
block_size
*
num_heads
*
head_size
)).
to
(
device
)
elif
backend
==
'FLASH_ATTN'
:
kv_cache
=
torch
.
rand
(
(
2
,
num_blocks
,
block_size
,
num_heads
,
head_size
)).
to
(
device
)
else
:
raise
ValueError
(
f
"Unknown backend value: '
{
backend
}
'. Expected 'XFORMERS' or "
f
"'FLASH_ATTN'."
)
if
default_val
is
not
None
:
if
default_val
is
not
None
:
kv_cache
[:,
:,
:]
=
default_val
kv_cache
[:,
:,
:]
=
default_val
return
kv_cache
return
kv_cache
...
@@ -858,8 +894,9 @@ def make_test_metadata(
...
@@ -858,8 +894,9 @@ def make_test_metadata(
context_lens_tensor
,
context_lens_tensor
,
_
,
_
,
_
,
_
,
_
,
seq_start_loc
,
encoder_seq_lens_tensor
,
encoder_seq_lens_tensor
,
encoder_seq_start_loc
,
max_encoder_seq_len
,
max_encoder_seq_len
,
)
=
_make_metadata_tensors
(
seq_lens
,
)
=
_make_metadata_tensors
(
seq_lens
,
context_lens
,
context_lens
,
...
@@ -874,6 +911,7 @@ def make_test_metadata(
...
@@ -874,6 +911,7 @@ def make_test_metadata(
num_decode_tokens
=
num_decode_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
seq_lens_tensor
=
seq_lens_tensor
,
seq_start_loc
=
seq_start_loc
,
max_prefill_seq_len
=
None
if
seq_lens
is
None
else
max
(
seq_lens
),
max_prefill_seq_len
=
None
if
seq_lens
is
None
else
max
(
seq_lens
),
max_decode_seq_len
=
0
,
max_decode_seq_len
=
0
,
context_lens_tensor
=
context_lens_tensor
,
context_lens_tensor
=
context_lens_tensor
,
...
@@ -882,6 +920,7 @@ def make_test_metadata(
...
@@ -882,6 +920,7 @@ def make_test_metadata(
num_encoder_tokens
=
num_encoder_tokens
,
num_encoder_tokens
=
num_encoder_tokens
,
encoder_seq_lens
=
encoder_seq_lens
,
encoder_seq_lens
=
encoder_seq_lens
,
encoder_seq_lens_tensor
=
encoder_seq_lens_tensor
,
encoder_seq_lens_tensor
=
encoder_seq_lens_tensor
,
encoder_seq_start_loc
=
encoder_seq_start_loc
,
max_encoder_seq_len
=
max_encoder_seq_len
,
max_encoder_seq_len
=
max_encoder_seq_len
,
cross_slot_mapping
=
(
None
if
cross_kv_mmap
is
None
else
cross_slot_mapping
=
(
None
if
cross_kv_mmap
is
None
else
cross_kv_mmap
.
slot_mapping
),
cross_kv_mmap
.
slot_mapping
),
...
@@ -904,8 +943,9 @@ def make_test_metadata(
...
@@ -904,8 +943,9 @@ def make_test_metadata(
context_lens_tensor
,
context_lens_tensor
,
_
,
_
,
_
,
_
,
_
,
seq_start_loc
,
encoder_seq_lens_tensor
,
encoder_seq_lens_tensor
,
encoder_seq_start_loc
,
max_encoder_seq_len
,
max_encoder_seq_len
,
)
=
_make_metadata_tensors
(
seq_lens
,
)
=
_make_metadata_tensors
(
seq_lens
,
context_lens
,
context_lens
,
...
@@ -920,14 +960,17 @@ def make_test_metadata(
...
@@ -920,14 +960,17 @@ def make_test_metadata(
num_decode_tokens
=
num_decode_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
seq_lens_tensor
=
seq_lens_tensor
,
seq_start_loc
=
seq_start_loc
,
max_prefill_seq_len
=
0
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
max
(
seq_lens
),
max_decode_seq_len
=
max
(
seq_lens
),
max_decode_query_len
=
1
,
context_lens_tensor
=
context_lens_tensor
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
kv_mmap
.
block_tables
,
block_tables
=
kv_mmap
.
block_tables
,
use_cuda_graph
=
False
,
use_cuda_graph
=
False
,
num_encoder_tokens
=
num_encoder_tokens
,
num_encoder_tokens
=
num_encoder_tokens
,
encoder_seq_lens
=
encoder_seq_lens
,
encoder_seq_lens
=
encoder_seq_lens
,
encoder_seq_lens_tensor
=
encoder_seq_lens_tensor
,
encoder_seq_lens_tensor
=
encoder_seq_lens_tensor
,
encoder_seq_start_loc
=
encoder_seq_start_loc
,
max_encoder_seq_len
=
max_encoder_seq_len
,
max_encoder_seq_len
=
max_encoder_seq_len
,
cross_slot_mapping
=
(
None
if
cross_kv_mmap
is
None
else
cross_slot_mapping
=
(
None
if
cross_kv_mmap
is
None
else
cross_kv_mmap
.
slot_mapping
),
cross_kv_mmap
.
slot_mapping
),
...
@@ -936,7 +979,8 @@ def make_test_metadata(
...
@@ -936,7 +979,8 @@ def make_test_metadata(
def
assert_actual_matches_ideal
(
test_params
:
PhaseTestParameters
,
def
assert_actual_matches_ideal
(
test_params
:
PhaseTestParameters
,
output_under_test
:
torch
.
Tensor
)
->
None
:
output_under_test
:
torch
.
Tensor
,
backend
:
str
)
->
None
:
'''
'''
Assert that observed output matches the ideal output
Assert that observed output matches the ideal output
contained in the test parameters data structure.
contained in the test parameters data structure.
...
@@ -947,8 +991,22 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
...
@@ -947,8 +991,22 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
* output_under_test: actually observed output value
* output_under_test: actually observed output value
'''
'''
ideal_output
=
test_params
.
packed_qkvo
.
ideal_output
ideal_output
=
test_params
.
packed_qkvo
.
ideal_output
torch
.
testing
.
assert_close
(
ideal_output
,
if
backend
==
'XFORMERS'
:
output_under_test
.
view_as
(
ideal_output
))
torch
.
testing
.
assert_close
(
ideal_output
,
output_under_test
.
view_as
(
ideal_output
))
elif
backend
==
'FLASH_ATTN'
:
# For FlashAttention override the accuracy thresholds to non default
# values since we notice a higher difference between the ideal and
# actual output.
torch
.
testing
.
assert_close
(
ideal_output
,
output_under_test
.
view_as
(
ideal_output
),
atol
=
0.01
,
rtol
=
0.016
)
else
:
raise
ValueError
(
f
"Unknown backend value: '
{
backend
}
'. Expected 'XFORMERS' or "
f
"'FLASH_ATTN'."
)
# Copied/modified from torch._refs.__init__.py
# Copied/modified from torch._refs.__init__.py
...
...
tests/models/encoder_decoder/vision_language/test_florence2.py
View file @
a78dd330
...
@@ -85,7 +85,7 @@ def run_test(
...
@@ -85,7 +85,7 @@ def run_test(
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
,
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
hf_runner
,
vllm_runner
,
model
,
dtype
,
max_tokens
,
def
test_models
(
hf_runner
,
vllm_runner
,
model
,
dtype
,
max_tokens
,
...
...
vllm/attention/backends/flash_attn.py
View file @
a78dd330
...
@@ -10,10 +10,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
...
@@ -10,10 +10,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
AttentionType
)
AttentionType
)
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
CommonAttentionState
,
from
vllm.attention.backends.utils
import
(
compute_slot_mapping
,
PAD_SLOT_ID
,
CommonAttentionState
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
compute_slot_mapping_start_idx
,
get_num_prefill_decode_query_kv_tokens
,
is_block_tables_empty
)
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
is_all_encoder_attn_metadata_set
,
is_block_tables_empty
)
from
vllm.forward_context
import
get_forward_context
from
vllm.forward_context
import
get_forward_context
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
(
async_tensor_h2d
,
direct_register_custom_op
,
from
vllm.utils
import
(
async_tensor_h2d
,
direct_register_custom_op
,
...
@@ -73,7 +74,6 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -73,7 +74,6 @@ class FlashAttentionBackend(AttentionBackend):
src_key_cache
=
src_kv_cache
[
0
]
src_key_cache
=
src_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
src_value_cache
=
src_kv_cache
[
1
]
src_value_cache
=
src_kv_cache
[
1
]
dst_value_cache
=
dst_kv_cache
[
1
]
dst_value_cache
=
dst_kv_cache
[
1
]
ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
...
@@ -85,6 +85,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -85,6 +85,7 @@ class FlashAttentionBackend(AttentionBackend):
)
->
None
:
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
...
@@ -111,26 +112,12 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -111,26 +112,12 @@ class FlashAttentionMetadata(AttentionMetadata):
# |-------------------- seq_len ---------------------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# |-- query_len ---|
# Maximum query length in the batch.
max_query_len
:
Optional
[
int
]
# Max number of query tokens among request in the batch.
max_decode_query_len
:
Optional
[
int
]
# Maximum sequence length among prefill batch. 0 if there are decoding
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
# requests only.
max_prefill_seq_len
:
int
max_prefill_seq_len
:
int
# Maximum sequence length among decode batch. 0 if there are prefill
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
# requests only.
max_decode_seq_len
:
int
max_decode_seq_len
:
int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size,) A tensor of context lengths (tokens that are computed
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
...
@@ -146,11 +133,62 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -146,11 +133,62 @@ class FlashAttentionMetadata(AttentionMetadata):
# Whether or not if cuda graph is enabled.
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
use_cuda_graph
:
bool
# Maximum query length in the batch.
max_query_len
:
Optional
[
int
]
=
None
# Max number of query tokens among request in the batch.
max_decode_query_len
:
Optional
[
int
]
=
None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
_cached_prefill_metadata
:
Optional
[
"FlashAttentionMetadata"
]
=
None
_cached_prefill_metadata
:
Optional
[
"FlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"FlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"FlashAttentionMetadata"
]
=
None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens
:
Optional
[
List
[
int
]]
=
None
encoder_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
encoder_seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum sequence length among encoder sequences
max_encoder_seq_len
:
Optional
[
int
]
=
None
# Number of tokens input to encoder
num_encoder_tokens
:
Optional
[
int
]
=
None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping
:
Optional
[
torch
.
Tensor
]
=
None
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
is_all_encoder_attn_metadata_set
(
self
):
'''
All attention metadata required for encoder attention is set.
'''
return
is_all_encoder_attn_metadata_set
(
self
)
@
property
def
is_all_cross_attn_metadata_set
(
self
):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return
is_all_cross_attn_metadata_set
(
self
)
@
property
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"FlashAttentionMetadata"
]:
def
prefill_metadata
(
self
)
->
Optional
[
"FlashAttentionMetadata"
]:
if
self
.
num_prefills
==
0
:
if
self
.
num_prefills
==
0
:
...
@@ -159,32 +197,52 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -159,32 +197,52 @@ class FlashAttentionMetadata(AttentionMetadata):
if
self
.
_cached_prefill_metadata
is
not
None
:
if
self
.
_cached_prefill_metadata
is
not
None
:
return
self
.
_cached_prefill_metadata
return
self
.
_cached_prefill_metadata
assert
self
.
seq_lens
is
not
None
assert
((
self
.
seq_lens
is
not
None
)
assert
self
.
seq_lens_tensor
is
not
None
or
(
self
.
encoder_seq_lens
is
not
None
))
assert
self
.
query_start_loc
is
not
None
assert
((
self
.
seq_lens_tensor
is
not
None
)
assert
self
.
context_lens_tensor
is
not
None
or
(
self
.
encoder_seq_lens_tensor
is
not
None
))
assert
self
.
block_tables
is
not
None
assert
self
.
seq_start_loc
is
not
None
# Compute some attn_metadata fields which default to None
query_start_loc
=
(
None
if
self
.
query_start_loc
is
None
else
self
.
query_start_loc
[:
self
.
num_prefills
+
1
])
slot_mapping
=
(
None
if
self
.
slot_mapping
is
None
else
self
.
slot_mapping
[:
self
.
num_prefill_tokens
])
seq_lens
=
(
None
if
self
.
seq_lens
is
None
else
self
.
seq_lens
[:
self
.
num_prefills
])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[:
self
.
num_prefills
])
seq_start_loc
=
(
None
if
self
.
seq_start_loc
is
None
else
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
])
context_lens_tensor
=
(
None
if
self
.
context_lens_tensor
is
None
else
self
.
context_lens_tensor
[:
self
.
num_prefills
])
block_tables
=
(
None
if
self
.
block_tables
is
None
else
self
.
block_tables
[:
self
.
num_prefills
])
self
.
_cached_prefill_metadata
=
FlashAttentionMetadata
(
self
.
_cached_prefill_metadata
=
FlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefills
=
self
.
num_prefills
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
num_decode_tokens
=
0
,
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
]
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
multi_modal_placeholder_index_maps
,
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
]
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
]
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
self
.
max_query_len
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_query_len
=
0
,
max_decode_query_len
=
0
,
max_decode_seq_len
=
0
,
max_decode_seq_len
=
0
,
query_start_loc
=
self
.
query_start_loc
[:
self
.
num_prefills
+
1
]
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
]
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
]
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
]
,
block_tables
=
block_tables
,
use_cuda_graph
=
False
,
use_cuda_graph
=
False
,
)
# Begin encoder & cross attn fields below...
encoder_seq_lens
=
self
.
encoder_seq_lens
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
encoder_seq_start_loc
=
self
.
encoder_seq_start_loc
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
return
self
.
_cached_prefill_metadata
return
self
.
_cached_prefill_metadata
@
property
@
property
...
@@ -194,17 +252,25 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -194,17 +252,25 @@ class FlashAttentionMetadata(AttentionMetadata):
if
self
.
_cached_decode_metadata
is
not
None
:
if
self
.
_cached_decode_metadata
is
not
None
:
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
assert
self
.
block_tables
is
not
None
assert
((
self
.
seq_lens_tensor
is
not
None
)
assert
self
.
seq_lens_tensor
is
not
None
or
(
self
.
encoder_seq_lens_tensor
is
not
None
))
# Compute some attn_metadata fields which default to None
slot_mapping
=
(
None
if
self
.
slot_mapping
is
None
else
self
.
slot_mapping
[
self
.
num_prefill_tokens
:])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[
self
.
num_prefills
:])
block_tables
=
(
None
if
self
.
block_tables
is
None
else
self
.
block_tables
[
self
.
num_prefills
:])
self
.
_cached_decode_metadata
=
FlashAttentionMetadata
(
self
.
_cached_decode_metadata
=
FlashAttentionMetadata
(
num_prefills
=
0
,
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:]
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
multi_modal_placeholder_index_maps
=
None
,
seq_lens
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:]
,
seq_lens_tensor
=
seq_lens_tensor
,
max_decode_query_len
=
self
.
max_decode_query_len
,
max_decode_query_len
=
self
.
max_decode_query_len
,
max_query_len
=
self
.
max_query_len
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
0
,
max_prefill_seq_len
=
0
,
...
@@ -214,9 +280,15 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -214,9 +280,15 @@ class FlashAttentionMetadata(AttentionMetadata):
seq_start_loc
=
self
.
seq_start_loc
[
self
.
num_prefills
:]
seq_start_loc
=
self
.
seq_start_loc
[
self
.
num_prefills
:]
if
self
.
seq_start_loc
is
not
None
else
None
,
if
self
.
seq_start_loc
is
not
None
else
None
,
context_lens_tensor
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:]
,
block_tables
=
block_tables
,
use_cuda_graph
=
self
.
use_cuda_graph
,
use_cuda_graph
=
self
.
use_cuda_graph
,
)
# Begin encoder & cross attn fields below...
encoder_seq_lens
=
self
.
encoder_seq_lens
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
encoder_seq_start_loc
=
self
.
encoder_seq_start_loc
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
return
self
.
_cached_decode_metadata
return
self
.
_cached_decode_metadata
def
advance_step
(
self
,
def
advance_step
(
self
,
...
@@ -586,16 +658,20 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -586,16 +658,20 @@ class FlashAttentionImpl(AttentionImpl):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl"
)
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
"key/v_scale is not supported in FlashAttention."
)
"key/v_scale is not supported in FlashAttention."
)
if
(
attn_type
==
AttentionType
.
ENCODER
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
raise
AttributeError
(
"Encoder attention requires setting "
"encoder metadata attributes."
)
elif
(
attn_type
==
AttentionType
.
ENCODER_DECODER
and
(
not
attn_metadata
.
is_all_cross_attn_metadata_set
)):
raise
AttributeError
(
"Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes."
)
output
=
torch
.
ops
.
vllm
.
unified_flash_attention
(
output
=
torch
.
ops
.
vllm
.
unified_flash_attention
(
query
,
query
,
key
,
key
,
...
@@ -608,6 +684,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -608,6 +684,7 @@ class FlashAttentionImpl(AttentionImpl):
k_scale
,
k_scale
,
v_scale
,
v_scale
,
self
.
scale
,
self
.
scale
,
attn_type
.
value
,
self
.
sliding_window
,
self
.
sliding_window
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
self
.
logits_soft_cap
,
self
.
logits_soft_cap
,
...
@@ -616,6 +693,89 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -616,6 +693,89 @@ class FlashAttentionImpl(AttentionImpl):
return
output
return
output
def
_get_query_key_seq_metadata
(
attn_metadata
,
is_prompt
:
bool
,
attn_type
:
AttentionType
,
)
->
tuple
:
"""
Returns sequence metadata for key and query based on the specified
attention type and whether input is a prompt.
This function computes the starting locations and maximum sequence lengths
for key and query sequences for different attention types.
Args:
attn_metadata: The attention metadata object
is_prompt (bool): A flag indicating if the input is a prompt
attn_type (AttentionType): The type of attention being used.
Returns:
tuple: A tuple containing four integers:
- Starting location for the query sequence.
- Maximum sequence length for the query sequence.
- Starting location for the key sequence.
- Maximum sequence length for the key sequence.
Raises:
AttributeError: If an invalid attention type is provided.
"""
if
attn_type
==
AttentionType
.
DECODER
:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if
is_prompt
:
max_seq_len
=
attn_metadata
.
max_prefill_seq_len
else
:
max_seq_len
=
attn_metadata
.
max_decode_seq_len
return
(
attn_metadata
.
seq_start_loc
,
max_seq_len
,
attn_metadata
.
seq_start_loc
,
max_seq_len
)
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# This is cross attention between the where the key
# is the precomputed encoder attention and query
# is the input sequence.
# Choose query max length based on whether it is prompt
# or not.
if
is_prompt
:
max_seq_len
=
attn_metadata
.
max_prefill_seq_len
else
:
max_seq_len
=
attn_metadata
.
max_decode_seq_len
return
(
attn_metadata
.
seq_start_loc
,
max_seq_len
,
attn_metadata
.
encoder_seq_start_loc
,
attn_metadata
.
max_encoder_seq_len
)
elif
attn_type
==
AttentionType
.
ENCODER
:
# For encoder attention both the query and the key are same i.e the
# encoder sequence.
return
(
attn_metadata
.
encoder_seq_start_loc
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
encoder_seq_start_loc
,
attn_metadata
.
max_encoder_seq_len
)
elif
attn_type
==
AttentionType
.
ENCODER_ONLY
:
assert
is_prompt
,
"Should not have decode for encoder only model."
return
(
attn_metadata
.
seq_start_loc
,
attn_metadata
.
max_prefill_seq_len
,
attn_metadata
.
seq_start_loc
,
attn_metadata
.
max_prefill_seq_len
)
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
_get_causal_option
(
attn_type
:
AttentionType
)
->
bool
:
"""
Determine whether the given attention type is suitable for causal
attention mechanisms.
Args:
attn_type (AttentionType): The type of attention being evaluated
Returns:
bool: Returns `True` if the attention type is suitable for causal
attention (i.e., not encoder, encoder-only, or encoder-decoder),
otherwise returns `False`.
"""
return
not
(
attn_type
==
AttentionType
.
ENCODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
or
attn_type
==
AttentionType
.
ENCODER_DECODER
)
def
unified_flash_attention
(
def
unified_flash_attention
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
...
@@ -628,60 +788,76 @@ def unified_flash_attention(
...
@@ -628,60 +788,76 @@ def unified_flash_attention(
k_scale
:
float
,
k_scale
:
float
,
v_scale
:
float
,
v_scale
:
float
,
softmax_scale
:
float
,
softmax_scale
:
float
,
attn_type_int_val
:
int
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Convert integer attn_type to enum
try
:
attn_type
=
AttentionType
(
attn_type_int_val
)
except
ValueError
as
err
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type_int_val
)
}
"
)
from
err
current_metadata
=
get_forward_context
()
current_metadata
=
get_forward_context
()
assert
current_metadata
is
not
None
assert
current_metadata
is
not
None
assert
isinstance
(
current_metadata
,
FlashAttentionMetadata
)
assert
isinstance
(
current_metadata
,
FlashAttentionMetadata
)
attn_metadata
:
FlashAttentionMetadata
=
current_metadata
attn_metadata
:
FlashAttentionMetadata
=
current_metadata
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
if
(
key
is
not
None
)
and
(
value
is
not
None
):
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
if
kv_cache
.
numel
()
>
0
:
if
kv_cache
.
numel
()
>
0
:
key_cache
=
kv_cache
[
0
]
key_cache
=
kv_cache
[
0
]
value_cache
=
kv_cache
[
1
]
value_cache
=
kv_cache
[
1
]
# We skip updating the KV cache under two conditions:
# a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the KV
# cache is already populated with the cross-attention tensor.
# Thus, we skip cache updates during this time.
if
(
attn_type
!=
AttentionType
.
ENCODER
)
and
(
key
is
not
None
)
and
(
value
is
not
None
):
if
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Update cross-attention KV cache (prefill-only)
updated_slot_mapping
=
attn_metadata
.
cross_slot_mapping
else
:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping
=
attn_metadata
.
slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[
0
],
kv_cache
[
1
],
updated_slot_mapping
.
flatten
(),
# type: ignore[union-attr]
kv_cache_dtype
,
k_scale
,
v_scale
,
)
# Reshape the input keys and values and store them in the cache.
(
num_prefill_query_tokens
,
num_prefill_kv_tokens
,
# If kv_cache is not provided, the new key and value tensors are
num_decode_query_tokens
)
=
\
# not cached. This happens during the initial memory profiling run.
get_num_prefill_decode_query_kv_tokens
(
attn_metadata
,
attn_type
)
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
decode_query
=
query
[
num_prefill_query_tokens
:]
key
,
value
,
kv_cache
[
0
],
kv_cache
[
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache_dtype
,
k_scale
,
v_scale
,
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
,
\
f
"key :
{
key
.
shape
}
: #prefill tokens
{
num_prefill_tokens
}
: #decode tokens
{
num_decode_tokens
}
"
# noqa
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
,
\
f
"value :
{
value
.
shape
}
: #prefill toks
{
num_prefill_tokens
}
: #decode toks
{
num_decode_tokens
}
"
# noqa
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
query
=
query
[:
num_prefill_query_tokens
]
key
=
key
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_query_tokens
value
=
value
[:
num_prefill_tokens
]
assert
decode_query
.
shape
[
0
]
==
num_decode_query_tokens
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
# Prompt run.
if
(
kv_cache
.
numel
()
==
0
or
prefill_meta
.
block_tables
is
None
if
(
kv_cache
.
numel
()
==
0
or
prefill_meta
.
block_tables
is
None
...
@@ -689,22 +865,30 @@ def unified_flash_attention(
...
@@ -689,22 +865,30 @@ def unified_flash_attention(
# normal attention
# normal attention
# When block_tables are not filled, it means q and k are the
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
# prompt, and they have the same length.
q_seq_start_loc
,
q_seq_len
,
k_seq_start_loc
,
k_seq_len
=
\
_get_query_key_seq_metadata
(
prefill_meta
,
True
,
attn_type
)
key
=
key
[:
num_prefill_kv_tokens
]
value
=
value
[:
num_prefill_kv_tokens
]
prefill_output
=
flash_attn_varlen_func
(
prefill_output
=
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_q
=
q_
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
k_
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill
_seq_len
,
max_seqlen_q
=
q
_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill
_seq_len
,
max_seqlen_k
=
k
_seq_len
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
causal
=
_get_causal_option
(
attn_type
)
,
window_size
=
window_size
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
)
)
else
:
else
:
# prefix-enabled attention
# prefix-enabled attention
assert
attn_type
==
AttentionType
.
DECODER
,
(
"Only decoder-only models support prefix caching"
)
assert
prefill_meta
.
seq_lens
is
not
None
assert
prefill_meta
.
seq_lens
is
not
None
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
prefill_output
=
flash_attn_varlen_func
(
# noqa
prefill_output
=
flash_attn_varlen_func
(
# noqa
...
@@ -729,6 +913,8 @@ def unified_flash_attention(
...
@@ -729,6 +913,8 @@ def unified_flash_attention(
# because different queries might have different lengths.
# because different queries might have different lengths.
assert
decode_meta
.
max_decode_query_len
is
not
None
assert
decode_meta
.
max_decode_query_len
is
not
None
if
decode_meta
.
max_decode_query_len
>
1
:
if
decode_meta
.
max_decode_query_len
>
1
:
assert
attn_type
==
AttentionType
.
DECODER
,
(
"Only decoder-only models support max_decode_query_len > 1"
)
decode_output
=
flash_attn_varlen_func
(
decode_output
=
flash_attn_varlen_func
(
q
=
decode_query
,
q
=
decode_query
,
k
=
key_cache
,
k
=
key_cache
,
...
@@ -746,12 +932,17 @@ def unified_flash_attention(
...
@@ -746,12 +932,17 @@ def unified_flash_attention(
)
)
else
:
else
:
# Use flash_attn_with_kvcache for normal decoding.
# Use flash_attn_with_kvcache for normal decoding.
(
seq_lens_arg
,
_
,
block_tables_arg
,
)
=
get_seq_len_block_table_args
(
decode_meta
,
False
,
attn_type
)
decode_output
=
flash_attn_with_kvcache
(
decode_output
=
flash_attn_with_kvcache
(
q
=
decode_query
.
unsqueeze
(
1
),
q
=
decode_query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
v_cache
=
value_cache
,
block_table
=
decode_meta
.
block_tables
,
block_table
=
block_tables
_arg
,
cache_seqlens
=
decode_meta
.
seq_lens_
tensor
,
cache_seqlens
=
seq_lens_
arg
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
causal
=
True
,
window_size
=
window_size
,
window_size
=
window_size
,
...
@@ -761,10 +952,10 @@ def unified_flash_attention(
...
@@ -761,10 +952,10 @@ def unified_flash_attention(
if
prefill_output
is
None
:
if
prefill_output
is
None
:
assert
decode_output
is
not
None
assert
decode_output
is
not
None
return
decode_output
.
view
(
num_decode_tokens
,
hidden_size
)
return
decode_output
.
view
(
num_decode_
query_
tokens
,
hidden_size
)
if
decode_output
is
None
:
if
decode_output
is
None
:
assert
prefill_output
is
not
None
assert
prefill_output
is
not
None
return
prefill_output
.
view
(
num_prefill_tokens
,
hidden_size
)
return
prefill_output
.
view
(
num_prefill_
query_
tokens
,
hidden_size
)
# Chunked prefill does not work with speculative decoding.
# Chunked prefill does not work with speculative decoding.
# Therefore, the query length for decode should be 1 in chunked prefill.
# Therefore, the query length for decode should be 1 in chunked prefill.
...
@@ -786,6 +977,7 @@ def unified_flash_attention_fake(
...
@@ -786,6 +977,7 @@ def unified_flash_attention_fake(
k_scale
:
float
,
k_scale
:
float
,
v_scale
:
float
,
v_scale
:
float
,
softmax_scale
:
float
,
softmax_scale
:
float
,
attn_type_int_val
:
int
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
...
...
vllm/attention/backends/utils.py
View file @
a78dd330
"""Attention backend utils"""
"""Attention backend utils"""
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Type
,
TypeVar
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Tuple
,
Type
,
TypeVar
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataBuilder
,
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionState
)
AttentionState
)
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
...
@@ -336,11 +337,13 @@ class CommonAttentionState(AttentionState):
...
@@ -336,11 +337,13 @@ class CommonAttentionState(AttentionState):
use_cuda_graph
=
True
,
use_cuda_graph
=
True
,
)
)
if
is_encoder_decoder_model
:
if
is_encoder_decoder_model
:
# The encoder decoder model works only with XFormers backend.
# The encoder decoder model works only with XFormers and
# Assert the same.
# Flash Attention backend. Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
==
"XFORMERS"
,
\
assert
self
.
runner
.
attn_backend
.
get_name
()
in
\
f
"Expected attn_backend name to be 'XFORMERS', but "
\
[
"XFORMERS"
,
"FLASH_ATTN"
],
\
f
" got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
f
"Expected attn_backend name to be either 'XFORMERS' or "
\
f
"'FLASH_ATTN', but "
\
f
"got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
self
.
_update_captured_metadata_for_enc_dec_model
(
self
.
_update_captured_metadata_for_enc_dec_model
(
batch_size
=
batch_size
,
attn_metadata
=
attn_metadata
)
batch_size
=
batch_size
,
attn_metadata
=
attn_metadata
)
...
@@ -356,11 +359,13 @@ class CommonAttentionState(AttentionState):
...
@@ -356,11 +359,13 @@ class CommonAttentionState(AttentionState):
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
}
}
if
is_encoder_decoder_model
:
if
is_encoder_decoder_model
:
# The encoder decoder model works only with XFormers backend.
# The encoder decoder model works only with XFormers and
# Assert the same.
# Flash Attention backend. Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
==
"XFORMERS"
,
\
assert
self
.
runner
.
attn_backend
.
get_name
()
in
\
f
"Expected attn_backend name to be 'XFORMERS', but "
\
[
"XFORMERS"
,
"FLASH_ATTN"
],
\
f
" got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
f
"Expected attn_backend name to be either 'XFORMERS' or "
\
f
"'FLASH_ATTN', but "
\
f
"got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
self
.
_add_additonal_input_buffers_for_enc_dec_model
(
self
.
_add_additonal_input_buffers_for_enc_dec_model
(
attn_metadata
=
attn_metadata
,
input_buffers
=
input_buffers
)
attn_metadata
=
attn_metadata
,
input_buffers
=
input_buffers
)
return
input_buffers
return
input_buffers
...
@@ -375,11 +380,13 @@ class CommonAttentionState(AttentionState):
...
@@ -375,11 +380,13 @@ class CommonAttentionState(AttentionState):
input_buffers
[
"block_tables"
].
copy_
(
input_buffers
[
"block_tables"
].
copy_
(
attn_metadata
.
decode_metadata
.
block_tables
,
non_blocking
=
True
)
attn_metadata
.
decode_metadata
.
block_tables
,
non_blocking
=
True
)
if
is_encoder_decoder_model
:
if
is_encoder_decoder_model
:
# The encoder decoder model works only with XFormers backend.
# The encoder decoder model works only with XFormers and
# Assert the same.
# Flash Attention backend. Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
==
"XFORMERS"
,
\
assert
self
.
runner
.
attn_backend
.
get_name
()
in
\
f
"Expected attn_backend name to be 'XFORMERS', but "
\
[
"XFORMERS"
,
"FLASH_ATTN"
],
\
f
" got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
f
"Expected attn_backend name to be either 'XFORMERS' or "
\
f
"'FLASH_ATTN', but "
\
f
"got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
self
.
_prepare_input_buffers_for_enc_dec_model
(
self
.
_prepare_input_buffers_for_enc_dec_model
(
attn_metadata
,
input_buffers
)
attn_metadata
,
input_buffers
)
...
@@ -411,6 +418,7 @@ class CommonAttentionState(AttentionState):
...
@@ -411,6 +418,7 @@ class CommonAttentionState(AttentionState):
attn_metadata
.
encoder_seq_lens_tensor
=
torch
.
full
(
attn_metadata
.
encoder_seq_lens_tensor
=
torch
.
full
(
(
batch_size
,
),
1
,
dtype
=
torch
.
int
).
cuda
()
(
batch_size
,
),
1
,
dtype
=
torch
.
int
).
cuda
()
attn_metadata
.
max_encoder_seq_len
=
self
.
runner
.
max_seq_len_to_capture
attn_metadata
.
max_encoder_seq_len
=
self
.
runner
.
max_seq_len_to_capture
attn_metadata
.
num_encoder_tokens
=
0
def
_add_additonal_input_buffers_for_enc_dec_model
(
def
_add_additonal_input_buffers_for_enc_dec_model
(
self
,
attn_metadata
,
input_buffers
:
Dict
[
str
,
Any
]):
self
,
attn_metadata
,
input_buffers
:
Dict
[
str
,
Any
]):
...
@@ -453,3 +461,122 @@ class CommonAttentionState(AttentionState):
...
@@ -453,3 +461,122 @@ class CommonAttentionState(AttentionState):
input_buffers
[
"cross_block_tables"
].
copy_
(
input_buffers
[
"cross_block_tables"
].
copy_
(
attn_metadata
.
decode_metadata
.
cross_block_tables
,
attn_metadata
.
decode_metadata
.
cross_block_tables
,
non_blocking
=
True
)
non_blocking
=
True
)
def
is_all_encoder_attn_metadata_set
(
attn_metadata
):
'''
All attention metadata required for encoder attention is set.
'''
return
((
attn_metadata
.
encoder_seq_lens
is
not
None
)
and
(
attn_metadata
.
encoder_seq_lens_tensor
is
not
None
)
and
(
attn_metadata
.
max_encoder_seq_len
is
not
None
))
def
is_all_cross_attn_metadata_set
(
attn_metadata
):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return
(
attn_metadata
.
is_all_encoder_attn_metadata_set
and
(
attn_metadata
.
cross_slot_mapping
is
not
None
)
and
(
attn_metadata
.
cross_block_tables
is
not
None
))
def
get_seq_len_block_table_args
(
attn_metadata
,
is_prompt
:
bool
,
attn_type
:
AttentionType
,
)
->
tuple
:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if
attn_type
==
AttentionType
.
DECODER
:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if
is_prompt
:
max_seq_len
=
attn_metadata
.
max_prefill_seq_len
else
:
max_seq_len
=
attn_metadata
.
max_decode_seq_len
return
(
attn_metadata
.
seq_lens_tensor
,
max_seq_len
,
attn_metadata
.
block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return
(
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
cross_block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER
:
# No block tables associated with encoder attention
return
(
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
None
)
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
get_num_prefill_decode_query_kv_tokens
(
attn_metadata
,
attn_type
:
AttentionType
,
)
->
Tuple
[
int
,
int
,
int
]:
"""
Calculate the number of prefill and decode tokens for query, key/value
based on the attention metadata and the specified attention type.
Args:
attn_metadata (FlashAttentionMetadata): Attention Metadata object.
attn_type (AttentionType): The type of attention being used.
Returns:
Tuple[int, int, int]: A tuple containing three integers:
- The number of prefill query tokens.
- The number of prefill key/value tokens.
- The number of decode query tokens.
Raises:
AssertionError: If the number of encoder tokens in `attn_metadata`
is `None` when required for the calculations.
"""
num_prefill_query_tokens
=
0
num_decode_query_tokens
=
0
num_prefill_kv_tokens
=
0
if
attn_type
==
AttentionType
.
ENCODER
:
# Encoder attention is only invoked during prefill phase.
# The same input servers a both query and key.
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_query_tokens
=
attn_metadata
.
num_encoder_tokens
num_prefill_kv_tokens
=
attn_metadata
.
num_encoder_tokens
num_decode_query_tokens
=
0
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_query_tokens
=
attn_metadata
.
num_prefill_tokens
# The key is the encoder/cross-attention.
num_prefill_kv_tokens
=
attn_metadata
.
num_encoder_tokens
num_decode_query_tokens
=
attn_metadata
.
num_decode_tokens
else
:
# attn_type == AttentionType.DECODER or
# attn_type == AttentionType.ENCODER_ONLY
num_prefill_query_tokens
=
attn_metadata
.
num_prefill_tokens
num_prefill_kv_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_query_tokens
=
attn_metadata
.
num_decode_tokens
return
(
num_prefill_query_tokens
,
num_prefill_kv_tokens
,
num_decode_query_tokens
)
vllm/attention/backends/xformers.py
View file @
a78dd330
...
@@ -11,8 +11,10 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
...
@@ -11,8 +11,10 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
CommonAttentionState
,
from
vllm.attention.backends.utils
import
(
CommonMetadataBuilder
)
CommonAttentionState
,
CommonMetadataBuilder
,
get_num_prefill_decode_query_kv_tokens
,
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
is_all_encoder_attn_metadata_set
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -135,6 +137,11 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -135,6 +137,11 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Encoder sequence lengths representation
# Encoder sequence lengths representation
encoder_seq_lens
:
Optional
[
List
[
int
]]
=
None
encoder_seq_lens
:
Optional
[
List
[
int
]]
=
None
encoder_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
encoder_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
encoder_seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum sequence length among encoder sequences
# Maximum sequence length among encoder sequences
max_encoder_seq_len
:
Optional
[
int
]
=
None
max_encoder_seq_len
:
Optional
[
int
]
=
None
...
@@ -162,9 +169,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -162,9 +169,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
'''
'''
All attention metadata required for encoder attention is set.
All attention metadata required for encoder attention is set.
'''
'''
return
((
self
.
encoder_seq_lens
is
not
None
)
return
is_all_encoder_attn_metadata_set
(
self
)
and
(
self
.
encoder_seq_lens_tensor
is
not
None
)
and
(
self
.
max_encoder_seq_len
is
not
None
))
@
property
@
property
def
is_all_cross_attn_metadata_set
(
self
):
def
is_all_cross_attn_metadata_set
(
self
):
...
@@ -173,9 +178,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -173,9 +178,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
Superset of encoder attention required metadata.
Superset of encoder attention required metadata.
'''
'''
return
(
self
.
is_all_encoder_attn_metadata_set
return
is_all_cross_attn_metadata_set
(
self
)
and
(
self
.
cross_slot_mapping
is
not
None
)
and
(
self
.
cross_block_tables
is
not
None
))
@
property
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"XFormersMetadata"
]:
def
prefill_metadata
(
self
)
->
Optional
[
"XFormersMetadata"
]:
...
@@ -329,64 +332,6 @@ def _set_attn_bias(
...
@@ -329,64 +332,6 @@ def _set_attn_bias(
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
_get_seq_len_block_table_args
(
attn_metadata
:
XFormersMetadata
,
is_prompt
:
bool
,
attn_type
:
AttentionType
,
)
->
tuple
:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if
attn_type
==
AttentionType
.
DECODER
:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if
is_prompt
:
max_seq_len
=
attn_metadata
.
max_prefill_seq_len
else
:
max_seq_len
=
attn_metadata
.
max_decode_seq_len
return
(
attn_metadata
.
seq_lens_tensor
,
max_seq_len
,
attn_metadata
.
block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return
(
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
cross_block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER
:
# No block tables associated with encoder attention
return
(
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
None
)
elif
attn_type
==
AttentionType
.
ENCODER_ONLY
:
assert
is_prompt
,
"Should not have decode for encoder only model."
# No block tables associated with encoder attention
return
(
attn_metadata
.
seq_lens_tensor
,
attn_metadata
.
max_prefill_seq_len
,
None
)
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
class
XFormersMetadataBuilder
(
CommonMetadataBuilder
[
XFormersMetadata
]):
class
XFormersMetadataBuilder
(
CommonMetadataBuilder
[
XFormersMetadata
]):
_metadata_cls
=
XFormersMetadata
_metadata_cls
=
XFormersMetadata
...
@@ -574,45 +519,21 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -574,45 +519,21 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
updated_slot_mapping
,
updated_slot_mapping
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
)
k_scale
,
v_scale
)
(
num_prefill_query_tokens
,
num_prefill_kv_tokens
,
if
attn_type
==
AttentionType
.
ENCODER
:
num_decode_query_tokens
)
=
\
# Encoder attention - chunked prefill is not applicable;
get_num_prefill_decode_query_kv_tokens
(
attn_metadata
,
attn_type
)
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
num_encoder_tokens
=
attn_metadata
.
num_encoder_tokens
num_decode_tokens
=
0
elif
attn_type
==
AttentionType
.
DECODER
:
# Decoder self-attention supports chunked prefill.
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_encoder_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
# Only enforce this shape-constraint for decoder
# self-attention
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
else
:
# attn_type == AttentionType.ENCODER_DECODER
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
if
attn_metadata
.
num_encoder_tokens
is
not
None
:
num_encoder_tokens
=
attn_metadata
.
num_encoder_tokens
else
:
num_encoder_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
decode_query
=
query
[
num_prefill_
query_
tokens
:]
# QKV for prefill.
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
query
=
query
[:
num_prefill_
query_
tokens
]
if
key
is
not
None
and
value
is
not
None
:
if
key
is
not
None
and
value
is
not
None
:
key
=
key
[:
num_
encoder
_tokens
]
key
=
key
[:
num_
prefill_kv
_tokens
]
value
=
value
[:
num_
encoder
_tokens
]
value
=
value
[:
num_
prefill_kv
_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
query
.
shape
[
0
]
==
num_prefill_
query_
tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_
query_
tokens
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
# Prompt run.
...
@@ -622,8 +543,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -622,8 +543,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# prefix.
# prefix.
out
=
self
.
_run_memory_efficient_xformers_forward
(
out
=
self
.
_run_memory_efficient_xformers_forward
(
query
,
key
,
value
,
prefill_meta
,
attn_type
=
attn_type
)
query
,
key
,
value
,
prefill_meta
,
attn_type
=
attn_type
)
assert
out
.
shape
==
output
[:
num_prefill_tokens
].
shape
assert
out
.
shape
==
output
[:
num_prefill_
query_
tokens
].
shape
output
[:
num_prefill_tokens
]
=
out
output
[:
num_prefill_
query_
tokens
]
=
out
else
:
else
:
assert
attn_type
!=
AttentionType
.
ENCODER_ONLY
,
(
assert
attn_type
!=
AttentionType
.
ENCODER_ONLY
,
(
"Encoder-only models should not have prefix attention."
)
"Encoder-only models should not have prefix attention."
)
...
@@ -652,8 +573,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -652,8 +573,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
k_scale
,
k_scale
,
v_scale
,
v_scale
,
)
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
assert
output
[:
num_prefill_
query_
tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
output
[:
num_prefill_
query_
tokens
]
=
out
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
assert
attn_type
!=
AttentionType
.
ENCODER_ONLY
,
(
assert
attn_type
!=
AttentionType
.
ENCODER_ONLY
,
(
...
@@ -663,9 +584,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -663,9 +584,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
seq_lens_arg
,
seq_lens_arg
,
max_seq_len_arg
,
max_seq_len_arg
,
block_tables_arg
,
block_tables_arg
,
)
=
_
get_seq_len_block_table_args
(
decode_meta
,
False
,
attn_type
)
)
=
get_seq_len_block_table_args
(
decode_meta
,
False
,
attn_type
)
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decode
(
output
[
num_prefill_
query_
tokens
:]
=
PagedAttention
.
forward_decode
(
decode_query
,
decode_query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
...
...
vllm/attention/selector.py
View file @
a78dd330
...
@@ -98,7 +98,6 @@ def get_attn_backend(
...
@@ -98,7 +98,6 @@ def get_attn_backend(
is_blocksparse
:
bool
=
False
,
is_blocksparse
:
bool
=
False
,
)
->
Type
[
AttentionBackend
]:
)
->
Type
[
AttentionBackend
]:
"""Selects which attention backend to use and lazily imports it."""
"""Selects which attention backend to use and lazily imports it."""
if
is_blocksparse
:
if
is_blocksparse
:
logger
.
info
(
"Using BlocksparseFlashAttention backend."
)
logger
.
info
(
"Using BlocksparseFlashAttention backend."
)
from
vllm.attention.backends.blocksparse_attn
import
(
from
vllm.attention.backends.blocksparse_attn
import
(
...
@@ -108,6 +107,7 @@ def get_attn_backend(
...
@@ -108,6 +107,7 @@ def get_attn_backend(
backend
=
which_attn_to_use
(
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
backend
=
which_attn_to_use
(
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
is_attention_free
)
is_attention_free
)
if
backend
==
_Backend
.
FLASH_ATTN
:
if
backend
==
_Backend
.
FLASH_ATTN
:
logger
.
info
(
"Using Flash Attention backend."
)
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
FlashAttentionBackend
)
FlashAttentionBackend
)
return
FlashAttentionBackend
return
FlashAttentionBackend
...
...
vllm/model_executor/models/bart.py
View file @
a78dd330
...
@@ -624,8 +624,6 @@ class BartEncoder(nn.Module):
...
@@ -624,8 +624,6 @@ class BartEncoder(nn.Module):
Decoder output torch.Tensor
Decoder output torch.Tensor
"""
"""
# retrieve input_ids and inputs_embeds
# retrieve input_ids and inputs_embeds
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
shape
[
-
1
])
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
embed_pos
=
self
.
embed_positions
(
embed_pos
=
self
.
embed_positions
(
...
...
vllm/utils.py
View file @
a78dd330
...
@@ -80,8 +80,8 @@ STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
...
@@ -80,8 +80,8 @@ STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
"currently supported with encoder/"
"currently supported with encoder/"
"decoder models."
)
"decoder models."
)
STR_NOT_IMPL_ENC_DEC_BACKEND
=
(
"XFormers
is
the only
backend
"
STR_NOT_IMPL_ENC_DEC_BACKEND
=
(
"XFormers
and Flash-Attention are
the only "
"currently supported with encoder/"
"
backends
currently supported with encoder/"
"decoder models."
)
"decoder models."
)
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER
=
(
"Prompt adapters are not "
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER
=
(
"Prompt adapters are not "
...
...
vllm/worker/enc_dec_model_runner.py
View file @
a78dd330
...
@@ -19,6 +19,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry
...
@@ -19,6 +19,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.utils
import
get_architecture_class_name
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalInputs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalInputs
,
MultiModalRegistry
)
MultiModalRegistry
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
...
@@ -36,6 +37,11 @@ from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
...
@@ -36,6 +37,11 @@ from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# The Mllama model has PagedAttention specific logic because of which it
# can only be run with the XFORMERS backend
# TODO Make Mllama model work with Flash Attention backend.
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS
=
[
"MllamaForConditionalGeneration"
]
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
EncoderDecoderModelInput
(
ModelInputForGPUWithSamplingMetadata
):
class
EncoderDecoderModelInput
(
ModelInputForGPUWithSamplingMetadata
):
...
@@ -101,9 +107,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -101,9 +107,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
models) but these arguments are present here for compatibility with
models) but these arguments are present here for compatibility with
the base-class constructor.
the base-class constructor.
'''
'''
self
.
_maybe_force_supported_attention_backend
(
model_config
)
self
.
_maybe_force_supported_attention_backend
()
super
().
__init__
(
super
().
__init__
(
model_config
,
model_config
,
parallel_config
,
parallel_config
,
...
@@ -119,7 +123,12 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -119,7 +123,12 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# Crash for unsupported encoder/scenarios
# Crash for unsupported encoder/scenarios
assert_enc_dec_mr_supported_scenario
(
self
)
assert_enc_dec_mr_supported_scenario
(
self
)
def
_maybe_force_supported_attention_backend
(
self
):
def
_is_xformers_only_encoder_decoder_model
(
self
,
model
:
ModelConfig
)
->
bool
:
return
get_architecture_class_name
(
model
)
in
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS
def
_maybe_force_supported_attention_backend
(
self
,
model
:
ModelConfig
):
'''
'''
Force vLLM to use the XFormers attention backend,
Force vLLM to use the XFormers attention backend,
which is currently the only supported option.
which is currently the only supported option.
...
@@ -135,22 +144,26 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -135,22 +144,26 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
is_forced_by_global
=
maybe_global_forced_backend
is
not
None
is_forced_by_global
=
maybe_global_forced_backend
is
not
None
is_forced_by_env_var
=
maybe_env_var_forced_backend
is
not
None
is_forced_by_env_var
=
maybe_env_var_forced_backend
is
not
None
if
not
(
is_forced_by_global
or
is_forced_by_env_var
):
if
not
(
is_forced_by_global
or
is_forced_by_env_var
)
\
and
self
.
_is_xformers_only_encoder_decoder_model
(
model
):
# The user has not already specified an attention backend
# The user has not already specified an attention backend
# override
# override
logger
.
info
(
"EncoderDecoderModelRunner requires "
logger
.
info
(
"XFormers backend; overriding backend "
"Encoder-Decoder Model Architecture %s requires XFormers "
"auto-selection and forcing XFormers."
)
"backend; overriding backend auto-selection and "
"forcing XFormers."
,
get_architecture_class_name
(
model
))
global_force_attn_backend
(
_Backend
.
XFORMERS
)
global_force_attn_backend
(
_Backend
.
XFORMERS
)
elif
is_forced_by_global
:
elif
is_forced_by_global
:
# Backend override enforced by global variable takes
# Backend override enforced by global variable takes
# precedence over vLLM backend environment variable.
# precedence over vLLM backend environment variable.
if
maybe_global_forced_backend
!=
_Backend
.
XFORMERS
:
if
maybe_global_forced_backend
not
in
\
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]:
raise_backend_err
()
raise_backend_err
()
elif
is_forced_by_env_var
:
elif
is_forced_by_env_var
:
# Backend override enforced by vLLM backend
# Backend override enforced by vLLM backend
# environment variable
# environment variable
if
maybe_env_var_forced_backend
!=
_Backend
.
XFORMERS
:
if
maybe_env_var_forced_backend
not
in
\
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]:
raise_backend_err
()
raise_backend_err
()
def
_list_to_int32_tensor
(
def
_list_to_int32_tensor
(
...
@@ -532,6 +545,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -532,6 +545,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
attn_metadata
.
encoder_seq_lens
,
attn_metadata
.
encoder_seq_lens
,
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
encoder_seq_start_loc
,
attn_metadata
.
cross_slot_mapping
,
attn_metadata
.
cross_slot_mapping
,
attn_metadata
.
cross_block_tables
,
attn_metadata
.
cross_block_tables
,
)
=
(
)
=
(
...
@@ -539,6 +553,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -539,6 +553,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
encoder_seq_lens
,
encoder_seq_lens
,
encoder_seq_lens_tensor
,
encoder_seq_lens_tensor
,
max_encoder_seq_len
,
max_encoder_seq_len
,
encoder_seq_start_loc
,
cross_slot_mapping_tensor
,
cross_slot_mapping_tensor
,
cross_block_tables
,
cross_block_tables
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment