Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1cb0cc29
Unverified
Commit
1cb0cc29
authored
Mar 08, 2024
by
Woosuk Kwon
Committed by
GitHub
Mar 08, 2024
Browse files
[FIX] Make `flash_attn` optional (#3269)
parent
99c3cfb8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
41 additions
and
78 deletions
+41
-78
.gitignore
.gitignore
+0
-3
setup.py
setup.py
+3
-45
vllm/__init__.py
vllm/__init__.py
+7
-23
vllm/model_executor/layers/attention/attention.py
vllm/model_executor/layers/attention/attention.py
+31
-6
vllm/model_executor/layers/attention/backends/flash_attn.py
vllm/model_executor/layers/attention/backends/flash_attn.py
+0
-1
No files found.
.gitignore
View file @
1cb0cc29
...
@@ -184,6 +184,3 @@ _build/
...
@@ -184,6 +184,3 @@ _build/
# Benchmark dataset
# Benchmark dataset
*.json
*.json
# Third-party Python packages.
vllm/thirdparty_files/
setup.py
View file @
1cb0cc29
...
@@ -3,7 +3,6 @@ import io
...
@@ -3,7 +3,6 @@ import io
import
os
import
os
import
re
import
re
import
subprocess
import
subprocess
import
sys
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
List
,
Set
from
typing
import
List
,
Set
...
@@ -15,8 +14,6 @@ import torch.utils.cpp_extension as torch_cpp_ext
...
@@ -15,8 +14,6 @@ import torch.utils.cpp_extension as torch_cpp_ext
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
,
ROCM_HOME
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
,
ROCM_HOME
ROOT_DIR
=
os
.
path
.
dirname
(
__file__
)
ROOT_DIR
=
os
.
path
.
dirname
(
__file__
)
# This is a temporary directory to store third-party packages.
THIRDPARTY_SUBDIR
=
"vllm/thirdparty_files"
# If you are developing the C++ backend of vLLM, consider building vLLM with
# If you are developing the C++ backend of vLLM, consider building vLLM with
# `python setup.py develop` since it will give you incremental builds.
# `python setup.py develop` since it will give you incremental builds.
...
@@ -327,46 +324,8 @@ if _is_cuda():
...
@@ -327,46 +324,8 @@ if _is_cuda():
"nvcc"
:
NVCC_FLAGS_PUNICA
,
"nvcc"
:
NVCC_FLAGS_PUNICA
,
},
},
))
))
elif
_is_neuron
():
# Download the FlashAttention package.
neuronxcc_version
=
get_neuronxcc_version
()
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/setup.py#L518-L530
flash_attn_version
=
"2.5.6"
install_dir
=
os
.
path
.
join
(
ROOT_DIR
,
THIRDPARTY_SUBDIR
)
subprocess
.
check_call
(
[
sys
.
executable
,
"-m"
,
"pip"
,
"install"
,
"-q"
,
f
"--target=
{
install_dir
}
"
,
"einops"
,
# Dependency of flash-attn.
f
"flash-attn==
{
flash_attn_version
}
"
,
"--no-dependencies"
,
# Required to avoid re-installing torch.
],
env
=
dict
(
os
.
environ
,
CC
=
"gcc"
),
)
# Copy the FlashAttention package into the vLLM package after build.
class
build_ext
(
BuildExtension
):
def
run
(
self
):
super
().
run
()
target_dir
=
os
.
path
.
join
(
self
.
build_lib
,
THIRDPARTY_SUBDIR
)
if
not
os
.
path
.
exists
(
target_dir
):
os
.
makedirs
(
target_dir
)
self
.
copy_tree
(
install_dir
,
target_dir
)
class
BinaryDistribution
(
setuptools
.
Distribution
):
def
has_ext_modules
(
self
):
return
True
else
:
build_ext
=
BuildExtension
BinaryDistribution
=
setuptools
.
Distribution
if
_is_neuron
():
neuronxcc_version
=
get_neuronxcc_version
()
vllm_extension_sources
=
[
vllm_extension_sources
=
[
"csrc/cache_kernels.cu"
,
"csrc/cache_kernels.cu"
,
...
@@ -509,7 +468,6 @@ setuptools.setup(
...
@@ -509,7 +468,6 @@ setuptools.setup(
python_requires
=
">=3.8"
,
python_requires
=
">=3.8"
,
install_requires
=
get_requirements
(),
install_requires
=
get_requirements
(),
ext_modules
=
ext_modules
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
build_ext
}
if
not
_is_neuron
()
else
{},
cmdclass
=
{
"build_ext"
:
BuildExtension
}
if
not
_is_neuron
()
else
{},
distclass
=
BinaryDistribution
,
package_data
=
package_data
,
package_data
=
package_data
,
)
)
vllm/__init__.py
View file @
1cb0cc29
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/ray/__init__.py#L11
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
def
_configure_system
():
from
vllm.engine.llm_engine
import
LLMEngine
import
os
from
vllm.engine.ray_utils
import
initialize_cluster
import
sys
from
vllm.entrypoints.llm
import
LLM
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
# Importing flash-attn.
from
vllm.sampling_params
import
SamplingParams
thirdparty_files
=
os
.
path
.
join
(
os
.
path
.
abspath
(
os
.
path
.
dirname
(
__file__
)),
"thirdparty_files"
)
sys
.
path
.
insert
(
0
,
thirdparty_files
)
_configure_system
()
# Delete configuration function.
del
_configure_system
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
# noqa: E402
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
# noqa: E402
from
vllm.engine.llm_engine
import
LLMEngine
# noqa: E402
from
vllm.engine.ray_utils
import
initialize_cluster
# noqa: E402
from
vllm.entrypoints.llm
import
LLM
# noqa: E402
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
# noqa: E402
from
vllm.sampling_params
import
SamplingParams
# noqa: E402
__version__
=
"0.3.3"
__version__
=
"0.3.3"
...
...
vllm/model_executor/layers/attention/attention.py
View file @
1cb0cc29
"""Attention layer."""
"""Attention layer."""
from
functools
import
lru_cache
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.logger
import
init_logger
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
logger
=
init_logger
(
__name__
)
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
"""Attention layer.
"""Attention layer.
...
@@ -30,17 +34,12 @@ class Attention(nn.Module):
...
@@ -30,17 +34,12 @@ class Attention(nn.Module):
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
(
not
is_hip
()
and
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
and
if
_use_flash_attn
():
torch
.
get_default_dtype
()
in
(
torch
.
float16
,
torch
.
bfloat16
)):
# Ampere or later NVIDIA GPUs.
# NOTE(woosuk): FlashAttention does not support FP32.
from
vllm.model_executor.layers.attention.backends.flash_attn
import
FlashAttentionBackend
from
vllm.model_executor.layers.attention.backends.flash_attn
import
FlashAttentionBackend
self
.
backend
=
FlashAttentionBackend
(
num_heads
,
head_size
,
scale
,
self
.
backend
=
FlashAttentionBackend
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
num_kv_heads
,
alibi_slopes
,
sliding_window
)
sliding_window
)
else
:
else
:
# Turing and Volta NVIDIA GPUs or AMD GPUs.
# Or FP32 on any GPU.
from
vllm.model_executor.layers.attention.backends.xformers
import
XFormersBackend
from
vllm.model_executor.layers.attention.backends.xformers
import
XFormersBackend
self
.
backend
=
XFormersBackend
(
num_heads
,
head_size
,
scale
,
self
.
backend
=
XFormersBackend
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
num_kv_heads
,
alibi_slopes
,
...
@@ -57,3 +56,29 @@ class Attention(nn.Module):
...
@@ -57,3 +56,29 @@ class Attention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
backend
.
forward
(
query
,
key
,
value
,
key_cache
,
value_cache
,
return
self
.
backend
.
forward
(
query
,
key
,
value
,
key_cache
,
value_cache
,
input_metadata
)
input_metadata
)
@
lru_cache
(
maxsize
=
1
)
def
_use_flash_attn
()
->
bool
:
try
:
import
flash_attn
# noqa: F401
except
ImportError
:
logger
.
info
(
"flash_attn is not found. Using xformers backend."
)
return
False
if
is_hip
():
# AMD GPUs.
return
False
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
# Volta and Turing NVIDIA GPUs.
logger
.
info
(
"flash_attn is not supported on Turing or older GPUs. "
"Using xformers backend."
)
return
False
if
torch
.
get_default_dtype
()
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
logger
.
info
(
"flash_attn only supports torch.float16 or torch.bfloat16. "
"Using xformers backend."
)
return
False
logger
.
info
(
"Using flash_attn backend."
)
return
True
vllm/model_executor/layers/attention/backends/flash_attn.py
View file @
1cb0cc29
"""Attention layer with Flash and PagedAttention."""
"""Attention layer with Flash and PagedAttention."""
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/.
from
flash_attn
import
flash_attn_func
from
flash_attn
import
flash_attn_func
import
torch
import
torch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment