Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ee3eea0a
Unverified
Commit
ee3eea0a
authored
May 22, 2024
by
Cody Yu
Committed by
GitHub
May 23, 2024
Browse files
[Misc] Take user preference in attention selector (#4960)
parent
a36de682
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
169 additions
and
61 deletions
+169
-61
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+84
-0
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+1
-0
vllm/attention/selector.py
vllm/attention/selector.py
+84
-61
No files found.
tests/kernels/test_attention_selector.py
0 → 100644
View file @
ee3eea0a
import
os
from
unittest.mock
import
patch
import
pytest
import
torch
from
vllm.attention.selector
import
which_attn_to_use
@
pytest
.
mark
.
parametrize
(
"name"
,
[
"TORCH_SDPA"
,
"ROCM_FLASH"
,
"XFORMERS"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
])
def
test_env
(
name
:
str
,
device
:
str
):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
"""
name_backup
=
os
.
environ
.
get
(
"VLLM_ATTENTION_BACKEND"
,
None
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
name
if
device
==
"cpu"
:
with
patch
(
"vllm.attention.selector.is_cpu"
,
return_value
=
True
):
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
torch
.
float16
,
16
)
assert
backend
.
name
==
"TORCH_SDPA"
elif
device
==
"hip"
:
with
patch
(
"vllm.attention.selector.is_hip"
,
return_value
=
True
):
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
torch
.
float16
,
16
)
assert
backend
.
name
==
"ROCM_FLASH"
else
:
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
torch
.
float16
,
16
)
assert
backend
.
name
==
name
if
name_backup
is
not
None
:
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
name_backup
def
test_flash_attn
():
"""Test FlashAttn validation."""
name_backup
=
os
.
environ
.
get
(
"VLLM_ATTENTION_BACKEND"
,
None
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
"FLASH_ATTN"
# Unsupported CUDA arch
with
patch
(
"torch.cuda.get_device_capability"
,
return_value
=
[
7
,
5
]):
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
16
)
assert
backend
.
name
!=
"FLASH_ATTN"
# Unsupported data type
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float8_e4m3fn
,
None
,
16
)
assert
backend
.
name
!=
"FLASH_ATTN"
# Unsupported kv cache data type
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
"fp8"
,
16
)
assert
backend
.
name
!=
"FLASH_ATTN"
# Unsupported block size
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
8
)
assert
backend
.
name
!=
"FLASH_ATTN"
# Unsupported sliding window
backend
=
which_attn_to_use
(
8
,
16
,
8
,
1
,
torch
.
float16
,
None
,
16
)
assert
backend
.
name
!=
"FLASH_ATTN"
# flash-attn is not installed
with
patch
.
dict
(
'sys.modules'
,
{
'vllm_flash_attn'
:
None
}):
backend
=
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
16
)
assert
backend
.
name
!=
"FLASH_ATTN"
# Unsupported head size
backend
=
which_attn_to_use
(
8
,
17
,
8
,
None
,
torch
.
float16
,
None
,
16
)
assert
backend
.
name
!=
"FLASH_ATTN"
if
name_backup
is
not
None
:
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
name_backup
def
test_invalid_env
():
"""Throw an exception if the backend name is invalid."""
name_backup
=
os
.
environ
.
get
(
"VLLM_ATTENTION_BACKEND"
,
None
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
"INVALID"
with
pytest
.
raises
(
ValueError
):
which_attn_to_use
(
8
,
16
,
8
,
None
,
torch
.
float16
,
None
,
16
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
name_backup
vllm/attention/backends/flashinfer.py
View file @
ee3eea0a
...
@@ -218,6 +218,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -218,6 +218,7 @@ class FlashInferImpl(AttentionImpl):
)
)
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
assert
prefill_meta
.
block_tables
is
not
None
assert
prefill_meta
.
block_tables
is
not
None
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
output
=
flash_attn_varlen_func
(
output
=
flash_attn_varlen_func
(
...
...
vllm/attention/selector.py
View file @
ee3eea0a
...
@@ -30,24 +30,16 @@ def get_attn_backend(
...
@@ -30,24 +30,16 @@ def get_attn_backend(
kv_cache_dtype
:
Optional
[
str
],
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
block_size
:
int
,
)
->
Type
[
AttentionBackend
]:
)
->
Type
[
AttentionBackend
]:
backend
=
_which_attn_to_use
(
num_heads
,
head_size
,
num_kv_heads
,
"""Determine which attention backend to use and only import
sliding_window
,
dtype
,
kv_cache_dtype
,
the selected backend module.
block_size
)
"""
backend
=
which_attn_to_use
(
num_heads
,
head_size
,
num_kv_heads
,
sliding_window
,
dtype
,
kv_cache_dtype
,
block_size
)
if
backend
==
_Backend
.
FLASH_ATTN
:
if
backend
==
_Backend
.
FLASH_ATTN
:
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
FlashAttentionBackend
)
FlashAttentionBackend
)
return
FlashAttentionBackend
# We check it here not in _which_attn_to_use because we cannot know
# the head size until we import FlashAttentionBackend.
supported_head_sizes
=
FlashAttentionBackend
.
get_supported_head_sizes
()
if
head_size
in
supported_head_sizes
:
logger
.
info
(
"Using FlashAttention-2 backend."
)
return
FlashAttentionBackend
logger
.
info
(
"Cannot use FlashAttention-2 backend for head size %d. "
"Using XFormers backend instead."
,
head_size
)
backend
=
_Backend
.
XFORMERS
if
backend
==
_Backend
.
XFORMERS
:
if
backend
==
_Backend
.
XFORMERS
:
logger
.
info
(
"Using XFormers backend."
)
logger
.
info
(
"Using XFormers backend."
)
from
vllm.attention.backends.xformers
import
(
# noqa: F401
from
vllm.attention.backends.xformers
import
(
# noqa: F401
...
@@ -64,14 +56,15 @@ def get_attn_backend(
...
@@ -64,14 +56,15 @@ def get_attn_backend(
return
TorchSDPABackend
return
TorchSDPABackend
elif
backend
==
_Backend
.
FLASHINFER
:
elif
backend
==
_Backend
.
FLASHINFER
:
logger
.
info
(
"Using Flashinfer backend."
)
logger
.
info
(
"Using Flashinfer backend."
)
logger
.
warning
(
"Eager mode is enforced for the Flashinfer backend."
)
logger
.
warning
(
"Eager mode is required for the Flashinfer backend. "
"Please make sure --enforce-eager is set."
)
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
return
FlashInferBackend
return
FlashInferBackend
else
:
else
:
raise
ValueError
(
"Invalid attention backend."
)
raise
ValueError
(
"Invalid attention backend."
)
def
_
which_attn_to_use
(
def
which_attn_to_use
(
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
...
@@ -81,54 +74,84 @@ def _which_attn_to_use(
...
@@ -81,54 +74,84 @@ def _which_attn_to_use(
block_size
:
int
,
block_size
:
int
,
)
->
_Backend
:
)
->
_Backend
:
"""Returns which flash attention backend to use."""
"""Returns which flash attention backend to use."""
# Default case.
selected_backend
=
_Backend
.
FLASH_ATTN
# Check the environment variable and override if specified
backend_by_env_var
:
Optional
[
str
]
=
envs
.
VLLM_ATTENTION_BACKEND
if
backend_by_env_var
is
not
None
:
backend_members
=
_Backend
.
__members__
if
backend_by_env_var
not
in
backend_members
:
raise
ValueError
(
f
"Invalid attention backend '
{
backend_by_env_var
}
'. "
f
"Available backends:
{
', '
.
join
(
backend_members
)
}
"
"(case-sensitive)."
)
selected_backend
=
_Backend
[
backend_by_env_var
]
if
is_cpu
():
if
is_cpu
():
if
selected_backend
!=
_Backend
.
TORCH_SDPA
:
logger
.
info
(
"Cannot use %s backend on CPU."
,
selected_backend
)
return
_Backend
.
TORCH_SDPA
return
_Backend
.
TORCH_SDPA
if
is_hip
():
if
is_hip
():
# AMD GPUs.
# AMD GPUs.
if
torch
.
cuda
.
get_device_capability
()[
0
]
!=
9
:
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
# not Instinct series GPUs.
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
logger
.
info
(
"flash_atten is not supported on NAVI GPUs."
)
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
!=
9
:
# not Instinct series GPUs.
logger
.
info
(
"flash_attn is not supported on NAVI GPUs."
)
else
:
logger
.
info
(
"%s is not supported in AMD GPUs."
,
selected_backend
)
return
_Backend
.
ROCM_FLASH
return
_Backend
.
ROCM_FLASH
# NVIDIA GPUs.
# FlashAttn in NVIDIA GPUs.
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
if
selected_backend
==
_Backend
.
FLASH_ATTN
:
# Volta and Turing NVIDIA GPUs.
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
logger
.
info
(
"Cannot use FlashAttention-2 backend for Volta and Turing "
# Volta and Turing NVIDIA GPUs.
"GPUs."
)
logger
.
info
(
return
_Backend
.
XFORMERS
"Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs."
)
if
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
selected_backend
=
_Backend
.
XFORMERS
logger
.
info
(
"Cannot use FlashAttention-2 backend for dtype other than "
elif
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
"torch.float16 or torch.bfloat16."
)
logger
.
info
(
return
_Backend
.
XFORMERS
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16."
)
if
kv_cache_dtype
is
not
None
and
kv_cache_dtype
.
startswith
(
"fp8"
):
selected_backend
=
_Backend
.
XFORMERS
logger
.
info
(
"Cannot use FlashAttention-2 backend for FP8 KV cache."
)
elif
kv_cache_dtype
is
not
None
and
kv_cache_dtype
.
startswith
(
"fp8"
):
return
_Backend
.
XFORMERS
logger
.
info
(
"Cannot use FlashAttention-2 backend for FP8 KV cache."
)
if
block_size
%
16
!=
0
:
selected_backend
=
_Backend
.
XFORMERS
logger
.
info
(
"Cannot use FlashAttention-2 backend for block size not "
elif
block_size
%
16
!=
0
:
"divisible by 16."
)
logger
.
info
(
return
_Backend
.
XFORMERS
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16."
)
if
sliding_window
is
not
None
:
selected_backend
=
_Backend
.
XFORMERS
logger
.
info
(
elif
sliding_window
is
not
None
:
"Cannot use FlashAttention-2 backend due to sliding window."
)
logger
.
info
(
return
_Backend
.
XFORMERS
"Cannot use FlashAttention-2 backend due to sliding window."
)
selected_backend
=
_Backend
.
XFORMERS
try
:
import
vllm_flash_attn
# noqa: F401
# FlashAttn is valid for the model, checking if the package is installed.
except
ImportError
:
if
selected_backend
==
_Backend
.
FLASH_ATTN
:
logger
.
info
(
try
:
"Cannot use FlashAttention-2 backend because the vllm_flash_attn "
import
vllm_flash_attn
# noqa: F401
"package is not found. `pip install vllm-flash-attn` for better "
"performance."
)
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
return
_Backend
.
XFORMERS
FlashAttentionBackend
)
backend_by_env_var
=
envs
.
VLLM_ATTENTION_BACKEND
supported_sizes
=
FlashAttentionBackend
.
get_supported_head_sizes
()
if
backend_by_env_var
is
not
None
:
if
head_size
not
in
supported_sizes
:
return
_Backend
[
backend_by_env_var
]
logger
.
info
(
"Cannot use FlashAttention-2 backend for head size %d."
,
# Default case.
head_size
)
return
_Backend
.
FLASH_ATTN
selected_backend
=
_Backend
.
XFORMERS
except
ImportError
:
logger
.
info
(
"Cannot use FlashAttention-2 backend because the "
"vllm_flash_attn package is not found. "
"`pip install vllm-flash-attn` for better performance."
)
selected_backend
=
_Backend
.
XFORMERS
return
selected_backend
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment