Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
89579a20
Unverified
Commit
89579a20
authored
May 08, 2024
by
Woosuk Kwon
Committed by
GitHub
May 08, 2024
Browse files
[Misc] Use vllm-flash-attn instead of flash-attn (#4686)
parent
230c4b38
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
16 additions
and
31 deletions
+16
-31
Dockerfile
Dockerfile
+0
-21
requirements-cuda.txt
requirements-cuda.txt
+1
-0
setup.py
setup.py
+9
-5
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+1
-1
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+1
-1
vllm/attention/selector.py
vllm/attention/selector.py
+4
-3
No files found.
Dockerfile
View file @
89579a20
...
@@ -87,23 +87,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
...
@@ -87,23 +87,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
pip cache remove vllm_nccl
*
pip cache remove vllm_nccl
*
#################### EXTENSION Build IMAGE ####################
#################### EXTENSION Build IMAGE ####################
#################### FLASH_ATTENTION Build IMAGE ####################
FROM
dev as flash-attn-builder
# max jobs used for build
ARG
max_jobs=2
ENV
MAX_JOBS=${max_jobs}
# flash attention version
ARG
flash_attn_version=v2.5.8
ENV
FLASH_ATTN_VERSION=${flash_attn_version}
WORKDIR
/usr/src/flash-attention-v2
# Download the wheel or build it if a pre-compiled release doesn't exist
RUN
pip
--verbose
wheel flash-attn
==
${
FLASH_ATTN_VERSION
}
\
--no-build-isolation
--no-deps
--no-cache-dir
#################### FLASH_ATTENTION Build IMAGE ####################
#################### vLLM installation IMAGE ####################
#################### vLLM installation IMAGE ####################
# image with vLLM installed
# image with vLLM installed
FROM
nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base
FROM
nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base
...
@@ -122,10 +105,6 @@ RUN ldconfig /usr/local/cuda-12.4/compat/
...
@@ -122,10 +105,6 @@ RUN ldconfig /usr/local/cuda-12.4/compat/
RUN
--mount
=
type
=
bind
,from
=
build,src
=
/workspace/dist,target
=
/vllm-workspace/dist
\
RUN
--mount
=
type
=
bind
,from
=
build,src
=
/workspace/dist,target
=
/vllm-workspace/dist
\
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
pip
install
dist/
*
.whl
--verbose
pip
install
dist/
*
.whl
--verbose
RUN
--mount
=
type
=
bind
,from
=
flash-attn-builder,src
=
/usr/src/flash-attention-v2,target
=
/usr/src/flash-attention-v2
\
--mount
=
type
=
cache,target
=
/root/.cache/pip
\
pip
install
/usr/src/flash-attention-v2/
*
.whl
--no-cache-dir
#################### vLLM installation IMAGE ####################
#################### vLLM installation IMAGE ####################
...
...
requirements-cuda.txt
View file @
89579a20
...
@@ -7,3 +7,4 @@ nvidia-ml-py # for pynvml package
...
@@ -7,3 +7,4 @@ nvidia-ml-py # for pynvml package
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
torch == 2.3.0
torch == 2.3.0
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
vllm-flash-attn == 2.5.8.post1 # Requires PyTorch 2.3.0
setup.py
View file @
89579a20
...
@@ -355,13 +355,17 @@ def get_requirements() -> List[str]:
...
@@ -355,13 +355,17 @@ def get_requirements() -> List[str]:
if
_is_cuda
():
if
_is_cuda
():
requirements
=
_read_requirements
(
"requirements-cuda.txt"
)
requirements
=
_read_requirements
(
"requirements-cuda.txt"
)
cuda_major
=
torch
.
version
.
cuda
.
split
(
"."
)
[
0
]
cuda_major
,
cuda_minor
=
torch
.
version
.
cuda
.
split
(
"."
)
modified_requirements
=
[]
modified_requirements
=
[]
for
req
in
requirements
:
for
req
in
requirements
:
if
"vllm-nccl-cu12"
in
req
:
if
"vllm-nccl-cu12"
in
req
:
modified_requirements
.
append
(
req
=
req
.
replace
(
"vllm-nccl-cu12"
,
req
.
replace
(
"vllm-nccl-cu12"
,
f
"vllm-nccl-cu
{
cuda_major
}
"
))
f
"vllm-nccl-cu
{
cuda_major
}
"
)
else
:
elif
(
"vllm-flash-attn"
in
req
and
not
(
cuda_major
==
"12"
and
cuda_minor
==
"1"
)):
# vllm-flash-attn is built only for CUDA 12.1.
# Skip for other versions.
continue
modified_requirements
.
append
(
req
)
modified_requirements
.
append
(
req
)
requirements
=
modified_requirements
requirements
=
modified_requirements
elif
_is_hip
():
elif
_is_hip
():
...
...
vllm/attention/backends/flash_attn.py
View file @
89579a20
...
@@ -8,7 +8,7 @@ from dataclasses import dataclass
...
@@ -8,7 +8,7 @@ from dataclasses import dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Type
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
from
flash_attn
import
flash_attn_varlen_func
from
vllm_
flash_attn
import
flash_attn_varlen_func
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadata
,
...
...
vllm/attention/backends/flashinfer.py
View file @
89579a20
...
@@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type
...
@@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type
import
flashinfer
import
flashinfer
import
torch
import
torch
from
flash_attn
import
flash_attn_varlen_func
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
vllm_flash_attn
import
flash_attn_varlen_func
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
...
...
vllm/attention/selector.py
View file @
89579a20
...
@@ -76,11 +76,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
...
@@ -76,11 +76,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
return
_Backend
.
XFORMERS
return
_Backend
.
XFORMERS
try
:
try
:
import
flash_attn
# noqa: F401
import
vllm_
flash_attn
# noqa: F401
except
ImportError
:
except
ImportError
:
logger
.
info
(
logger
.
info
(
"Cannot use FlashAttention-2 backend because the flash_attn "
"Cannot use FlashAttention-2 backend because the vllm_flash_attn "
"package is not found. Please install it for better performance."
)
"package is not found. `pip install vllm-flash-attn` for better "
"performance."
)
return
_Backend
.
XFORMERS
return
_Backend
.
XFORMERS
backend_by_env_var
=
envs
.
VLLM_ATTENTION_BACKEND
backend_by_env_var
=
envs
.
VLLM_ATTENTION_BACKEND
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment