Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
805a8a75
Unverified
Commit
805a8a75
authored
Aug 01, 2024
by
Woosuk Kwon
Committed by
GitHub
Aug 01, 2024
Browse files
[Misc] Support attention logits soft-capping with flash-attn (#7022)
parent
562e580a
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
71 additions
and
47 deletions
+71
-47
requirements-cuda.txt
requirements-cuda.txt
+1
-1
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+13
-6
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+1
-0
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+3
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+10
-11
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+5
-9
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+6
-2
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+4
-0
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+8
-2
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+6
-2
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+0
-9
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+7
-2
vllm/attention/layer.py
vllm/attention/layer.py
+2
-1
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+5
-2
No files found.
requirements-cuda.txt
View file @
805a8a75
...
@@ -8,4 +8,4 @@ torch == 2.4.0
...
@@ -8,4 +8,4 @@ torch == 2.4.0
# These must be updated alongside torch
# These must be updated alongside torch
torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
xformers == 0.0.27.post2 # Requires PyTorch 2.4.0
xformers == 0.0.27.post2 # Requires PyTorch 2.4.0
vllm-flash-attn == 2.6.
0
# Requires PyTorch 2.4.0
vllm-flash-attn == 2.6.
1
# Requires PyTorch 2.4.0
tests/kernels/test_flash_attn.py
View file @
805a8a75
...
@@ -20,6 +20,7 @@ def ref_paged_attn(
...
@@ -20,6 +20,7 @@ def ref_paged_attn(
block_tables
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
scale
:
float
,
scale
:
float
,
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_seqs
=
len
(
query_lens
)
num_seqs
=
len
(
query_lens
)
block_tables
=
block_tables
.
cpu
().
numpy
()
block_tables
=
block_tables
.
cpu
().
numpy
()
...
@@ -53,6 +54,8 @@ def ref_paged_attn(
...
@@ -53,6 +54,8 @@ def ref_paged_attn(
(
query_len
+
sliding_window
)
+
(
query_len
+
sliding_window
)
+
1
).
bool
().
logical_not
()
1
).
bool
().
logical_not
()
mask
|=
sliding_window_mask
mask
|=
sliding_window_mask
if
soft_cap
is
not
None
:
attn
=
soft_cap
*
torch
.
tanh
(
attn
/
soft_cap
)
attn
.
masked_fill_
(
mask
,
float
(
"-inf"
))
attn
.
masked_fill_
(
mask
,
float
(
"-inf"
))
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
).
to
(
v
.
dtype
)
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
).
to
(
v
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn
,
v
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn
,
v
)
...
@@ -68,13 +71,15 @@ def ref_paged_attn(
...
@@ -68,13 +71,15 @@ def ref_paged_attn(
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
torch
.
inference_mode
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
def
test_flash_attn_with_paged_kv
(
kv_lens
:
List
[
int
],
kv_lens
:
List
[
int
],
num_heads
:
Tuple
[
int
,
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
...
@@ -108,6 +113,7 @@ def test_flash_attn_with_paged_kv(
...
@@ -108,6 +113,7 @@ def test_flash_attn_with_paged_kv(
causal
=
True
,
causal
=
True
,
block_table
=
block_tables
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
).
squeeze
(
1
)
).
squeeze
(
1
)
ref_output
=
ref_paged_attn
(
ref_output
=
ref_paged_attn
(
...
@@ -118,6 +124,7 @@ def test_flash_attn_with_paged_kv(
...
@@ -118,6 +124,7 @@ def test_flash_attn_with_paged_kv(
kv_lens
=
kv_lens
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
scale
=
scale
,
scale
=
scale
,
soft_cap
=
soft_cap
,
)
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
...
@@ -129,7 +136,8 @@ def test_flash_attn_with_paged_kv(
...
@@ -129,7 +136,8 @@ def test_flash_attn_with_paged_kv(
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
torch
.
inference_mode
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
torch
.
inference_mode
()
def
test_varlen_with_paged_kv
(
def
test_varlen_with_paged_kv
(
seq_lens
:
List
[
Tuple
[
int
,
int
]],
seq_lens
:
List
[
Tuple
[
int
,
int
]],
num_heads
:
Tuple
[
int
,
int
],
num_heads
:
Tuple
[
int
,
int
],
...
@@ -137,6 +145,7 @@ def test_varlen_with_paged_kv(
...
@@ -137,6 +145,7 @@ def test_varlen_with_paged_kv(
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
...
@@ -163,10 +172,6 @@ def test_varlen_with_paged_kv(
...
@@ -163,10 +172,6 @@ def test_varlen_with_paged_kv(
head_size
,
head_size
,
dtype
=
dtype
)
dtype
=
dtype
)
value_cache
=
torch
.
randn_like
(
key_cache
)
value_cache
=
torch
.
randn_like
(
key_cache
)
# Normalize the scale of the key and value caches to mitigate
# numerical instability.
key_cache
/=
head_size
**
0.5
value_cache
/=
head_size
**
0.5
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
...
@@ -192,6 +197,7 @@ def test_varlen_with_paged_kv(
...
@@ -192,6 +197,7 @@ def test_varlen_with_paged_kv(
causal
=
True
,
causal
=
True
,
window_size
=
window_size
,
window_size
=
window_size
,
block_table
=
block_tables
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
)
)
ref_output
=
ref_paged_attn
(
ref_output
=
ref_paged_attn
(
...
@@ -203,6 +209,7 @@ def test_varlen_with_paged_kv(
...
@@ -203,6 +209,7 @@ def test_varlen_with_paged_kv(
block_tables
=
block_tables
,
block_tables
=
block_tables
,
scale
=
scale
,
scale
=
scale
,
sliding_window
=
sliding_window
,
sliding_window
=
sliding_window
,
soft_cap
=
soft_cap
,
)
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
vllm/attention/backends/abstract.py
View file @
805a8a75
...
@@ -150,6 +150,7 @@ class AttentionImpl(ABC, Generic[T]):
...
@@ -150,6 +150,7 @@ class AttentionImpl(ABC, Generic[T]):
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
kv_cache_dtype
:
str
=
"auto"
,
kv_cache_dtype
:
str
=
"auto"
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/attention/backends/blocksparse_attn.py
View file @
805a8a75
...
@@ -283,12 +283,15 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
...
@@ -283,12 +283,15 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
not
None
assert
blocksparse_params
is
not
None
assert
alibi_slopes
is
None
,
ValueError
(
assert
alibi_slopes
is
None
,
ValueError
(
"Alibi not support for blocksparse flash attention."
)
"Alibi not support for blocksparse flash attention."
)
assert
sliding_window
is
None
,
ValueError
(
assert
sliding_window
is
None
,
ValueError
(
"sliding_window is invalid for blocksparse attention."
)
"sliding_window is invalid for blocksparse attention."
)
assert
logits_soft_cap
is
None
,
ValueError
(
"logits_soft_cap is invalid for blocksparse attention."
)
if
"num_heads"
not
in
blocksparse_params
:
if
"num_heads"
not
in
blocksparse_params
:
blocksparse_params
[
"num_heads"
]
=
num_heads
blocksparse_params
[
"num_heads"
]
=
num_heads
...
...
vllm/attention/backends/flash_attn.py
View file @
805a8a75
...
@@ -288,15 +288,6 @@ class FlashAttentionMetadataBuilder(
...
@@ -288,15 +288,6 @@ class FlashAttentionMetadataBuilder(
device
=
self
.
runner
.
device
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
logits_soft_cap
=
getattr
(
self
.
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER."
)
max_query_len
=
max
(
query_lens
)
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
...
@@ -405,9 +396,11 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -405,9 +396,11 @@ class FlashAttentionImpl(AttentionImpl):
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
if
blocksparse_params
is
not
None
:
"FlashAttention does not support block-sparse attention."
)
raise
ValueError
(
"FlashAttention does not support block-sparse attention."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
...
@@ -418,6 +411,10 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -418,6 +411,10 @@ class FlashAttentionImpl(AttentionImpl):
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
if
logits_soft_cap
is
None
:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap
=
0
self
.
logits_soft_cap
=
logits_soft_cap
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
@@ -525,6 +522,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -525,6 +522,7 @@ class FlashAttentionImpl(AttentionImpl):
causal
=
True
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
softcap
=
self
.
logits_soft_cap
,
)
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
output
[:
num_prefill_tokens
]
=
out
...
@@ -544,6 +542,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -544,6 +542,7 @@ class FlashAttentionImpl(AttentionImpl):
causal
=
True
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
self
.
logits_soft_cap
,
)
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
...
...
vllm/attention/backends/flashinfer.py
View file @
805a8a75
...
@@ -116,8 +116,6 @@ class FlashInferMetadata(AttentionMetadata):
...
@@ -116,8 +116,6 @@ class FlashInferMetadata(AttentionMetadata):
# The data type of the paged kv cache
# The data type of the paged kv cache
data_type
:
torch
.
dtype
=
None
data_type
:
torch
.
dtype
=
None
device
:
torch
.
device
=
torch
.
device
(
"cuda"
)
device
:
torch
.
device
=
torch
.
device
(
"cuda"
)
# Only used by gemma2 model
logits_soft_cap
:
Optional
[
float
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Refer to
# Refer to
...
@@ -391,9 +389,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -391,9 +389,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
)
device
=
device
)
logits_soft_cap
=
getattr
(
self
.
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
len
(
self
.
paged_kv_indptr
)
>
0
:
if
len
(
self
.
paged_kv_indptr
)
>
0
:
paged_kv_indices_tensor
=
torch
.
tensor
(
self
.
paged_kv_indices
,
paged_kv_indices_tensor
=
torch
.
tensor
(
self
.
paged_kv_indices
,
device
=
"cpu"
,
device
=
"cpu"
,
...
@@ -430,8 +425,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -430,8 +425,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
device
=
device
,
device
=
device
,
data_type
=
kv_cache_dtype
,
data_type
=
kv_cache_dtype
,
use_cuda_graph
=
use_captured_graph
,
use_cuda_graph
=
use_captured_graph
)
logits_soft_cap
=
logits_soft_cap
)
class
FlashInferImpl
(
AttentionImpl
):
class
FlashInferImpl
(
AttentionImpl
):
...
@@ -446,6 +440,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -446,6 +440,7 @@ class FlashInferImpl(AttentionImpl):
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
...
@@ -458,6 +453,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -458,6 +453,7 @@ class FlashInferImpl(AttentionImpl):
raise
ValueError
(
"Sliding window is not supported in FlashInfer."
)
raise
ValueError
(
"Sliding window is not supported in FlashInfer."
)
self
.
sliding_window
=
(
-
1
,
-
1
)
self
.
sliding_window
=
(
-
1
,
-
1
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
logits_soft_cap
=
logits_soft_cap
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
@@ -532,7 +528,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -532,7 +528,7 @@ class FlashInferImpl(AttentionImpl):
output
=
prefill_meta
.
prefill_wrapper
.
forward
(
output
=
prefill_meta
.
prefill_wrapper
.
forward
(
query
,
query
,
kv_cache
,
kv_cache
,
logits_soft_cap
=
attn_metadata
.
logits_soft_cap
,
logits_soft_cap
=
self
.
logits_soft_cap
,
causal
=
True
)
causal
=
True
)
else
:
else
:
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
is
not
None
...
@@ -541,5 +537,5 @@ class FlashInferImpl(AttentionImpl):
...
@@ -541,5 +537,5 @@ class FlashInferImpl(AttentionImpl):
query
,
query
,
kv_cache
,
kv_cache
,
sm_scale
=
self
.
scale
,
sm_scale
=
self
.
scale
,
logits_soft_cap
=
attn_metadata
.
logits_soft_cap
)
logits_soft_cap
=
self
.
logits_soft_cap
)
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/ipex_attn.py
View file @
805a8a75
...
@@ -105,9 +105,13 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -105,9 +105,13 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
if
blocksparse_params
is
not
None
:
"Torch SPDA does not support block-sparse attention."
)
raise
ValueError
(
"IPEX backend does not support block-sparse attention."
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"IPEX backend does not support logits_soft_cap."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
...
...
vllm/attention/backends/pallas.py
View file @
805a8a75
...
@@ -91,6 +91,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -91,6 +91,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
...
@@ -109,6 +110,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -109,6 +110,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise
NotImplementedError
(
"FP8 KV cache dtype is not supported."
)
raise
NotImplementedError
(
"FP8 KV cache dtype is not supported."
)
if
blocksparse_params
is
not
None
:
if
blocksparse_params
is
not
None
:
raise
NotImplementedError
(
"Blocksparse is not supported."
)
raise
NotImplementedError
(
"Blocksparse is not supported."
)
if
logits_soft_cap
is
not
None
:
raise
NotImplementedError
(
"Attention logits soft-capping is not supported."
)
if
torch_xla
.
tpu
.
version
()
<
4
:
if
torch_xla
.
tpu
.
version
()
<
4
:
raise
NotImplementedError
(
"TPU version must be 4 or higher."
)
raise
NotImplementedError
(
"TPU version must be 4 or higher."
)
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
805a8a75
...
@@ -244,9 +244,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -244,9 +244,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
if
blocksparse_params
is
not
None
:
"ROCFlashAttention does not support blocksparse attention."
)
raise
ValueError
(
"ROCmFlashAttention does not support blocksparse attention."
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"ROCmFlashAttention does not support attention logits soft "
"capping."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
...
...
vllm/attention/backends/torch_sdpa.py
View file @
805a8a75
...
@@ -109,9 +109,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -109,9 +109,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
if
blocksparse_params
is
not
None
:
"Torch SPDA does not support block-sparse attention."
)
raise
ValueError
(
"Torch SPDA does not support block-sparse attention."
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"Torch SPDA does not support logits soft cap."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
...
...
vllm/attention/backends/utils.py
View file @
805a8a75
...
@@ -165,15 +165,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -165,15 +165,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
device
=
self
.
runner
.
device
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
logits_soft_cap
=
getattr
(
self
.
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"Please use Flashinfer backend for models with logits_soft_cap "
"(i.e., Gemma-2). Otherwise, the output might be wrong. "
"Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER."
)
max_query_len
=
max
(
query_lens
)
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
...
...
vllm/attention/backends/xformers.py
View file @
805a8a75
...
@@ -408,9 +408,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -408,9 +408,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
assert
blocksparse_params
is
None
,
ValueError
(
if
blocksparse_params
is
not
None
:
"XFormer does not support block-sparse attention."
)
raise
ValueError
(
"XFormers does not support block-sparse attention."
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"XFormers does not support attention logits soft capping."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
...
...
vllm/attention/layer.py
View file @
805a8a75
...
@@ -34,6 +34,7 @@ class Attention(nn.Module):
...
@@ -34,6 +34,7 @@ class Attention(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -82,7 +83,7 @@ class Attention(nn.Module):
...
@@ -82,7 +83,7 @@ class Attention(nn.Module):
impl_cls
=
attn_backend
.
get_impl_cls
()
impl_cls
=
attn_backend
.
get_impl_cls
()
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
)
blocksparse_params
,
logits_soft_cap
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/model_executor/models/gemma2.py
View file @
805a8a75
...
@@ -90,7 +90,8 @@ class Gemma2Attention(nn.Module):
...
@@ -90,7 +90,8 @@ class Gemma2Attention(nn.Module):
max_position_embeddings
:
int
,
max_position_embeddings
:
int
,
rope_theta
:
float
,
rope_theta
:
float
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
attn_logits_soft_cap
:
Optional
[
float
]
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
layer_idx
=
layer_idx
self
.
config
=
config
self
.
config
=
config
...
@@ -150,7 +151,8 @@ class Gemma2Attention(nn.Module):
...
@@ -150,7 +151,8 @@ class Gemma2Attention(nn.Module):
self
.
scaling
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
logits_soft_cap
=
attn_logits_soft_cap
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -189,6 +191,7 @@ class Gemma2DecoderLayer(nn.Module):
...
@@ -189,6 +191,7 @@ class Gemma2DecoderLayer(nn.Module):
rope_theta
=
config
.
rope_theta
,
rope_theta
=
config
.
rope_theta
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
attn_logits_soft_cap
=
config
.
attn_logit_softcapping
,
)
)
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
mlp
=
Gemma2MLP
(
self
.
mlp
=
Gemma2MLP
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment