Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
498cd8c3
Commit
498cd8c3
authored
Mar 28, 2024
by
Woosuk Kwon
Browse files
flash-attn -> vllm-flash-attn
parent
ae856f3a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
70 deletions
+16
-70
setup.py
setup.py
+14
-68
vllm_flash_attn/__init__.py
vllm_flash_attn/__init__.py
+1
-1
vllm_flash_attn/flash_attn_interface.py
vllm_flash_attn/flash_attn_interface.py
+1
-1
No files found.
setup.py
View file @
498cd8c3
...
@@ -32,7 +32,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
...
@@ -32,7 +32,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
# ninja build does not work unless include_dirs are abs path
# ninja build does not work unless include_dirs are abs path
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
PACKAGE_NAME
=
"flash_attn"
PACKAGE_NAME
=
"
vllm_
flash_attn"
BASE_WHEEL_URL
=
(
BASE_WHEEL_URL
=
(
"https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
"https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
...
@@ -106,7 +106,7 @@ if not SKIP_CUDA_BUILD:
...
@@ -106,7 +106,7 @@ if not SKIP_CUDA_BUILD:
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
check_if_cuda_home_none
(
"flash_attn"
)
check_if_cuda_home_none
(
PACKAGE_NAME
)
# Check, if CUDA11 is installed for compute capability 8.0
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
cc_flag
=
[]
if
CUDA_HOME
is
not
None
:
if
CUDA_HOME
is
not
None
:
...
@@ -132,7 +132,7 @@ if not SKIP_CUDA_BUILD:
...
@@ -132,7 +132,7 @@ if not SKIP_CUDA_BUILD:
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
=
True
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
=
True
ext_modules
.
append
(
ext_modules
.
append
(
CUDAExtension
(
CUDAExtension
(
name
=
"flash_attn_2_cuda"
,
name
=
"
vllm_
flash_attn_2_cuda"
,
sources
=
[
sources
=
[
"csrc/flash_attn/flash_api.cpp"
,
"csrc/flash_attn/flash_api.cpp"
,
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu"
,
...
@@ -215,7 +215,7 @@ if not SKIP_CUDA_BUILD:
...
@@ -215,7 +215,7 @@ if not SKIP_CUDA_BUILD:
def
get_package_version
():
def
get_package_version
():
with
open
(
Path
(
this_dir
)
/
"flash_attn"
/
"__init__.py"
,
"r"
)
as
f
:
with
open
(
Path
(
this_dir
)
/
PACKAGE_NAME
/
"__init__.py"
,
"r"
)
as
f
:
version_match
=
re
.
search
(
r
"^__version__\s*=\s*(.*)$"
,
f
.
read
(),
re
.
MULTILINE
)
version_match
=
re
.
search
(
r
"^__version__\s*=\s*(.*)$"
,
f
.
read
(),
re
.
MULTILINE
)
public_version
=
ast
.
literal_eval
(
version_match
.
group
(
1
))
public_version
=
ast
.
literal_eval
(
version_match
.
group
(
1
))
local_version
=
os
.
environ
.
get
(
"FLASH_ATTN_LOCAL_VERSION"
)
local_version
=
os
.
environ
.
get
(
"FLASH_ATTN_LOCAL_VERSION"
)
...
@@ -225,29 +225,6 @@ def get_package_version():
...
@@ -225,29 +225,6 @@ def get_package_version():
return
str
(
public_version
)
return
str
(
public_version
)
def
get_wheel_url
():
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
torch_version_raw
=
parse
(
torch
.
__version__
)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
# to save CI time. Minor versions should be compatible.
torch_cuda_version
=
parse
(
"11.8"
)
if
torch_cuda_version
.
major
==
11
else
parse
(
"12.2"
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
flash_version
=
get_package_version
()
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version
=
f
"
{
torch_cuda_version
.
major
}{
torch_cuda_version
.
minor
}
"
torch_version
=
f
"
{
torch_version_raw
.
major
}
.
{
torch_version_raw
.
minor
}
"
cxx11_abi
=
str
(
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
).
upper
()
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename
=
f
"
{
PACKAGE_NAME
}
-
{
flash_version
}
+cu
{
cuda_version
}
torch
{
torch_version
}
cxx11abi
{
cxx11_abi
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl"
wheel_url
=
BASE_WHEEL_URL
.
format
(
tag_name
=
f
"v
{
flash_version
}
"
,
wheel_name
=
wheel_filename
)
return
wheel_url
,
wheel_filename
class
CachedWheelsCommand
(
_bdist_wheel
):
class
CachedWheelsCommand
(
_bdist_wheel
):
"""
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
...
@@ -260,28 +237,6 @@ class CachedWheelsCommand(_bdist_wheel):
...
@@ -260,28 +237,6 @@ class CachedWheelsCommand(_bdist_wheel):
if
FORCE_BUILD
:
if
FORCE_BUILD
:
return
super
().
run
()
return
super
().
run
()
wheel_url
,
wheel_filename
=
get_wheel_url
()
print
(
"Guessing wheel URL: "
,
wheel_url
)
try
:
urllib
.
request
.
urlretrieve
(
wheel_url
,
wheel_filename
)
# Make the archive
# Lifted from the root wheel processing command
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
if
not
os
.
path
.
exists
(
self
.
dist_dir
):
os
.
makedirs
(
self
.
dist_dir
)
impl_tag
,
abi_tag
,
plat_tag
=
self
.
get_tag
()
archive_basename
=
f
"
{
self
.
wheel_dist_name
}
-
{
impl_tag
}
-
{
abi_tag
}
-
{
plat_tag
}
"
wheel_path
=
os
.
path
.
join
(
self
.
dist_dir
,
archive_basename
+
".whl"
)
print
(
"Raw wheel path"
,
wheel_path
)
os
.
rename
(
wheel_filename
,
wheel_path
)
except
urllib
.
error
.
HTTPError
:
print
(
"Precompiled wheel not found. Building from source..."
)
# If the wheel could not be downloaded, build from source
super
().
run
()
class
NinjaBuildExtension
(
BuildExtension
):
class
NinjaBuildExtension
(
BuildExtension
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
...
@@ -304,7 +259,7 @@ class NinjaBuildExtension(BuildExtension):
...
@@ -304,7 +259,7 @@ class NinjaBuildExtension(BuildExtension):
setup
(
setup
(
name
=
PACKAGE_NAME
,
name
=
"vllm-flash-attn"
,
version
=
get_package_version
(),
version
=
get_package_version
(),
packages
=
find_packages
(
packages
=
find_packages
(
exclude
=
(
exclude
=
(
...
@@ -315,15 +270,13 @@ setup(
...
@@ -315,15 +270,13 @@ setup(
"dist"
,
"dist"
,
"docs"
,
"docs"
,
"benchmarks"
,
"benchmarks"
,
"flash_attn
.egg-info"
,
f
"
{
PACKAGE_NAME
}
.egg-info"
,
)
)
),
),
author
=
"Tri Dao"
,
author
=
"vLLM Team"
,
author_email
=
"trid@cs.stanford.edu"
,
description
=
"Forward-only flash-attn"
,
description
=
"Flash Attention: Fast and Memory-Efficient Exact Attention"
,
long_description
=
"Forward-only flash-attn package built for PyTorch 2.1.2 and CUDA 12.1"
,
long_description
=
long_description
,
url
=
"https://github.com/vllm-project/flash-attention.git"
,
long_description_content_type
=
"text/markdown"
,
url
=
"https://github.com/Dao-AILab/flash-attention"
,
classifiers
=
[
classifiers
=
[
"Programming Language :: Python :: 3"
,
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: BSD License"
,
"License :: OSI Approved :: BSD License"
,
...
@@ -335,14 +288,7 @@ setup(
...
@@ -335,14 +288,7 @@ setup(
else
{
else
{
"bdist_wheel"
:
CachedWheelsCommand
,
"bdist_wheel"
:
CachedWheelsCommand
,
},
},
python_requires
=
">=3.7"
,
python_requires
=
">=3.8"
,
install_requires
=
[
install_requires
=
[
"torch == 2.1.2"
],
"torch"
,
setup_requires
=
[
"psutil"
],
"einops"
,
"packaging"
,
"ninja"
,
],
setup_requires
=
[
"psutil"
],
)
)
vllm_flash_attn/__init__.py
View file @
498cd8c3
__version__
=
"2.5.6"
__version__
=
"2.5.6"
from
flash_attn.flash_attn_interface
import
(
from
vllm_
flash_attn.flash_attn_interface
import
(
flash_attn_func
,
flash_attn_func
,
flash_attn_kvpacked_func
,
flash_attn_kvpacked_func
,
flash_attn_qkvpacked_func
,
flash_attn_qkvpacked_func
,
...
...
vllm_flash_attn/flash_attn_interface.py
View file @
498cd8c3
...
@@ -7,7 +7,7 @@ import torch.nn as nn
...
@@ -7,7 +7,7 @@ import torch.nn as nn
# isort: off
# isort: off
# We need to import the CUDA kernels after importing torch
# We need to import the CUDA kernels after importing torch
import
flash_attn_2_cuda
as
flash_attn_cuda
import
vllm_
flash_attn_2_cuda
as
flash_attn_cuda
# isort: on
# isort: on
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment