Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a78dd330
Unverified
Commit
a78dd330
authored
Nov 01, 2024
by
sroy745
Committed by
GitHub
Nov 01, 2024
Browse files
[Encoder Decoder] Add flash_attn kernel support for encoder-decoder models (#9559)
parent
d522034c
Changes
11
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
716 additions
and
317 deletions
+716
-317
tests/encoder_decoder/test_e2e_correctness.py
tests/encoder_decoder/test_e2e_correctness.py
+51
-37
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+115
-41
tests/kernels/utils.py
tests/kernels/utils.py
+74
-16
tests/models/encoder_decoder/vision_language/test_florence2.py
.../models/encoder_decoder/vision_language/test_florence2.py
+1
-1
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+278
-86
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+143
-16
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+26
-105
vllm/attention/selector.py
vllm/attention/selector.py
+1
-1
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+0
-2
vllm/utils.py
vllm/utils.py
+2
-2
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+25
-10
No files found.
tests/encoder_decoder/test_e2e_correctness.py
View file @
a78dd330
...
...
@@ -7,12 +7,18 @@ from typing import List, Optional, Tuple
import
pytest
from
transformers
import
AutoModelForSeq2SeqLM
from
vllm.attention.selector
import
(
_Backend
,
global_force_attn_backend_context_manager
)
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
SampleLogprobs
from
..conftest
import
DecoderPromptType
from
..models.utils
import
check_logprobs_close
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
,
None
]
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]],
...
...
@@ -29,7 +35,8 @@ def vllm_to_hf_output(
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/bart-large-cnn"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
LIST_ENC_DEC_SUPPORTED_BACKENDS
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"decoder_prompt_type"
,
list
(
DecoderPromptType
))
...
...
@@ -48,6 +55,7 @@ def test_encoder_decoder_e2e(
num_logprobs
:
int
,
decoder_prompt_type
:
DecoderPromptType
,
enforce_eager
:
bool
,
attn_backend
:
_Backend
,
)
->
None
:
'''
End-to-End (E2E) test for the encoder-decoder framework.
...
...
@@ -56,7 +64,12 @@ def test_encoder_decoder_e2e(
implementations to ensure that both implementations produce consistent
and correct results.
'''
test_case_prompts
=
example_encoder_decoder_prompts
[
decoder_prompt_type
]
with
global_force_attn_backend_context_manager
(
attn_backend
):
if
attn_backend
==
_Backend
.
FLASH_ATTN
:
# Flash Attention works only with bfloat16 data-type
dtype
=
'bfloat16'
test_case_prompts
=
example_encoder_decoder_prompts
[
decoder_prompt_type
]
# Configuration settings for HF baseline
hf_kwargs
=
{
...
...
@@ -72,7 +85,8 @@ def test_encoder_decoder_e2e(
with
hf_runner
(
model
,
dtype
=
dtype
,
auto_cls
=
AutoModelForSeq2SeqLM
)
as
hf_model
:
hf_outputs
=
(
hf_model
.
generate_encoder_decoder_greedy_logprobs_limit
(
hf_outputs
=
(
hf_model
.
generate_encoder_decoder_greedy_logprobs_limit
(
test_case_prompts
,
max_tokens
,
num_logprobs
,
...
...
@@ -83,8 +97,8 @@ def test_encoder_decoder_e2e(
vllm_outputs
=
vllm_model
.
generate_encoder_decoder_greedy_logprobs
(
test_case_prompts
,
max_tokens
,
num_logprobs
)
hf_skip_tokens
=
(
1
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
else
0
)
hf_skip_tokens
=
(
1
if
decoder_prompt_type
==
DecoderPromptType
.
NONE
else
0
)
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
...
...
tests/kernels/test_encoder_decoder_attn.py
View file @
a78dd330
...
...
@@ -16,13 +16,13 @@ from tests.kernels.utils import *
from
vllm.attention
import
(
Attention
,
AttentionBackend
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from
vllm.attention.selector
import
(
_Backend
,
from
vllm.attention.selector
import
(
_Backend
,
get_attn_backend
,
global_force_attn_backend_context_manager
)
from
vllm.forward_context
import
set_forward_context
from
vllm.platforms
import
current_platform
# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
]
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]
HEAD_SIZES
=
[
64
,
256
]
NUM_HEADS
=
[
1
,
16
]
...
...
@@ -145,7 +145,8 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
test_pt
.
num_heads
,
test_pt
.
head_size
,
test_pt
.
block_size
,
device
=
CUDA_DEVICE
)
device
=
CUDA_DEVICE
,
backend
=
test_pt
.
backend_name
)
return
TestResources
(
scale
,
attn_backend
,
attn
,
kv_cache
)
...
...
@@ -592,6 +593,7 @@ def _run_encoder_attention_test(
attn
:
Attention
,
encoder_test_params
:
PhaseTestParameters
,
attn_metadata
:
AttentionMetadata
,
test_pt
:
TestPoint
,
)
->
torch
.
Tensor
:
'''
Run encoder attention.
...
...
@@ -610,6 +612,8 @@ def _run_encoder_attention_test(
(number_of_tokens x num_heads x head_size)
query/key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
* test_pt: The TestPoint object containing test details like number of
model heads, head size, name of the backend being used etc.
Returns:
* Attention.forward() applied to packed {query,key,value} and
...
...
@@ -619,7 +623,17 @@ def _run_encoder_attention_test(
attn_type
=
AttentionType
.
ENCODER
packed_qkv
=
encoder_test_params
.
packed_qkvo
.
packed_qkv
assert
packed_qkv
is
not
None
return
attn
.
forward
(
packed_qkv
.
query
,
with
set_forward_context
(
attn_metadata
):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
# [num_tokens, hidden_size]. Hence reshape the query before
# invoking the forward method.
# TODO - Update the way we construct the query so that it
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
torch
.
tensor
([],
...
...
@@ -633,6 +647,7 @@ def _run_decoder_self_attention_test(
test_rsrcs
:
TestResources
,
decoder_test_params
:
PhaseTestParameters
,
attn_metadata
:
AttentionMetadata
,
test_pt
:
TestPoint
,
)
->
torch
.
Tensor
:
'''
Run decoder self-attention test.
...
...
@@ -650,6 +665,8 @@ def _run_decoder_self_attention_test(
query/key/value fields
* attn_metadata: attention metadata for decoder-self attention
(contains KV cache memory-mapping)
* test_pt: The TestPoint object containing test details like number of
model heads, head size, name of the backend being used etc.
Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache
...
...
@@ -660,7 +677,17 @@ def _run_decoder_self_attention_test(
kv_cache
=
test_rsrcs
.
kv_cache
packed_qkv
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
assert
packed_qkv
is
not
None
return
attn
.
forward
(
packed_qkv
.
query
,
with
set_forward_context
(
attn_metadata
):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
# [num_tokens, hidden_size]. Hence reshape the query before
# invoking the forward method.
# TODO - Update the way we construct the query so that it
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
kv_cache
,
...
...
@@ -673,6 +700,7 @@ def _run_encoder_decoder_cross_attention_test(
decoder_test_params
:
PhaseTestParameters
,
cross_test_params
:
Optional
[
PhaseTestParameters
],
attn_metadata
:
AttentionMetadata
,
test_pt
:
TestPoint
,
)
->
torch
.
Tensor
:
'''
Run encoder/decoder cross-attention test.
...
...
@@ -701,6 +729,8 @@ def _run_encoder_decoder_cross_attention_test(
(number_of_tokens x num_heads x head_size)
key/value fields
* attn_metadata: attention metadata for encoder/decoder-self attention
* test_pt: The TestPoint object containing test details like number of
model heads, head size, name of the backend being used etc.
Returns:
* Attention.forward() applied to packed_{query,key,value}, kv_cache
...
...
@@ -718,7 +748,17 @@ def _run_encoder_decoder_cross_attention_test(
cross_pckd_qkv
=
cross_test_params
.
packed_qkvo
.
packed_qkv
key
=
(
None
if
cross_pckd_qkv
is
None
else
cross_pckd_qkv
.
key
)
value
=
(
None
if
cross_pckd_qkv
is
None
else
cross_pckd_qkv
.
value
)
return
attn
.
forward
(
decoder_test_params
.
packed_qkvo
.
packed_qkv
.
query
,
with
set_forward_context
(
attn_metadata
):
# In the test setup the shape of the query is
# [batch_size, seq_len, num_heads, head_size]. However
# the attention backend expect the shape to be
# [num_tokens, hidden_size]. Hence reshape the query before
# invoking the forward method.
# TODO - Update the way we construct the query so that it
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
key
,
value
,
kv_cache
,
...
...
@@ -726,6 +766,21 @@ def _run_encoder_decoder_cross_attention_test(
attn_type
=
attn_type
)
@
pytest
.
fixture
(
autouse
=
True
)
def
set_reset_environment
(
attn_backend
):
# Set the default torch datatype to bfloat16 to enable
# testing of the Flash Attention backend. Also clear the
# cached value of the backend.
default_dtype
=
torch
.
get_default_dtype
()
if
attn_backend
.
name
==
'FLASH_ATTN'
:
torch
.
set_default_dtype
(
torch
.
bfloat16
)
get_attn_backend
.
cache_clear
()
yield
# Reset the torch datatype to what it was before the test
# so as not to impact the remaining tests.
torch
.
set_default_dtype
(
default_dtype
)
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
...
...
@@ -773,10 +828,8 @@ def test_encoder_only(
* max_dec_seq_len: max length of decoder input sequences
* max_enc_seq_len: max length of encoder input sequences
'''
# Force Attention wrapper backend
with
global_force_attn_backend_context_manager
(
attn_backend
):
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
...
...
@@ -807,10 +860,14 @@ def test_encoder_only(
# PREFILL: encoder attention
enc_pckd_act_out
:
torch
.
Tensor
=
(
_run_encoder_attention_test
(
test_rsrcs
.
attn
,
enc_test_params
,
prephase_attn_metadata
))
test_rsrcs
.
attn
,
enc_test_params
,
prephase_attn_metadata
,
test_pt
=
test_pt
))
# - Is encoder attention result correct?
assert_actual_matches_ideal
(
enc_test_params
,
enc_pckd_act_out
)
assert_actual_matches_ideal
(
enc_test_params
,
enc_pckd_act_out
,
attn_backend
.
name
)
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
...
...
@@ -892,10 +949,8 @@ def test_e2e_enc_dec_attn(
* max_dec_seq_len: max length of decoder input sequences
* max_enc_seq_len: max length of encoder input sequences
'''
# Force Attention wrapper backend
with
global_force_attn_backend_context_manager
(
attn_backend
):
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
...
...
@@ -955,29 +1010,39 @@ def test_e2e_enc_dec_attn(
enc_pckd_act_out
=
_run_encoder_attention_test
(
test_rsrcs
.
attn
,
enc_test_params
,
prephase_attn_metadata
)
prephase_attn_metadata
,
test_pt
=
test_pt
)
# - Is encoder attention result correct?
assert_actual_matches_ideal
(
enc_test_params
,
enc_pckd_act_out
)
assert_actual_matches_ideal
(
enc_test_params
,
enc_pckd_act_out
,
attn_backend
.
name
)
# PREFILL: decoder self-attention test
prephase_dec_pckd_act_out
=
_run_decoder_self_attention_test
(
test_rsrcs
,
prephase_dec_test_params
,
prephase_attn_metadata
)
test_rsrcs
,
prephase_dec_test_params
,
prephase_attn_metadata
,
test_pt
=
test_pt
)
# - Is prefill decoder self-attention correct?
assert_actual_matches_ideal
(
prephase_dec_test_params
,
prephase_dec_pckd_act_out
)
prephase_dec_pckd_act_out
,
attn_backend
.
name
)
# PREFILL: encoder/decoder cross-attention test
prephase_cross_pckd_act_out
=
_run_encoder_decoder_cross_attention_test
(
test_rsrcs
,
prephase_dec_test_params
,
prephase_cross_test_params
,
prephase_attn_metadata
)
test_rsrcs
,
prephase_dec_test_params
,
prephase_cross_test_params
,
prephase_attn_metadata
,
test_pt
=
test_pt
)
# - Is prefill encoder/decoder cross-attention correct?
assert_actual_matches_ideal
(
prephase_cross_test_params
,
prephase_cross_pckd_act_out
)
prephase_cross_pckd_act_out
,
attn_backend
.
name
)
# DECODE: build decode-phase attention metadata
...
...
@@ -993,17 +1058,26 @@ def test_e2e_enc_dec_attn(
# DECODE: decoder self-attention test
decphase_dec_pckd_act_out
=
_run_decoder_self_attention_test
(
test_rsrcs
,
decphase_dec_test_params
,
decphase_attn_metadata
)
test_rsrcs
,
decphase_dec_test_params
,
decphase_attn_metadata
,
test_pt
=
test_pt
)
# - Is decode-phase decoder self-attention correct?
assert_actual_matches_ideal
(
decphase_dec_test_params
,
decphase_dec_pckd_act_out
)
decphase_dec_pckd_act_out
,
attn_backend
.
name
)
# DECODE: encoder/decoder cross-attention test
decphase_cross_pckd_act_out
=
_run_encoder_decoder_cross_attention_test
(
test_rsrcs
,
decphase_dec_test_params
,
None
,
decphase_attn_metadata
)
test_rsrcs
,
decphase_dec_test_params
,
None
,
decphase_attn_metadata
,
test_pt
=
test_pt
)
# - Is decode-phase encoder/decoder cross-attention correct?
assert_actual_matches_ideal
(
decphase_cross_test_params
,
decphase_cross_pckd_act_out
)
decphase_cross_pckd_act_out
,
attn_backend
.
name
)
tests/kernels/utils.py
View file @
a78dd330
...
...
@@ -13,8 +13,8 @@ from torch._prims_common import TensorLikeType
from
vllm.attention
import
AttentionBackend
,
AttentionMetadata
,
AttentionType
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.utils
import
(
STR_BACKEND_ENV_VAR
,
STR_
XFORMERS
_ATTN_VAL
,
make_tensor_with_pad
)
from
vllm.utils
import
(
STR_BACKEND_ENV_VAR
,
STR_
FLASH
_ATTN_VAL
,
STR_XFORMERS_ATTN_VAL
,
make_tensor_with_pad
)
# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
...
...
@@ -525,17 +525,22 @@ def make_backend(backend_name: str) -> AttentionBackend:
if
backend_name
==
STR_XFORMERS_ATTN_VAL
:
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
from
vllm.attention.backends.xformers
import
XFormersBackend
return
XFormersBackend
()
elif
backend_name
==
STR_FLASH_ATTN_VAL
:
from
vllm.attention.backends.flash_attn
import
FlashAttentionBackend
return
FlashAttentionBackend
()
raise
AssertionError
(
f
"Unrecognized backend_name
{
backend_name
}
for unit test"
)
def
_make_metadata_tensors
(
seq_lens
:
Optional
[
List
[
int
]],
context_lens
:
Optional
[
List
[
int
]],
encoder_seq_lens
:
Optional
[
List
[
int
]],
device
:
Union
[
torch
.
device
,
str
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Any
,
Any
,
Optional
[
List
[
int
]],
torch
.
Tensor
,
Optional
[
int
]]:
seq_lens
:
Optional
[
List
[
int
]],
context_lens
:
Optional
[
List
[
int
]],
encoder_seq_lens
:
Optional
[
List
[
int
]],
device
:
Union
[
torch
.
device
,
str
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Any
,
Any
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
int
]]:
'''
Build scalar & tensor values required to build attention metadata structure.
...
...
@@ -553,6 +558,8 @@ def _make_metadata_tensors(
* max_context_len: max(context_lens)
* max_seq_len: max(seq_lens)
* seq_start_loc: start idx of each sequence
* encoder_seq_lens_tensor: encoder seq_lens list, as tensor
* encoder_seq_start_loc: start idx of each encoder sequence
* max_encoder_seq_len: encoder seq_lens list, as tensor
'''
seq_lens_tensor
=
maybe_make_int_tensor
(
seq_lens
,
device
)
...
...
@@ -566,8 +573,26 @@ def _make_metadata_tensors(
seq_start_loc
=
None
if
seq_lens_tensor
is
not
None
:
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens_tensor
.
device
)
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
encoder_seq_start_loc
=
torch
.
zeros
(
encoder_seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
encoder_seq_lens_tensor
.
device
)
torch
.
cumsum
(
encoder_seq_lens_tensor
,
dim
=
0
,
dtype
=
encoder_seq_start_loc
.
dtype
,
out
=
encoder_seq_start_loc
[
1
:])
return
(
seq_lens_tensor
,
context_lens_tensor
,
max_context_len
,
max_seq_len
,
seq_start_loc
,
encoder_seq_lens_tensor
,
max_encoder_seq_len
)
seq_start_loc
,
encoder_seq_lens_tensor
,
encoder_seq_start_loc
,
max_encoder_seq_len
)
def
make_kv_cache
(
num_blocks
:
int
,
...
...
@@ -575,6 +600,7 @@ def make_kv_cache(num_blocks: int,
head_size
:
int
,
block_size
:
int
,
device
:
Union
[
torch
.
device
,
str
],
backend
:
str
,
default_val
:
float
=
0.0
)
->
torch
.
Tensor
:
'''
Create a fake KV cache.
...
...
@@ -591,10 +617,20 @@ def make_kv_cache(num_blocks: int,
Returns:
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
* for backend 'XFORMERS'
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
* for backend 'FLASH_ATTN'
'''
if
backend
==
'XFORMERS'
:
kv_cache
=
torch
.
rand
(
(
2
,
num_blocks
,
block_size
*
num_heads
*
head_size
)).
to
(
device
)
elif
backend
==
'FLASH_ATTN'
:
kv_cache
=
torch
.
rand
(
(
2
,
num_blocks
,
block_size
,
num_heads
,
head_size
)).
to
(
device
)
else
:
raise
ValueError
(
f
"Unknown backend value: '
{
backend
}
'. Expected 'XFORMERS' or "
f
"'FLASH_ATTN'."
)
if
default_val
is
not
None
:
kv_cache
[:,
:,
:]
=
default_val
return
kv_cache
...
...
@@ -858,8 +894,9 @@ def make_test_metadata(
context_lens_tensor
,
_
,
_
,
_
,
seq_start_loc
,
encoder_seq_lens_tensor
,
encoder_seq_start_loc
,
max_encoder_seq_len
,
)
=
_make_metadata_tensors
(
seq_lens
,
context_lens
,
...
...
@@ -874,6 +911,7 @@ def make_test_metadata(
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
seq_start_loc
=
seq_start_loc
,
max_prefill_seq_len
=
None
if
seq_lens
is
None
else
max
(
seq_lens
),
max_decode_seq_len
=
0
,
context_lens_tensor
=
context_lens_tensor
,
...
...
@@ -882,6 +920,7 @@ def make_test_metadata(
num_encoder_tokens
=
num_encoder_tokens
,
encoder_seq_lens
=
encoder_seq_lens
,
encoder_seq_lens_tensor
=
encoder_seq_lens_tensor
,
encoder_seq_start_loc
=
encoder_seq_start_loc
,
max_encoder_seq_len
=
max_encoder_seq_len
,
cross_slot_mapping
=
(
None
if
cross_kv_mmap
is
None
else
cross_kv_mmap
.
slot_mapping
),
...
...
@@ -904,8 +943,9 @@ def make_test_metadata(
context_lens_tensor
,
_
,
_
,
_
,
seq_start_loc
,
encoder_seq_lens_tensor
,
encoder_seq_start_loc
,
max_encoder_seq_len
,
)
=
_make_metadata_tensors
(
seq_lens
,
context_lens
,
...
...
@@ -920,14 +960,17 @@ def make_test_metadata(
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
seq_start_loc
=
seq_start_loc
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
max
(
seq_lens
),
max_decode_query_len
=
1
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
kv_mmap
.
block_tables
,
use_cuda_graph
=
False
,
num_encoder_tokens
=
num_encoder_tokens
,
encoder_seq_lens
=
encoder_seq_lens
,
encoder_seq_lens_tensor
=
encoder_seq_lens_tensor
,
encoder_seq_start_loc
=
encoder_seq_start_loc
,
max_encoder_seq_len
=
max_encoder_seq_len
,
cross_slot_mapping
=
(
None
if
cross_kv_mmap
is
None
else
cross_kv_mmap
.
slot_mapping
),
...
...
@@ -936,7 +979,8 @@ def make_test_metadata(
def
assert_actual_matches_ideal
(
test_params
:
PhaseTestParameters
,
output_under_test
:
torch
.
Tensor
)
->
None
:
output_under_test
:
torch
.
Tensor
,
backend
:
str
)
->
None
:
'''
Assert that observed output matches the ideal output
contained in the test parameters data structure.
...
...
@@ -947,9 +991,23 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
* output_under_test: actually observed output value
'''
ideal_output
=
test_params
.
packed_qkvo
.
ideal_output
if
backend
==
'XFORMERS'
:
torch
.
testing
.
assert_close
(
ideal_output
,
output_under_test
.
view_as
(
ideal_output
))
elif
backend
==
'FLASH_ATTN'
:
# For FlashAttention override the accuracy thresholds to non default
# values since we notice a higher difference between the ideal and
# actual output.
torch
.
testing
.
assert_close
(
ideal_output
,
output_under_test
.
view_as
(
ideal_output
),
atol
=
0.01
,
rtol
=
0.016
)
else
:
raise
ValueError
(
f
"Unknown backend value: '
{
backend
}
'. Expected 'XFORMERS' or "
f
"'FLASH_ATTN'."
)
# Copied/modified from torch._refs.__init__.py
def
fp8_allclose
(
...
...
tests/models/encoder_decoder/vision_language/test_florence2.py
View file @
a78dd330
...
...
@@ -85,7 +85,7 @@ def run_test(
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
,
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
hf_runner
,
vllm_runner
,
model
,
dtype
,
max_tokens
,
...
...
vllm/attention/backends/flash_attn.py
View file @
a78dd330
This diff is collapsed.
Click to expand it.
vllm/attention/backends/utils.py
View file @
a78dd330
"""Attention backend utils"""
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Type
,
TypeVar
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Tuple
,
Type
,
TypeVar
,
Union
import
numpy
as
np
import
torch
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionState
)
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
...
...
@@ -336,11 +337,13 @@ class CommonAttentionState(AttentionState):
use_cuda_graph
=
True
,
)
if
is_encoder_decoder_model
:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
==
"XFORMERS"
,
\
f
"Expected attn_backend name to be 'XFORMERS', but "
\
f
" got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
in
\
[
"XFORMERS"
,
"FLASH_ATTN"
],
\
f
"Expected attn_backend name to be either 'XFORMERS' or "
\
f
"'FLASH_ATTN', but "
\
f
"got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
self
.
_update_captured_metadata_for_enc_dec_model
(
batch_size
=
batch_size
,
attn_metadata
=
attn_metadata
)
...
...
@@ -356,11 +359,13 @@ class CommonAttentionState(AttentionState):
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
}
if
is_encoder_decoder_model
:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
==
"XFORMERS"
,
\
f
"Expected attn_backend name to be 'XFORMERS', but "
\
f
" got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
in
\
[
"XFORMERS"
,
"FLASH_ATTN"
],
\
f
"Expected attn_backend name to be either 'XFORMERS' or "
\
f
"'FLASH_ATTN', but "
\
f
"got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
self
.
_add_additonal_input_buffers_for_enc_dec_model
(
attn_metadata
=
attn_metadata
,
input_buffers
=
input_buffers
)
return
input_buffers
...
...
@@ -375,11 +380,13 @@ class CommonAttentionState(AttentionState):
input_buffers
[
"block_tables"
].
copy_
(
attn_metadata
.
decode_metadata
.
block_tables
,
non_blocking
=
True
)
if
is_encoder_decoder_model
:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
==
"XFORMERS"
,
\
f
"Expected attn_backend name to be 'XFORMERS', but "
\
f
" got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
in
\
[
"XFORMERS"
,
"FLASH_ATTN"
],
\
f
"Expected attn_backend name to be either 'XFORMERS' or "
\
f
"'FLASH_ATTN', but "
\
f
"got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
self
.
_prepare_input_buffers_for_enc_dec_model
(
attn_metadata
,
input_buffers
)
...
...
@@ -411,6 +418,7 @@ class CommonAttentionState(AttentionState):
attn_metadata
.
encoder_seq_lens_tensor
=
torch
.
full
(
(
batch_size
,
),
1
,
dtype
=
torch
.
int
).
cuda
()
attn_metadata
.
max_encoder_seq_len
=
self
.
runner
.
max_seq_len_to_capture
attn_metadata
.
num_encoder_tokens
=
0
def
_add_additonal_input_buffers_for_enc_dec_model
(
self
,
attn_metadata
,
input_buffers
:
Dict
[
str
,
Any
]):
...
...
@@ -453,3 +461,122 @@ class CommonAttentionState(AttentionState):
input_buffers
[
"cross_block_tables"
].
copy_
(
attn_metadata
.
decode_metadata
.
cross_block_tables
,
non_blocking
=
True
)
def
is_all_encoder_attn_metadata_set
(
attn_metadata
):
'''
All attention metadata required for encoder attention is set.
'''
return
((
attn_metadata
.
encoder_seq_lens
is
not
None
)
and
(
attn_metadata
.
encoder_seq_lens_tensor
is
not
None
)
and
(
attn_metadata
.
max_encoder_seq_len
is
not
None
))
def
is_all_cross_attn_metadata_set
(
attn_metadata
):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return
(
attn_metadata
.
is_all_encoder_attn_metadata_set
and
(
attn_metadata
.
cross_slot_mapping
is
not
None
)
and
(
attn_metadata
.
cross_block_tables
is
not
None
))
def
get_seq_len_block_table_args
(
attn_metadata
,
is_prompt
:
bool
,
attn_type
:
AttentionType
,
)
->
tuple
:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if
attn_type
==
AttentionType
.
DECODER
:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if
is_prompt
:
max_seq_len
=
attn_metadata
.
max_prefill_seq_len
else
:
max_seq_len
=
attn_metadata
.
max_decode_seq_len
return
(
attn_metadata
.
seq_lens_tensor
,
max_seq_len
,
attn_metadata
.
block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return
(
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
cross_block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER
:
# No block tables associated with encoder attention
return
(
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
None
)
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
get_num_prefill_decode_query_kv_tokens
(
attn_metadata
,
attn_type
:
AttentionType
,
)
->
Tuple
[
int
,
int
,
int
]:
"""
Calculate the number of prefill and decode tokens for query, key/value
based on the attention metadata and the specified attention type.
Args:
attn_metadata (FlashAttentionMetadata): Attention Metadata object.
attn_type (AttentionType): The type of attention being used.
Returns:
Tuple[int, int, int]: A tuple containing three integers:
- The number of prefill query tokens.
- The number of prefill key/value tokens.
- The number of decode query tokens.
Raises:
AssertionError: If the number of encoder tokens in `attn_metadata`
is `None` when required for the calculations.
"""
num_prefill_query_tokens
=
0
num_decode_query_tokens
=
0
num_prefill_kv_tokens
=
0
if
attn_type
==
AttentionType
.
ENCODER
:
# Encoder attention is only invoked during prefill phase.
# The same input servers a both query and key.
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_query_tokens
=
attn_metadata
.
num_encoder_tokens
num_prefill_kv_tokens
=
attn_metadata
.
num_encoder_tokens
num_decode_query_tokens
=
0
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_query_tokens
=
attn_metadata
.
num_prefill_tokens
# The key is the encoder/cross-attention.
num_prefill_kv_tokens
=
attn_metadata
.
num_encoder_tokens
num_decode_query_tokens
=
attn_metadata
.
num_decode_tokens
else
:
# attn_type == AttentionType.DECODER or
# attn_type == AttentionType.ENCODER_ONLY
num_prefill_query_tokens
=
attn_metadata
.
num_prefill_tokens
num_prefill_kv_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_query_tokens
=
attn_metadata
.
num_decode_tokens
return
(
num_prefill_query_tokens
,
num_prefill_kv_tokens
,
num_decode_query_tokens
)
vllm/attention/backends/xformers.py
View file @
a78dd330
...
...
@@ -11,8 +11,10 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
CommonAttentionState
,
CommonMetadataBuilder
)
from
vllm.attention.backends.utils
import
(
CommonAttentionState
,
CommonMetadataBuilder
,
get_num_prefill_decode_query_kv_tokens
,
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
is_all_encoder_attn_metadata_set
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
...
...
@@ -135,6 +137,11 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Encoder sequence lengths representation
encoder_seq_lens
:
Optional
[
List
[
int
]]
=
None
encoder_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
encoder_seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum sequence length among encoder sequences
max_encoder_seq_len
:
Optional
[
int
]
=
None
...
...
@@ -162,9 +169,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
'''
All attention metadata required for encoder attention is set.
'''
return
((
self
.
encoder_seq_lens
is
not
None
)
and
(
self
.
encoder_seq_lens_tensor
is
not
None
)
and
(
self
.
max_encoder_seq_len
is
not
None
))
return
is_all_encoder_attn_metadata_set
(
self
)
@
property
def
is_all_cross_attn_metadata_set
(
self
):
...
...
@@ -173,9 +178,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
Superset of encoder attention required metadata.
'''
return
(
self
.
is_all_encoder_attn_metadata_set
and
(
self
.
cross_slot_mapping
is
not
None
)
and
(
self
.
cross_block_tables
is
not
None
))
return
is_all_cross_attn_metadata_set
(
self
)
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"XFormersMetadata"
]:
...
...
@@ -329,64 +332,6 @@ def _set_attn_bias(
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
_get_seq_len_block_table_args
(
attn_metadata
:
XFormersMetadata
,
is_prompt
:
bool
,
attn_type
:
AttentionType
,
)
->
tuple
:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if
attn_type
==
AttentionType
.
DECODER
:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if
is_prompt
:
max_seq_len
=
attn_metadata
.
max_prefill_seq_len
else
:
max_seq_len
=
attn_metadata
.
max_decode_seq_len
return
(
attn_metadata
.
seq_lens_tensor
,
max_seq_len
,
attn_metadata
.
block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return
(
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
cross_block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER
:
# No block tables associated with encoder attention
return
(
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
None
)
elif
attn_type
==
AttentionType
.
ENCODER_ONLY
:
assert
is_prompt
,
"Should not have decode for encoder only model."
# No block tables associated with encoder attention
return
(
attn_metadata
.
seq_lens_tensor
,
attn_metadata
.
max_prefill_seq_len
,
None
)
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
class
XFormersMetadataBuilder
(
CommonMetadataBuilder
[
XFormersMetadata
]):
_metadata_cls
=
XFormersMetadata
...
...
@@ -574,45 +519,21 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
updated_slot_mapping
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
)
if
attn_type
==
AttentionType
.
ENCODER
:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
num_encoder_tokens
=
attn_metadata
.
num_encoder_tokens
num_decode_tokens
=
0
elif
attn_type
==
AttentionType
.
DECODER
:
# Decoder self-attention supports chunked prefill.
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_encoder_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
# Only enforce this shape-constraint for decoder
# self-attention
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
else
:
# attn_type == AttentionType.ENCODER_DECODER
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
if
attn_metadata
.
num_encoder_tokens
is
not
None
:
num_encoder_tokens
=
attn_metadata
.
num_encoder_tokens
else
:
num_encoder_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
(
num_prefill_query_tokens
,
num_prefill_kv_tokens
,
num_decode_query_tokens
)
=
\
get_num_prefill_decode_query_kv_tokens
(
attn_metadata
,
attn_type
)
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
decode_query
=
query
[
num_prefill_
query_
tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
query
=
query
[:
num_prefill_
query_
tokens
]
if
key
is
not
None
and
value
is
not
None
:
key
=
key
[:
num_
encoder
_tokens
]
value
=
value
[:
num_
encoder
_tokens
]
key
=
key
[:
num_
prefill_kv
_tokens
]
value
=
value
[:
num_
prefill_kv
_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
assert
query
.
shape
[
0
]
==
num_prefill_
query_
tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_
query_
tokens
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
...
...
@@ -622,8 +543,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# prefix.
out
=
self
.
_run_memory_efficient_xformers_forward
(
query
,
key
,
value
,
prefill_meta
,
attn_type
=
attn_type
)
assert
out
.
shape
==
output
[:
num_prefill_tokens
].
shape
output
[:
num_prefill_tokens
]
=
out
assert
out
.
shape
==
output
[:
num_prefill_
query_
tokens
].
shape
output
[:
num_prefill_
query_
tokens
]
=
out
else
:
assert
attn_type
!=
AttentionType
.
ENCODER_ONLY
,
(
"Encoder-only models should not have prefix attention."
)
...
...
@@ -652,8 +573,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
k_scale
,
v_scale
,
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
assert
output
[:
num_prefill_
query_
tokens
].
shape
==
out
.
shape
output
[:
num_prefill_
query_
tokens
]
=
out
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
assert
attn_type
!=
AttentionType
.
ENCODER_ONLY
,
(
...
...
@@ -663,9 +584,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
seq_lens_arg
,
max_seq_len_arg
,
block_tables_arg
,
)
=
_
get_seq_len_block_table_args
(
decode_meta
,
False
,
attn_type
)
)
=
get_seq_len_block_table_args
(
decode_meta
,
False
,
attn_type
)
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decode
(
output
[
num_prefill_
query_
tokens
:]
=
PagedAttention
.
forward_decode
(
decode_query
,
key_cache
,
value_cache
,
...
...
vllm/attention/selector.py
View file @
a78dd330
...
...
@@ -98,7 +98,6 @@ def get_attn_backend(
is_blocksparse
:
bool
=
False
,
)
->
Type
[
AttentionBackend
]:
"""Selects which attention backend to use and lazily imports it."""
if
is_blocksparse
:
logger
.
info
(
"Using BlocksparseFlashAttention backend."
)
from
vllm.attention.backends.blocksparse_attn
import
(
...
...
@@ -108,6 +107,7 @@ def get_attn_backend(
backend
=
which_attn_to_use
(
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
is_attention_free
)
if
backend
==
_Backend
.
FLASH_ATTN
:
logger
.
info
(
"Using Flash Attention backend."
)
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
FlashAttentionBackend
)
return
FlashAttentionBackend
...
...
vllm/model_executor/models/bart.py
View file @
a78dd330
...
...
@@ -624,8 +624,6 @@ class BartEncoder(nn.Module):
Decoder output torch.Tensor
"""
# retrieve input_ids and inputs_embeds
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
shape
[
-
1
])
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
embed_pos
=
self
.
embed_positions
(
...
...
vllm/utils.py
View file @
a78dd330
...
...
@@ -80,8 +80,8 @@ STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
"currently supported with encoder/"
"decoder models."
)
STR_NOT_IMPL_ENC_DEC_BACKEND
=
(
"XFormers
is
the only
backend
"
"currently supported with encoder/"
STR_NOT_IMPL_ENC_DEC_BACKEND
=
(
"XFormers
and Flash-Attention are
the only "
"
backends
currently supported with encoder/"
"decoder models."
)
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER
=
(
"Prompt adapters are not "
...
...
vllm/worker/enc_dec_model_runner.py
View file @
a78dd330
...
...
@@ -19,6 +19,7 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.utils
import
get_architecture_class_name
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalInputs
,
MultiModalRegistry
)
from
vllm.sampling_params
import
SamplingParams
...
...
@@ -36,6 +37,11 @@ from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
logger
=
init_logger
(
__name__
)
# The Mllama model has PagedAttention specific logic because of which it
# can only be run with the XFORMERS backend
# TODO Make Mllama model work with Flash Attention backend.
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS
=
[
"MllamaForConditionalGeneration"
]
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
EncoderDecoderModelInput
(
ModelInputForGPUWithSamplingMetadata
):
...
...
@@ -101,9 +107,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
models) but these arguments are present here for compatibility with
the base-class constructor.
'''
self
.
_maybe_force_supported_attention_backend
()
self
.
_maybe_force_supported_attention_backend
(
model_config
)
super
().
__init__
(
model_config
,
parallel_config
,
...
...
@@ -119,7 +123,12 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# Crash for unsupported encoder/scenarios
assert_enc_dec_mr_supported_scenario
(
self
)
def
_maybe_force_supported_attention_backend
(
self
):
def
_is_xformers_only_encoder_decoder_model
(
self
,
model
:
ModelConfig
)
->
bool
:
return
get_architecture_class_name
(
model
)
in
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS
def
_maybe_force_supported_attention_backend
(
self
,
model
:
ModelConfig
):
'''
Force vLLM to use the XFormers attention backend,
which is currently the only supported option.
...
...
@@ -135,22 +144,26 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
is_forced_by_global
=
maybe_global_forced_backend
is
not
None
is_forced_by_env_var
=
maybe_env_var_forced_backend
is
not
None
if
not
(
is_forced_by_global
or
is_forced_by_env_var
):
if
not
(
is_forced_by_global
or
is_forced_by_env_var
)
\
and
self
.
_is_xformers_only_encoder_decoder_model
(
model
):
# The user has not already specified an attention backend
# override
logger
.
info
(
"EncoderDecoderModelRunner requires "
"XFormers backend; overriding backend "
"auto-selection and forcing XFormers."
)
logger
.
info
(
"Encoder-Decoder Model Architecture %s requires XFormers "
"backend; overriding backend auto-selection and "
"forcing XFormers."
,
get_architecture_class_name
(
model
))
global_force_attn_backend
(
_Backend
.
XFORMERS
)
elif
is_forced_by_global
:
# Backend override enforced by global variable takes
# precedence over vLLM backend environment variable.
if
maybe_global_forced_backend
!=
_Backend
.
XFORMERS
:
if
maybe_global_forced_backend
not
in
\
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]:
raise_backend_err
()
elif
is_forced_by_env_var
:
# Backend override enforced by vLLM backend
# environment variable
if
maybe_env_var_forced_backend
!=
_Backend
.
XFORMERS
:
if
maybe_env_var_forced_backend
not
in
\
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]:
raise_backend_err
()
def
_list_to_int32_tensor
(
...
...
@@ -532,6 +545,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
attn_metadata
.
encoder_seq_lens
,
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
encoder_seq_start_loc
,
attn_metadata
.
cross_slot_mapping
,
attn_metadata
.
cross_block_tables
,
)
=
(
...
...
@@ -539,6 +553,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
encoder_seq_lens
,
encoder_seq_lens_tensor
,
max_encoder_seq_len
,
encoder_seq_start_loc
,
cross_slot_mapping_tensor
,
cross_block_tables
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment