Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e20c92bb
Unverified
Commit
e20c92bb
authored
Jan 07, 2025
by
Chen Zhang
Committed by
GitHub
Jan 07, 2025
Browse files
[Kernel] Move attn_type to Attention.__init__() (#11690)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
32c9eff2
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
159 additions
and
201 deletions
+159
-201
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+49
-51
tests/kernels/utils.py
tests/kernels/utils.py
+7
-5
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+1
-1
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+7
-7
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+3
-1
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+7
-8
vllm/attention/backends/hpu_attn.py
vllm/attention/backends/hpu_attn.py
+7
-6
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+6
-6
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+7
-6
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+7
-7
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+3
-1
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+4
-2
vllm/attention/layer.py
vllm/attention/layer.py
+10
-27
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+11
-33
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+3
-7
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+4
-7
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+16
-19
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+7
-7
No files found.
tests/kernels/test_encoder_decoder_attn.py
View file @
e20c92bb
...
@@ -13,8 +13,7 @@ import pytest
...
@@ -13,8 +13,7 @@ import pytest
import
torch
import
torch
from
tests.kernels.utils
import
*
from
tests.kernels.utils
import
*
from
vllm.attention
import
(
Attention
,
AttentionBackend
,
AttentionMetadata
,
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
AttentionType
)
from
vllm.attention.backends.utils
import
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from
vllm.attention.backends.utils
import
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from
vllm.attention.selector
import
(
_Backend
,
_cached_get_attn_backend
,
from
vllm.attention.selector
import
(
_Backend
,
_cached_get_attn_backend
,
global_force_attn_backend_context_manager
)
global_force_attn_backend_context_manager
)
...
@@ -64,6 +63,7 @@ class TestPoint(NamedTuple):
...
@@ -64,6 +63,7 @@ class TestPoint(NamedTuple):
max_dec_seq_len
:
int
max_dec_seq_len
:
int
max_enc_seq_len
:
int
max_enc_seq_len
:
int
num_blocks
:
int
num_blocks
:
int
attn_type
:
AttentionType
class
TestResources
(
NamedTuple
):
class
TestResources
(
NamedTuple
):
...
@@ -96,7 +96,6 @@ class TestResources(NamedTuple):
...
@@ -96,7 +96,6 @@ class TestResources(NamedTuple):
'''
'''
scale
:
float
scale
:
float
attn_backend
:
AttentionBackend
attn
:
Attention
attn
:
Attention
kv_cache
:
torch
.
Tensor
kv_cache
:
torch
.
Tensor
...
@@ -129,16 +128,17 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
...
@@ -129,16 +128,17 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
'''
'''
scale
=
float
(
1.0
/
(
test_pt
.
head_size
**
0.5
))
scale
=
float
(
1.0
/
(
test_pt
.
head_size
**
0.5
))
attn_backend
=
make_backend
(
test_pt
.
backend_name
)
attn
=
Attention
(
attn
=
Attention
(
test_pt
.
num_heads
,
test_pt
.
num_heads
,
test_pt
.
head_size
,
test_pt
.
head_size
,
scale
=
scale
,
scale
=
scale
,
prefix
=
f
"
{
test_pt
.
attn_type
}
"
,
attn_type
=
test_pt
.
attn_type
,
)
)
if
test_pt
.
num_blocks
is
None
or
test_pt
.
num_heads
is
None
:
if
test_pt
.
num_blocks
is
None
or
test_pt
.
num_heads
is
None
:
# Caller does not require a KV cache
# Caller does not require a KV cache
return
TestResources
(
return
TestResources
(
scale
,
attn_backend
,
attn
,
scale
,
attn
,
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
CUDA_DEVICE
))
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
CUDA_DEVICE
))
# Construct KV cache
# Construct KV cache
...
@@ -148,7 +148,7 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
...
@@ -148,7 +148,7 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
test_pt
.
block_size
,
test_pt
.
block_size
,
device
=
CUDA_DEVICE
,
device
=
CUDA_DEVICE
,
backend
=
test_pt
.
backend_name
)
backend
=
test_pt
.
backend_name
)
return
TestResources
(
scale
,
attn_backend
,
attn
,
kv_cache
)
return
TestResources
(
scale
,
attn
,
kv_cache
)
def
_encoder_attn_setup
(
def
_encoder_attn_setup
(
...
@@ -193,6 +193,7 @@ def _encoder_attn_setup(
...
@@ -193,6 +193,7 @@ def _encoder_attn_setup(
_
,
_
,
max_q_seq_len
,
max_q_seq_len
,
_
,
_
,
_
,
)
=
test_pt
)
=
test_pt
scale
=
test_rsrcs
.
scale
scale
=
test_rsrcs
.
scale
...
@@ -301,6 +302,7 @@ def _decoder_attn_setup(
...
@@ -301,6 +302,7 @@ def _decoder_attn_setup(
max_q_seq_len
,
max_q_seq_len
,
_
,
_
,
_
,
_
,
_
,
)
=
test_pt
)
=
test_pt
scale
=
test_rsrcs
.
scale
scale
=
test_rsrcs
.
scale
...
@@ -488,6 +490,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
...
@@ -488,6 +490,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
max_decoder_seq_len
,
max_decoder_seq_len
,
max_encoder_seq_len
,
max_encoder_seq_len
,
_
,
_
,
_
,
)
=
test_pt
)
=
test_pt
scale
=
test_rsrcs
.
scale
scale
=
test_rsrcs
.
scale
...
@@ -622,7 +625,6 @@ def _run_encoder_attention_test(
...
@@ -622,7 +625,6 @@ def _run_encoder_attention_test(
& attn_metadata
& attn_metadata
'''
'''
assert
attn_metadata
.
num_decode_tokens
==
0
assert
attn_metadata
.
num_decode_tokens
==
0
attn_type
=
AttentionType
.
ENCODER
packed_qkv
=
encoder_test_params
.
packed_qkvo
.
packed_qkv
packed_qkv
=
encoder_test_params
.
packed_qkvo
.
packed_qkv
assert
packed_qkv
is
not
None
assert
packed_qkv
is
not
None
with
set_forward_context
(
attn_metadata
,
vllm_config
):
with
set_forward_context
(
attn_metadata
,
vllm_config
):
...
@@ -635,14 +637,11 @@ def _run_encoder_attention_test(
...
@@ -635,14 +637,11 @@ def _run_encoder_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
packed_qkv
.
query
.
view
(
reshaped_query
=
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
return
attn
.
forward
(
packed_qkv
.
key
,
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
packed_qkv
.
value
,
torch
.
tensor
([],
torch
.
tensor
([],
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
packed_qkv
.
query
.
device
),
attn_metadata
)
device
=
packed_qkv
.
query
.
device
),
attn_metadata
,
attn_type
=
attn_type
)
def
_run_decoder_self_attention_test
(
def
_run_decoder_self_attention_test
(
...
@@ -675,7 +674,6 @@ def _run_decoder_self_attention_test(
...
@@ -675,7 +674,6 @@ def _run_decoder_self_attention_test(
* Attention.forward() applied to packed_{query,key,value}, kv_cache
* Attention.forward() applied to packed_{query,key,value}, kv_cache
& attn_metadata
& attn_metadata
'''
'''
attn_type
=
AttentionType
.
DECODER
attn
=
test_rsrcs
.
attn
attn
=
test_rsrcs
.
attn
kv_cache
=
test_rsrcs
.
kv_cache
kv_cache
=
test_rsrcs
.
kv_cache
packed_qkv
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
packed_qkv
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
...
@@ -690,12 +688,8 @@ def _run_decoder_self_attention_test(
...
@@ -690,12 +688,8 @@ def _run_decoder_self_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
packed_qkv
.
query
.
view
(
reshaped_query
=
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
return
attn
.
forward
(
reshaped_query
,
packed_qkv
.
key
,
packed_qkv
.
value
,
packed_qkv
.
key
,
kv_cache
,
attn_metadata
)
packed_qkv
.
value
,
kv_cache
,
attn_metadata
,
attn_type
=
attn_type
)
def
_run_encoder_decoder_cross_attention_test
(
def
_run_encoder_decoder_cross_attention_test
(
...
@@ -742,7 +736,6 @@ def _run_encoder_decoder_cross_attention_test(
...
@@ -742,7 +736,6 @@ def _run_encoder_decoder_cross_attention_test(
'''
'''
assert
decoder_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
assert
decoder_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
attn_type
=
AttentionType
.
ENCODER_DECODER
attn
=
test_rsrcs
.
attn
attn
=
test_rsrcs
.
attn
kv_cache
=
test_rsrcs
.
kv_cache
kv_cache
=
test_rsrcs
.
kv_cache
if
cross_test_params
is
None
:
if
cross_test_params
is
None
:
...
@@ -762,12 +755,8 @@ def _run_encoder_decoder_cross_attention_test(
...
@@ -762,12 +755,8 @@ def _run_encoder_decoder_cross_attention_test(
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
# is shaped as [num_tokens, hidden_size] and we can skip the reshape.
reshaped_query
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
.
query
.
view
(
reshaped_query
=
decoder_test_params
.
packed_qkvo
.
packed_qkv
.
query
.
view
(
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
-
1
,
test_pt
.
num_heads
*
test_pt
.
head_size
)
return
attn
.
forward
(
reshaped_query
,
return
attn
.
forward
(
reshaped_query
,
key
,
value
,
kv_cache
,
key
,
attn_metadata
)
value
,
kv_cache
,
attn_metadata
,
attn_type
=
attn_type
)
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
...
@@ -839,7 +828,7 @@ def test_encoder_only(
...
@@ -839,7 +828,7 @@ def test_encoder_only(
# is not part of this test
# is not part of this test
test_pt
=
TestPoint
(
num_heads
,
head_size
,
attn_backend
.
name
,
test_pt
=
TestPoint
(
num_heads
,
head_size
,
attn_backend
.
name
,
batch_size
,
block_size
,
max_dec_seq_len
,
batch_size
,
block_size
,
max_dec_seq_len
,
max_enc_seq_len
,
4096
)
max_enc_seq_len
,
4096
,
AttentionType
.
ENCODER
)
# Attention scale factor, attention backend instance, attention wrapper
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
# instance, KV cache init
...
@@ -855,7 +844,7 @@ def test_encoder_only(
...
@@ -855,7 +844,7 @@ def test_encoder_only(
# Shared prefill metadata structure
# Shared prefill metadata structure
prephase_attn_metadata
:
AttentionMetadata
=
make_test_metadata
(
prephase_attn_metadata
:
AttentionMetadata
=
make_test_metadata
(
test_rsrcs
.
attn_backend
,
attn_backend
,
True
,
True
,
None
,
None
,
decoder_test_params
=
None
,
decoder_test_params
=
None
,
...
@@ -961,20 +950,29 @@ def test_e2e_enc_dec_attn(
...
@@ -961,20 +950,29 @@ def test_e2e_enc_dec_attn(
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# Note: KV cache size of 4096 is arbitrary & chosen intentionally
# to be more than necessary, since exceeding the kv cache size
# to be more than necessary, since exceeding the kv cache size
# is not part of this test
# is not part of this test
test_pt
=
TestPoint
(
num_heads
,
head_size
,
attn_backend
.
name
,
enc_test_pt
=
TestPoint
(
num_heads
,
head_size
,
attn_backend
.
name
,
batch_size
,
block_size
,
max_dec_seq_len
,
batch_size
,
block_size
,
max_dec_seq_len
,
max_enc_seq_len
,
4096
)
max_enc_seq_len
,
4096
,
AttentionType
.
ENCODER
)
enc_dec_test_pt
=
TestPoint
(
num_heads
,
head_size
,
attn_backend
.
name
,
batch_size
,
block_size
,
max_dec_seq_len
,
max_enc_seq_len
,
4096
,
AttentionType
.
ENCODER_DECODER
)
dec_test_pt
=
TestPoint
(
num_heads
,
head_size
,
attn_backend
.
name
,
batch_size
,
block_size
,
max_dec_seq_len
,
max_enc_seq_len
,
4096
,
AttentionType
.
DECODER
)
# Attention scale factor, attention backend instance, attention wrapper
# Attention scale factor, attention backend instance, attention wrapper
# instance, KV cache init
# instance, KV cache init
vllm_config
=
VllmConfig
()
vllm_config
=
VllmConfig
()
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
test_rsrcs
=
_make_test_resources
(
test_pt
)
enc_test_rsrcs
=
_make_test_resources
(
enc_test_pt
)
enc_dec_test_rsrcs
=
_make_test_resources
(
enc_dec_test_pt
)
dec_test_rsrcs
=
_make_test_resources
(
dec_test_pt
)
# Construct encoder attention test params (only used
# Construct encoder attention test params (only used
# during prefill)
# during prefill)
enc_test_params
=
_encoder_attn_setup
(
test_pt
,
test_rsrcs
)
enc_test_params
=
_encoder_attn_setup
(
enc_
test_pt
,
enc_
test_rsrcs
)
# Construct Decoder self-attention prefill-phase & decode-phase
# Construct Decoder self-attention prefill-phase & decode-phase
# test params, including query/key/value tensors, decoder self-attention
# test params, including query/key/value tensors, decoder self-attention
...
@@ -987,7 +985,7 @@ def test_e2e_enc_dec_attn(
...
@@ -987,7 +985,7 @@ def test_e2e_enc_dec_attn(
prephase_dec_test_params
,
prephase_dec_test_params
,
decphase_dec_test_params
,
decphase_dec_test_params
,
cross_block_base_addr
,
cross_block_base_addr
,
)
=
_decoder_attn_setup
(
test_pt
,
test_rsrcs
)
)
=
_decoder_attn_setup
(
dec_
test_pt
,
dec_
test_rsrcs
)
# Construct encoder/decoder cross-attention prefill-phase
# Construct encoder/decoder cross-attention prefill-phase
# & decode-phase test params, including key/value tensors,
# & decode-phase test params, including key/value tensors,
...
@@ -1000,14 +998,14 @@ def test_e2e_enc_dec_attn(
...
@@ -1000,14 +998,14 @@ def test_e2e_enc_dec_attn(
dec_qkv
,
dec_qkv
,
enc_test_params
,
enc_test_params
,
prephase_dec_test_params
,
prephase_dec_test_params
,
test_pt
,
enc_dec_
test_pt
,
test_rsrcs
,
enc_dec_
test_rsrcs
,
block_base_addr
=
cross_block_base_addr
)
block_base_addr
=
cross_block_base_addr
)
# Shared prefill metadata structure
# Shared prefill metadata structure
assert
prephase_dec_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
assert
prephase_dec_test_params
.
packed_qkvo
.
packed_qkv
is
not
None
prephase_attn_metadata
:
AttentionMetadata
=
make_test_metadata
(
prephase_attn_metadata
:
AttentionMetadata
=
make_test_metadata
(
test_rsrcs
.
attn_backend
,
attn_backend
,
True
,
True
,
prephase_dec_test_params
.
packed_qkvo
.
packed_qkv
.
q_seq_lens
,
prephase_dec_test_params
.
packed_qkvo
.
packed_qkv
.
q_seq_lens
,
decoder_test_params
=
prephase_dec_test_params
,
decoder_test_params
=
prephase_dec_test_params
,
...
@@ -1017,10 +1015,10 @@ def test_e2e_enc_dec_attn(
...
@@ -1017,10 +1015,10 @@ def test_e2e_enc_dec_attn(
# PREFILL: encoder attention
# PREFILL: encoder attention
enc_pckd_act_out
=
_run_encoder_attention_test
(
test_rsrcs
.
attn
,
enc_pckd_act_out
=
_run_encoder_attention_test
(
enc_
test_rsrcs
.
attn
,
enc_test_params
,
enc_test_params
,
prephase_attn_metadata
,
prephase_attn_metadata
,
test_pt
=
test_pt
,
test_pt
=
enc_
test_pt
,
vllm_config
=
vllm_config
)
vllm_config
=
vllm_config
)
# - Is encoder attention result correct?
# - Is encoder attention result correct?
...
@@ -1030,10 +1028,10 @@ def test_e2e_enc_dec_attn(
...
@@ -1030,10 +1028,10 @@ def test_e2e_enc_dec_attn(
# PREFILL: decoder self-attention test
# PREFILL: decoder self-attention test
prephase_dec_pckd_act_out
=
_run_decoder_self_attention_test
(
prephase_dec_pckd_act_out
=
_run_decoder_self_attention_test
(
test_rsrcs
,
dec_
test_rsrcs
,
prephase_dec_test_params
,
prephase_dec_test_params
,
prephase_attn_metadata
,
prephase_attn_metadata
,
test_pt
=
test_pt
,
test_pt
=
dec_
test_pt
,
vllm_config
=
vllm_config
)
vllm_config
=
vllm_config
)
# - Is prefill decoder self-attention correct?
# - Is prefill decoder self-attention correct?
...
@@ -1044,11 +1042,11 @@ def test_e2e_enc_dec_attn(
...
@@ -1044,11 +1042,11 @@ def test_e2e_enc_dec_attn(
# PREFILL: encoder/decoder cross-attention test
# PREFILL: encoder/decoder cross-attention test
prephase_cross_pckd_act_out
=
_run_encoder_decoder_cross_attention_test
(
prephase_cross_pckd_act_out
=
_run_encoder_decoder_cross_attention_test
(
test_rsrcs
,
enc_dec_
test_rsrcs
,
prephase_dec_test_params
,
prephase_dec_test_params
,
prephase_cross_test_params
,
prephase_cross_test_params
,
prephase_attn_metadata
,
prephase_attn_metadata
,
test_pt
=
test_pt
,
test_pt
=
enc_dec_
test_pt
,
vllm_config
=
vllm_config
)
vllm_config
=
vllm_config
)
# - Is prefill encoder/decoder cross-attention correct?
# - Is prefill encoder/decoder cross-attention correct?
...
@@ -1059,7 +1057,7 @@ def test_e2e_enc_dec_attn(
...
@@ -1059,7 +1057,7 @@ def test_e2e_enc_dec_attn(
# DECODE: build decode-phase attention metadata
# DECODE: build decode-phase attention metadata
decphase_attn_metadata
:
AttentionMetadata
=
make_test_metadata
(
decphase_attn_metadata
:
AttentionMetadata
=
make_test_metadata
(
test_rsrcs
.
attn_backend
,
attn_backend
,
False
,
False
,
dec_qkv
.
q_seq_lens
,
dec_qkv
.
q_seq_lens
,
decoder_test_params
=
decphase_dec_test_params
,
decoder_test_params
=
decphase_dec_test_params
,
...
@@ -1070,10 +1068,10 @@ def test_e2e_enc_dec_attn(
...
@@ -1070,10 +1068,10 @@ def test_e2e_enc_dec_attn(
# DECODE: decoder self-attention test
# DECODE: decoder self-attention test
decphase_dec_pckd_act_out
=
_run_decoder_self_attention_test
(
decphase_dec_pckd_act_out
=
_run_decoder_self_attention_test
(
test_rsrcs
,
dec_
test_rsrcs
,
decphase_dec_test_params
,
decphase_dec_test_params
,
decphase_attn_metadata
,
decphase_attn_metadata
,
test_pt
=
test_pt
,
test_pt
=
dec_
test_pt
,
vllm_config
=
vllm_config
)
vllm_config
=
vllm_config
)
# - Is decode-phase decoder self-attention correct?
# - Is decode-phase decoder self-attention correct?
...
@@ -1084,11 +1082,11 @@ def test_e2e_enc_dec_attn(
...
@@ -1084,11 +1082,11 @@ def test_e2e_enc_dec_attn(
# DECODE: encoder/decoder cross-attention test
# DECODE: encoder/decoder cross-attention test
decphase_cross_pckd_act_out
=
_run_encoder_decoder_cross_attention_test
(
decphase_cross_pckd_act_out
=
_run_encoder_decoder_cross_attention_test
(
test_rsrcs
,
enc_dec_
test_rsrcs
,
decphase_dec_test_params
,
decphase_dec_test_params
,
None
,
None
,
decphase_attn_metadata
,
decphase_attn_metadata
,
test_pt
=
test_pt
,
test_pt
=
enc_dec_
test_pt
,
vllm_config
=
vllm_config
)
vllm_config
=
vllm_config
)
# - Is decode-phase encoder/decoder cross-attention correct?
# - Is decode-phase encoder/decoder cross-attention correct?
...
...
tests/kernels/utils.py
View file @
e20c92bb
...
@@ -13,6 +13,7 @@ from torch._prims_common import TensorLikeType
...
@@ -13,6 +13,7 @@ from torch._prims_common import TensorLikeType
from
vllm.attention
import
AttentionBackend
,
AttentionMetadata
,
AttentionType
from
vllm.attention
import
AttentionBackend
,
AttentionMetadata
,
AttentionType
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.platforms.interface
import
_Backend
from
vllm.utils
import
(
STR_BACKEND_ENV_VAR
,
STR_FLASH_ATTN_VAL
,
from
vllm.utils
import
(
STR_BACKEND_ENV_VAR
,
STR_FLASH_ATTN_VAL
,
STR_XFORMERS_ATTN_VAL
,
make_tensor_with_pad
)
STR_XFORMERS_ATTN_VAL
,
make_tensor_with_pad
)
...
@@ -790,7 +791,7 @@ def make_block_tables_slot_mapping(
...
@@ -790,7 +791,7 @@ def make_block_tables_slot_mapping(
def
make_test_metadata
(
def
make_test_metadata
(
attn_backend
:
Attention
Backend
,
attn_backend
:
_
Backend
,
is_prompt
:
bool
,
is_prompt
:
bool
,
seq_lens
:
Optional
[
List
[
int
]],
seq_lens
:
Optional
[
List
[
int
]],
decoder_test_params
:
Optional
[
PhaseTestParameters
],
decoder_test_params
:
Optional
[
PhaseTestParameters
],
...
@@ -815,7 +816,7 @@ def make_test_metadata(
...
@@ -815,7 +816,7 @@ def make_test_metadata(
Arguments:
Arguments:
* attn_backend: Backend for sourcing attention kernels
* attn_backend
_name
: Backend for sourcing attention kernels
* is_prompt: prefill if True, o/w decode
* is_prompt: prefill if True, o/w decode
* seq_lens: list of token counts for each sequence
* seq_lens: list of token counts for each sequence
* decoder_test_params: decoder self-attention test params;
* decoder_test_params: decoder self-attention test params;
...
@@ -882,6 +883,8 @@ def make_test_metadata(
...
@@ -882,6 +883,8 @@ def make_test_metadata(
# (kv_mmap)
# (kv_mmap)
cross_kv_mmap
=
cross_test_params
.
kv_mmap
cross_kv_mmap
=
cross_test_params
.
kv_mmap
attn_backend_obj
=
make_backend
(
attn_backend
.
name
)
if
is_prompt
:
if
is_prompt
:
# Prefill-phase scenario
# Prefill-phase scenario
...
@@ -902,8 +905,7 @@ def make_test_metadata(
...
@@ -902,8 +905,7 @@ def make_test_metadata(
context_lens
,
context_lens
,
encoder_seq_lens
,
encoder_seq_lens
,
device
=
device
)
device
=
device
)
return
attn_backend_obj
.
make_metadata
(
return
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
num_prefills
=
num_prefills
,
slot_mapping
=
(
None
if
kv_mmap
is
None
else
kv_mmap
.
slot_mapping
),
slot_mapping
=
(
None
if
kv_mmap
is
None
else
kv_mmap
.
slot_mapping
),
multi_modal_placeholder_index_maps
=
None
,
multi_modal_placeholder_index_maps
=
None
,
...
@@ -952,7 +954,7 @@ def make_test_metadata(
...
@@ -952,7 +954,7 @@ def make_test_metadata(
encoder_seq_lens
,
encoder_seq_lens
,
device
=
device
)
device
=
device
)
return
attn_backend
.
make_metadata
(
return
attn_backend
_obj
.
make_metadata
(
num_prefills
=
num_prefills
,
num_prefills
=
num_prefills
,
slot_mapping
=
kv_mmap
.
slot_mapping
,
slot_mapping
=
kv_mmap
.
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
multi_modal_placeholder_index_maps
=
None
,
...
...
vllm/attention/backends/abstract.py
View file @
e20c92bb
...
@@ -233,6 +233,7 @@ class AttentionImpl(ABC, Generic[T]):
...
@@ -233,6 +233,7 @@ class AttentionImpl(ABC, Generic[T]):
kv_cache_dtype
:
str
=
"auto"
,
kv_cache_dtype
:
str
=
"auto"
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -246,7 +247,6 @@ class AttentionImpl(ABC, Generic[T]):
...
@@ -246,7 +247,6 @@ class AttentionImpl(ABC, Generic[T]):
attn_metadata
:
T
,
attn_metadata
:
T
,
k_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
vllm/attention/backends/blocksparse_attn.py
View file @
e20c92bb
...
@@ -300,6 +300,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
...
@@ -300,6 +300,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
not
None
assert
blocksparse_params
is
not
None
assert
alibi_slopes
is
None
,
ValueError
(
assert
alibi_slopes
is
None
,
ValueError
(
...
@@ -350,6 +351,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
...
@@ -350,6 +351,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
active_head_range
=
self
.
blocksparse_params
.
active_head_range
,
active_head_range
=
self
.
blocksparse_params
.
active_head_range
,
)
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl"
)
def
forward
(
def
forward
(
self
,
self
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -359,7 +366,6 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
...
@@ -359,7 +366,6 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
attn_metadata
:
BlocksparseFlashAttentionMetadata
,
attn_metadata
:
BlocksparseFlashAttentionMetadata
,
k_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
"""Forward pass with FlashAttention and PagedAttention.
...
@@ -375,12 +381,6 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
...
@@ -375,12 +381,6 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"BlocksparseFlashAttentionImpl"
)
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
...
vllm/attention/backends/flash_attn.py
View file @
e20c92bb
...
@@ -600,6 +600,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -600,6 +600,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
)
->
None
:
if
blocksparse_params
is
not
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -627,6 +628,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -627,6 +628,7 @@ class FlashAttentionImpl(AttentionImpl):
raise
ValueError
(
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by FlashAttention. "
f
"Head size
{
head_size
}
is not supported by FlashAttention. "
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
self
.
attn_type
=
attn_type
def
forward
(
def
forward
(
self
,
self
,
...
@@ -637,7 +639,6 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -637,7 +639,6 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata
:
FlashAttentionMetadata
,
attn_metadata
:
FlashAttentionMetadata
,
k_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention.
"""Forward pass with FlashAttention.
...
@@ -659,6 +660,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -659,6 +660,7 @@ class FlashAttentionImpl(AttentionImpl):
assert
output
is
not
None
,
"Output tensor must be provided."
assert
output
is
not
None
,
"Output tensor must be provided."
attn_type
=
self
.
attn_type
if
(
attn_type
==
AttentionType
.
ENCODER
if
(
attn_type
==
AttentionType
.
ENCODER
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
raise
AttributeError
(
"Encoder attention requires setting "
raise
AttributeError
(
"Encoder attention requires setting "
...
...
vllm/attention/backends/flashinfer.py
View file @
e20c92bb
...
@@ -748,6 +748,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -748,6 +748,7 @@ class FlashInferImpl(AttentionImpl):
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
)
->
None
:
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
...
@@ -764,6 +765,12 @@ class FlashInferImpl(AttentionImpl):
...
@@ -764,6 +765,12 @@ class FlashInferImpl(AttentionImpl):
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl"
)
def
forward
(
def
forward
(
self
,
self
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -773,18 +780,10 @@ class FlashInferImpl(AttentionImpl):
...
@@ -773,18 +780,10 @@ class FlashInferImpl(AttentionImpl):
attn_metadata
:
FlashInferMetadata
,
attn_metadata
:
FlashInferMetadata
,
k_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# TODO: directly write to output tensor
# TODO: directly write to output tensor
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl"
)
num_heads
:
int
=
self
.
num_heads
num_heads
:
int
=
self
.
num_heads
head_size
:
int
=
self
.
head_size
head_size
:
int
=
self
.
head_size
num_kv_heads
:
int
=
self
.
num_kv_heads
num_kv_heads
:
int
=
self
.
num_kv_heads
...
...
vllm/attention/backends/hpu_attn.py
View file @
e20c92bb
...
@@ -102,6 +102,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -102,6 +102,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_seq_len
:
int
=
4096
,
max_seq_len
:
int
=
4096
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
)
->
None
:
super
(
AttentionImpl
,
self
).
__init__
()
super
(
AttentionImpl
,
self
).
__init__
()
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
...
@@ -143,6 +144,12 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -143,6 +144,12 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"HPUAttentionImpl"
)
def
forward
(
def
forward
(
self
,
self
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -152,7 +159,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -152,7 +159,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
attn_metadata
:
HPUAttentionMetadata
,
attn_metadata
:
HPUAttentionMetadata
,
k_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
"""Forward pass with xFormers and PagedAttention.
...
@@ -166,11 +172,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -166,11 +172,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"HPUAttentionImpl"
)
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
_
,
seq_len_kv
,
_
=
key
.
shape
_
,
seq_len_kv
,
_
=
key
.
shape
...
...
vllm/attention/backends/ipex_attn.py
View file @
e20c92bb
...
@@ -115,6 +115,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -115,6 +115,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
)
->
None
:
if
blocksparse_params
is
not
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -146,6 +147,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -146,6 +147,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
raise
NotImplementedError
(
raise
NotImplementedError
(
"IPEX backend does not support FP8 KV cache. "
"IPEX backend does not support FP8 KV cache. "
"Please use xFormers backend instead."
)
"Please use xFormers backend instead."
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl"
)
def
split_kv_cache
(
def
split_kv_cache
(
self
,
self
,
...
@@ -172,7 +178,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -172,7 +178,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
attn_metadata
:
IpexAttnMetadata
,
# type: ignore
attn_metadata
:
IpexAttnMetadata
,
# type: ignore
k_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with IPEX varlen_attention and PagedAttention.
"""Forward pass with IPEX varlen_attention and PagedAttention.
...
@@ -189,11 +194,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -189,11 +194,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
assert
k_scale
==
1.0
and
v_scale
==
1.0
assert
k_scale
==
1.0
and
v_scale
==
1.0
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl"
)
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
...
vllm/attention/backends/pallas.py
View file @
e20c92bb
...
@@ -100,6 +100,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -100,6 +100,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
)
->
None
:
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
...
@@ -141,6 +142,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -141,6 +142,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
# megacore mode will be None.
# megacore mode will be None.
self
.
megacore_mode
=
"batch"
self
.
megacore_mode
=
"batch"
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl"
)
def
forward
(
def
forward
(
self
,
self
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -150,7 +157,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -150,7 +157,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata
:
PallasMetadata
,
attn_metadata
:
PallasMetadata
,
k_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with Pallas attention.
"""Forward pass with Pallas attention.
...
@@ -168,11 +174,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -168,11 +174,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
shape = [batch_size, seq_len, num_heads * head_size]
shape = [batch_size, seq_len, num_heads * head_size]
"""
"""
assert
k_scale
==
1.0
and
v_scale
==
1.0
assert
k_scale
==
1.0
and
v_scale
==
1.0
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl"
)
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
e20c92bb
...
@@ -338,6 +338,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -338,6 +338,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
)
->
None
:
if
blocksparse_params
is
not
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -397,6 +398,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -397,6 +398,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
attn_func
=
_sdpa_attention
self
.
attn_func
=
_sdpa_attention
logger
.
debug
(
"Using naive attention in ROCmBackend"
)
logger
.
debug
(
"Using naive attention in ROCmBackend"
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmFlashAttentionImpl"
)
def
repeat_kv
(
self
,
x
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
def
repeat_kv
(
self
,
x
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
tokens
,
n_kv_heads
,
head_dim
=
x
.
shape
tokens
,
n_kv_heads
,
head_dim
=
x
.
shape
...
@@ -414,7 +421,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -414,7 +421,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_metadata
:
ROCmFlashAttentionMetadata
,
attn_metadata
:
ROCmFlashAttentionMetadata
,
k_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
"""Forward pass with FlashAttention and PagedAttention.
...
@@ -432,12 +438,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -432,12 +438,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"""
"""
# Reminder: Please update docs/source/features/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
# If the feature combo become valid
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmFlashAttentionImpl"
)
num_tokens
,
hidden_size
=
query
.
shape
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
...
vllm/attention/backends/torch_sdpa.py
View file @
e20c92bb
...
@@ -390,6 +390,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -390,6 +390,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
)
->
None
:
if
blocksparse_params
is
not
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -421,6 +422,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -421,6 +422,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
raise
NotImplementedError
(
raise
NotImplementedError
(
"Torch SDPA backend does not support FP8 KV cache. "
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead."
)
"Please use xFormers backend instead."
)
self
.
attn_type
=
attn_type
def
forward
(
def
forward
(
self
,
self
,
...
@@ -431,7 +433,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -431,7 +433,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
k_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
"""Forward pass with torch SDPA and PagedAttention.
...
@@ -448,6 +449,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -448,6 +449,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
assert
k_scale
==
1.0
and
v_scale
==
1.0
assert
k_scale
==
1.0
and
v_scale
==
1.0
attn_type
=
self
.
attn_type
if
(
attn_type
==
AttentionType
.
ENCODER
if
(
attn_type
==
AttentionType
.
ENCODER
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
raise
AttributeError
(
"Encoder attention requires setting "
raise
AttributeError
(
"Encoder attention requires setting "
...
...
vllm/attention/backends/xformers.py
View file @
e20c92bb
...
@@ -379,6 +379,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -379,6 +379,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
)
->
None
:
if
blocksparse_params
is
not
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -405,6 +406,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -405,6 +406,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
f
"Supported head sizes are:
{
suppored_head_sizes
}
."
)
self
.
attn_type
=
attn_type
def
forward
(
def
forward
(
self
,
self
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -414,7 +417,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -414,7 +417,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
attn_metadata
:
"XFormersMetadata"
,
attn_metadata
:
"XFormersMetadata"
,
k_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
str
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
"""Forward pass with xFormers and PagedAttention.
...
@@ -468,7 +470,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -468,7 +470,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
attn_type
=
self
.
attn_type
# Check that appropriate attention metadata attributes are
# Check that appropriate attention metadata attributes are
# selected for the desired attention type
# selected for the desired attention type
if
(
attn_type
==
AttentionType
.
ENCODER
if
(
attn_type
==
AttentionType
.
ENCODER
...
...
vllm/attention/layer.py
View file @
e20c92bb
...
@@ -41,6 +41,7 @@ class Attention(nn.Module):
...
@@ -41,6 +41,7 @@ class Attention(nn.Module):
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
per_layer_sliding_window
:
Optional
[
int
]
=
None
,
per_layer_sliding_window
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
per_layer_sliding_window
is
not
None
:
if
per_layer_sliding_window
is
not
None
:
...
@@ -96,7 +97,7 @@ class Attention(nn.Module):
...
@@ -96,7 +97,7 @@ class Attention(nn.Module):
impl_cls
=
attn_backend
.
get_impl_cls
()
impl_cls
=
attn_backend
.
get_impl_cls
()
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
)
blocksparse_params
,
logits_soft_cap
,
attn_type
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
num_kv_heads
=
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
...
@@ -119,6 +120,7 @@ class Attention(nn.Module):
...
@@ -119,6 +120,7 @@ class Attention(nn.Module):
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
.
static_forward_context
[
prefix
]
=
self
self
.
layer_name
=
prefix
self
.
layer_name
=
prefix
self
.
attn_type
=
attn_type
def
forward
(
def
forward
(
self
,
self
,
...
@@ -127,18 +129,12 @@ class Attention(nn.Module):
...
@@ -127,18 +129,12 @@ class Attention(nn.Module):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
use_direct_call
:
if
self
.
use_direct_call
:
return
self
.
impl
.
forward
(
query
,
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
key
,
attn_metadata
,
self
.
_k_scale
,
value
,
self
.
_v_scale
)
kv_cache
,
attn_metadata
,
self
.
_k_scale
,
self
.
_v_scale
,
attn_type
=
attn_type
)
elif
self
.
use_output
:
elif
self
.
use_output
:
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
hidden_size
=
query
.
size
(
-
1
)
hidden_size
=
query
.
size
(
-
1
)
...
@@ -152,13 +148,11 @@ class Attention(nn.Module):
...
@@ -152,13 +148,11 @@ class Attention(nn.Module):
if
value
is
not
None
:
if
value
is
not
None
:
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
torch
.
ops
.
vllm
.
unified_attention_with_output
(
torch
.
ops
.
vllm
.
unified_attention_with_output
(
query
,
key
,
value
,
output
,
kv_cache
,
attn_type
,
query
,
key
,
value
,
output
,
kv_cache
,
self
.
layer_name
)
self
.
layer_name
)
return
output
.
view
(
-
1
,
hidden_size
)
return
output
.
view
(
-
1
,
hidden_size
)
else
:
else
:
return
torch
.
ops
.
vllm
.
unified_attention
(
query
,
key
,
value
,
return
torch
.
ops
.
vllm
.
unified_attention
(
query
,
key
,
value
,
kv_cache
,
attn_type
,
kv_cache
,
self
.
layer_name
)
self
.
layer_name
)
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
impl
.
head_size
}
"
# type: ignore
s
=
f
"head_size=
{
self
.
impl
.
head_size
}
"
# type: ignore
...
@@ -237,20 +231,13 @@ def unified_attention(
...
@@ -237,20 +231,13 @@ def unified_attention(
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_type
:
str
,
layer_name
:
str
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
dynamic_forward_context
attn_metadata
=
forward_context
.
dynamic_forward_context
self
=
forward_context
.
static_forward_context
[
layer_name
]
self
=
forward_context
.
static_forward_context
[
layer_name
]
return
self
.
impl
.
forward
(
query
,
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
key
,
self
.
_k_scale
,
self
.
_v_scale
)
value
,
kv_cache
,
attn_metadata
,
self
.
_k_scale
,
self
.
_v_scale
,
attn_type
=
attn_type
)
def
unified_attention_fake
(
def
unified_attention_fake
(
...
@@ -258,7 +245,6 @@ def unified_attention_fake(
...
@@ -258,7 +245,6 @@ def unified_attention_fake(
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_type
:
str
,
layer_name
:
str
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
query
).
contiguous
()
return
torch
.
empty_like
(
query
).
contiguous
()
...
@@ -279,7 +265,6 @@ def unified_attention_with_output(
...
@@ -279,7 +265,6 @@ def unified_attention_with_output(
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_type
:
str
,
layer_name
:
str
,
layer_name
:
str
,
)
->
None
:
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
...
@@ -292,7 +277,6 @@ def unified_attention_with_output(
...
@@ -292,7 +277,6 @@ def unified_attention_with_output(
attn_metadata
,
attn_metadata
,
self
.
_k_scale
,
self
.
_k_scale
,
self
.
_v_scale
,
self
.
_v_scale
,
attn_type
=
attn_type
,
output
=
output
)
output
=
output
)
...
@@ -302,7 +286,6 @@ def unified_attention_with_output_fake(
...
@@ -302,7 +286,6 @@ def unified_attention_with_output_fake(
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_type
:
str
,
layer_name
:
str
,
layer_name
:
str
,
)
->
None
:
)
->
None
:
return
return
...
...
vllm/model_executor/models/bart.py
View file @
e20c92bb
...
@@ -71,12 +71,8 @@ class BartLearnedPositionalEmbedding(VocabParallelEmbedding):
...
@@ -71,12 +71,8 @@ class BartLearnedPositionalEmbedding(VocabParallelEmbedding):
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
attn_type
:
AttentionType
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""`input_ids' shape is expected to be [bsz x seqlen]."""
"""`input_ids' shape is expected to be [bsz x seqlen]."""
assert
attn_type
!=
AttentionType
.
ENCODER_DECODER
return
super
().
forward
(
positions
+
self
.
offset
)
return
super
().
forward
(
positions
+
self
.
offset
)
...
@@ -180,7 +176,8 @@ class BartEncoderAttention(nn.Module):
...
@@ -180,7 +176,8 @@ class BartEncoderAttention(nn.Module):
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
AttentionType
.
ENCODER
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
...
@@ -189,12 +186,7 @@ class BartEncoderAttention(nn.Module):
...
@@ -189,12 +186,7 @@ class BartEncoderAttention(nn.Module):
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
AttentionType
.
ENCODER
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
return
output
...
@@ -264,7 +256,8 @@ class BartDecoderSelfAttention(nn.Module):
...
@@ -264,7 +256,8 @@ class BartDecoderSelfAttention(nn.Module):
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
AttentionType
.
DECODER
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
attn_metadata
:
AttentionMetadata
)
->
torch
.
Tensor
:
...
@@ -273,12 +266,7 @@ class BartDecoderSelfAttention(nn.Module):
...
@@ -273,12 +266,7 @@ class BartDecoderSelfAttention(nn.Module):
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
AttentionType
.
DECODER
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
return
output
...
@@ -348,7 +336,8 @@ class BartCrossAttention(nn.Module):
...
@@ -348,7 +336,8 @@ class BartCrossAttention(nn.Module):
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
AttentionType
.
ENCODER_DECODER
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -372,12 +361,7 @@ class BartCrossAttention(nn.Module):
...
@@ -372,12 +361,7 @@ class BartCrossAttention(nn.Module):
_
,
k
,
v
=
qkv_enc
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
_
,
k
,
v
=
qkv_enc
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
AttentionType
.
ENCODER_DECODER
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
return
output
...
@@ -644,10 +628,7 @@ class BartEncoder(nn.Module):
...
@@ -644,10 +628,7 @@ class BartEncoder(nn.Module):
# retrieve input_ids and inputs_embeds
# retrieve input_ids and inputs_embeds
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
embed_pos
=
self
.
embed_positions
(
embed_pos
=
self
.
embed_positions
(
positions
)
positions
,
AttentionType
.
ENCODER
,
)
embed_pos
=
embed_pos
.
to
(
inputs_embeds
.
device
)
embed_pos
=
embed_pos
.
to
(
inputs_embeds
.
device
)
hidden_states
=
inputs_embeds
+
embed_pos
hidden_states
=
inputs_embeds
+
embed_pos
...
@@ -734,10 +715,7 @@ class BartDecoder(nn.Module):
...
@@ -734,10 +715,7 @@ class BartDecoder(nn.Module):
inputs_embeds
=
self
.
embed_tokens
(
decoder_input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
decoder_input_ids
)
# embed positions
# embed positions
embed_pos
=
self
.
embed_positions
(
embed_pos
=
self
.
embed_positions
(
decoder_positions
)
decoder_positions
,
AttentionType
.
DECODER
,
)
embed_pos
=
embed_pos
.
to
(
inputs_embeds
.
device
)
embed_pos
=
embed_pos
.
to
(
inputs_embeds
.
device
)
hidden_states
=
inputs_embeds
+
embed_pos
hidden_states
=
inputs_embeds
+
embed_pos
...
...
vllm/model_executor/models/bert.py
View file @
e20c92bb
...
@@ -238,7 +238,8 @@ class BertSelfAttention(nn.Module):
...
@@ -238,7 +238,8 @@ class BertSelfAttention(nn.Module):
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
AttentionType
.
ENCODER_ONLY
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -248,12 +249,7 @@ class BertSelfAttention(nn.Module):
...
@@ -248,12 +249,7 @@ class BertSelfAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
output
=
self
.
attn
(
q
,
output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
AttentionType
.
ENCODER_ONLY
)
return
output
return
output
...
...
vllm/model_executor/models/mllama.py
View file @
e20c92bb
...
@@ -770,6 +770,7 @@ class MllamaTextCrossAttention(nn.Module):
...
@@ -770,6 +770,7 @@ class MllamaTextCrossAttention(nn.Module):
self
.
scaling
,
self
.
scaling
,
self
.
num_local_key_value_heads
,
self
.
num_local_key_value_heads
,
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
AttentionType
.
ENCODER_DECODER
,
)
)
def
forward
(
def
forward
(
...
@@ -805,13 +806,9 @@ class MllamaTextCrossAttention(nn.Module):
...
@@ -805,13 +806,9 @@ class MllamaTextCrossAttention(nn.Module):
kv_range_for_decode
,
kv_range_for_decode
,
attn_metadata
)
attn_metadata
)
else
:
else
:
output
=
self
.
attn
(
q
.
view
(
-
1
,
output
=
self
.
attn
(
self
.
num_local_heads
*
self
.
head_dim
),
q
.
view
(
-
1
,
self
.
num_local_heads
*
self
.
head_dim
),
k
,
v
,
k
,
kv_cache
,
attn_metadata
)
v
,
kv_cache
,
attn_metadata
,
attn_type
=
AttentionType
.
ENCODER_DECODER
)
out
,
_
=
self
.
o_proj
(
output
)
out
,
_
=
self
.
o_proj
(
output
)
return
out
return
out
...
...
vllm/model_executor/models/qwen2.py
View file @
e20c92bb
...
@@ -107,7 +107,8 @@ class Qwen2Attention(nn.Module):
...
@@ -107,7 +107,8 @@ class Qwen2Attention(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
rope_scaling
:
Optional
[
Tuple
]
=
None
,
rope_scaling
:
Optional
[
Tuple
]
=
None
,
prefix
:
str
=
""
)
->
None
:
prefix
:
str
=
""
,
attn_type
:
str
=
AttentionType
.
DECODER
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
...
@@ -160,7 +161,8 @@ class Qwen2Attention(nn.Module):
...
@@ -160,7 +161,8 @@ class Qwen2Attention(nn.Module):
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
attn_type
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -168,17 +170,11 @@ class Qwen2Attention(nn.Module):
...
@@ -168,17 +170,11 @@ class Qwen2Attention(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
attn_type
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -197,6 +193,16 @@ class Qwen2DecoderLayer(nn.Module):
...
@@ -197,6 +193,16 @@ class Qwen2DecoderLayer(nn.Module):
# Requires transformers > 4.32.0
# Requires transformers > 4.32.0
rope_theta
=
getattr
(
config
,
"rope_theta"
,
1000000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
1000000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if
getattr
(
config
,
"is_causal"
,
True
):
attn_type
=
AttentionType
.
DECODER
else
:
attn_type
=
AttentionType
.
ENCODER_ONLY
self
.
self_attn
=
Qwen2Attention
(
self
.
self_attn
=
Qwen2Attention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
...
@@ -207,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
...
@@ -207,6 +213,7 @@ class Qwen2DecoderLayer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
attn_type
=
attn_type
,
)
)
self
.
mlp
=
Qwen2MLP
(
self
.
mlp
=
Qwen2MLP
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
...
@@ -220,15 +227,6 @@ class Qwen2DecoderLayer(nn.Module):
...
@@ -220,15 +227,6 @@ class Qwen2DecoderLayer(nn.Module):
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if
getattr
(
config
,
"is_causal"
,
True
):
self
.
_attn_type
=
AttentionType
.
DECODER
else
:
self
.
_attn_type
=
AttentionType
.
ENCODER_ONLY
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -249,7 +247,6 @@ class Qwen2DecoderLayer(nn.Module):
...
@@ -249,7 +247,6 @@ class Qwen2DecoderLayer(nn.Module):
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
attn_type
=
self
.
_attn_type
,
)
)
# Fully Connected
# Fully Connected
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
e20c92bb
...
@@ -89,6 +89,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -89,6 +89,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
None
:
)
->
None
:
if
blocksparse_params
is
not
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -119,6 +120,12 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -119,6 +120,12 @@ class FlashAttentionImpl(AttentionImpl):
f
"Head size
{
head_size
}
is not supported by FlashAttention. "
f
"Head size
{
head_size
}
is not supported by FlashAttention. "
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl"
)
def
forward
(
def
forward
(
self
,
self
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -128,7 +135,6 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -128,7 +135,6 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata
:
FlashAttentionMetadata
,
attn_metadata
:
FlashAttentionMetadata
,
k_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention.
"""Forward pass with FlashAttention.
...
@@ -142,12 +148,6 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -142,12 +148,6 @@ class FlashAttentionImpl(AttentionImpl):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl"
)
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
"key/v_scale is not supported in FlashAttention."
)
"key/v_scale is not supported in FlashAttention."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment