Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
498cd8c3
Commit
498cd8c3
authored
Mar 28, 2024
by
Woosuk Kwon
Browse files
flash-attn -> vllm-flash-attn
parent
ae856f3a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
70 deletions
+16
-70
setup.py
setup.py
+14
-68
vllm_flash_attn/__init__.py
vllm_flash_attn/__init__.py
+1
-1
vllm_flash_attn/flash_attn_interface.py
vllm_flash_attn/flash_attn_interface.py
+1
-1
No files found.
setup.py
View file @
498cd8c3
...
...
@@ -32,7 +32,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
# ninja build does not work unless include_dirs are abs path
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
PACKAGE_NAME
=
"flash_attn"
PACKAGE_NAME
=
"
vllm_
flash_attn"
BASE_WHEEL_URL
=
(
"https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
...
...
@@ -106,7 +106,7 @@ if not SKIP_CUDA_BUILD:
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
check_if_cuda_home_none
(
"flash_attn"
)
check_if_cuda_home_none
(
PACKAGE_NAME
)
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
if
CUDA_HOME
is
not
None
:
...
...
@@ -132,7 +132,7 @@ if not SKIP_CUDA_BUILD:
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
=
True
ext_modules
.
append
(
CUDAExtension
(
name
=
"flash_attn_2_cuda"
,
name
=
"
vllm_
flash_attn_2_cuda"
,
sources
=
[
"csrc/flash_attn/flash_api.cpp"
,
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu"
,
...
...
@@ -215,7 +215,7 @@ if not SKIP_CUDA_BUILD:
def
get_package_version
():
with
open
(
Path
(
this_dir
)
/
"flash_attn"
/
"__init__.py"
,
"r"
)
as
f
:
with
open
(
Path
(
this_dir
)
/
PACKAGE_NAME
/
"__init__.py"
,
"r"
)
as
f
:
version_match
=
re
.
search
(
r
"^__version__\s*=\s*(.*)$"
,
f
.
read
(),
re
.
MULTILINE
)
public_version
=
ast
.
literal_eval
(
version_match
.
group
(
1
))
local_version
=
os
.
environ
.
get
(
"FLASH_ATTN_LOCAL_VERSION"
)
...
...
@@ -225,29 +225,6 @@ def get_package_version():
return
str
(
public_version
)
def
get_wheel_url
():
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
torch_version_raw
=
parse
(
torch
.
__version__
)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
# to save CI time. Minor versions should be compatible.
torch_cuda_version
=
parse
(
"11.8"
)
if
torch_cuda_version
.
major
==
11
else
parse
(
"12.2"
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
flash_version
=
get_package_version
()
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version
=
f
"
{
torch_cuda_version
.
major
}{
torch_cuda_version
.
minor
}
"
torch_version
=
f
"
{
torch_version_raw
.
major
}
.
{
torch_version_raw
.
minor
}
"
cxx11_abi
=
str
(
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
).
upper
()
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename
=
f
"
{
PACKAGE_NAME
}
-
{
flash_version
}
+cu
{
cuda_version
}
torch
{
torch_version
}
cxx11abi
{
cxx11_abi
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl"
wheel_url
=
BASE_WHEEL_URL
.
format
(
tag_name
=
f
"v
{
flash_version
}
"
,
wheel_name
=
wheel_filename
)
return
wheel_url
,
wheel_filename
class
CachedWheelsCommand
(
_bdist_wheel
):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
...
...
@@ -260,28 +237,6 @@ class CachedWheelsCommand(_bdist_wheel):
if
FORCE_BUILD
:
return
super
().
run
()
wheel_url
,
wheel_filename
=
get_wheel_url
()
print
(
"Guessing wheel URL: "
,
wheel_url
)
try
:
urllib
.
request
.
urlretrieve
(
wheel_url
,
wheel_filename
)
# Make the archive
# Lifted from the root wheel processing command
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
if
not
os
.
path
.
exists
(
self
.
dist_dir
):
os
.
makedirs
(
self
.
dist_dir
)
impl_tag
,
abi_tag
,
plat_tag
=
self
.
get_tag
()
archive_basename
=
f
"
{
self
.
wheel_dist_name
}
-
{
impl_tag
}
-
{
abi_tag
}
-
{
plat_tag
}
"
wheel_path
=
os
.
path
.
join
(
self
.
dist_dir
,
archive_basename
+
".whl"
)
print
(
"Raw wheel path"
,
wheel_path
)
os
.
rename
(
wheel_filename
,
wheel_path
)
except
urllib
.
error
.
HTTPError
:
print
(
"Precompiled wheel not found. Building from source..."
)
# If the wheel could not be downloaded, build from source
super
().
run
()
class
NinjaBuildExtension
(
BuildExtension
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
...
...
@@ -304,7 +259,7 @@ class NinjaBuildExtension(BuildExtension):
setup
(
name
=
PACKAGE_NAME
,
name
=
"vllm-flash-attn"
,
version
=
get_package_version
(),
packages
=
find_packages
(
exclude
=
(
...
...
@@ -315,15 +270,13 @@ setup(
"dist"
,
"docs"
,
"benchmarks"
,
"flash_attn
.egg-info"
,
f
"
{
PACKAGE_NAME
}
.egg-info"
,
)
),
author
=
"Tri Dao"
,
author_email
=
"trid@cs.stanford.edu"
,
description
=
"Flash Attention: Fast and Memory-Efficient Exact Attention"
,
long_description
=
long_description
,
long_description_content_type
=
"text/markdown"
,
url
=
"https://github.com/Dao-AILab/flash-attention"
,
author
=
"vLLM Team"
,
description
=
"Forward-only flash-attn"
,
long_description
=
"Forward-only flash-attn package built for PyTorch 2.1.2 and CUDA 12.1"
,
url
=
"https://github.com/vllm-project/flash-attention.git"
,
classifiers
=
[
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: BSD License"
,
...
...
@@ -335,14 +288,7 @@ setup(
else
{
"bdist_wheel"
:
CachedWheelsCommand
,
},
python_requires
=
">=3.7"
,
install_requires
=
[
"torch"
,
"einops"
,
"packaging"
,
"ninja"
,
],
setup_requires
=
[
"psutil"
],
)
\ No newline at end of file
python_requires
=
">=3.8"
,
install_requires
=
[
"torch == 2.1.2"
],
setup_requires
=
[
"psutil"
],
)
vllm_flash_attn/__init__.py
View file @
498cd8c3
__version__
=
"2.5.6"
from
flash_attn.flash_attn_interface
import
(
from
vllm_
flash_attn.flash_attn_interface
import
(
flash_attn_func
,
flash_attn_kvpacked_func
,
flash_attn_qkvpacked_func
,
...
...
vllm_flash_attn/flash_attn_interface.py
View file @
498cd8c3
...
...
@@ -7,7 +7,7 @@ import torch.nn as nn
# isort: off
# We need to import the CUDA kernels after importing torch
import
flash_attn_2_cuda
as
flash_attn_cuda
import
vllm_
flash_attn_2_cuda
as
flash_attn_cuda
# isort: on
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment