Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ea2ed886
Commit
ea2ed886
authored
Jun 02, 2023
by
Pierce Freeman
Browse files
Refactor and clean of setup.py
parent
9fc9820a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
133 additions
and
115 deletions
+133
-115
.github/workflows/publish.yml
.github/workflows/publish.yml
+2
-0
setup.py
setup.py
+131
-115
No files found.
.github/workflows/publish.yml
View file @
ea2ed886
...
@@ -150,6 +150,8 @@ jobs:
...
@@ -150,6 +150,8 @@ jobs:
pip install ninja packaging setuptools wheel twine
pip install ninja packaging setuptools wheel twine
-
name
:
Build core package
-
name
:
Build core package
env
:
FLASH_ATTENTION_SKIP_CUDA_BUILD
:
"
TRUE"
run
:
|
run
:
|
python setup.py sdist --dist-dir=dist
python setup.py sdist --dist-dir=dist
...
...
setup.py
View file @
ea2ed886
...
@@ -6,8 +6,10 @@ import re
...
@@ -6,8 +6,10 @@ import re
import
ast
import
ast
from
pathlib
import
Path
from
pathlib
import
Path
from
packaging.version
import
parse
,
Version
from
packaging.version
import
parse
,
Version
import
platform
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
from
setuptools.command.install
import
install
import
subprocess
import
subprocess
import
urllib.request
import
urllib.request
...
@@ -24,60 +26,29 @@ with open("README.md", "r", encoding="utf-8") as fh:
...
@@ -24,60 +26,29 @@ with open("README.md", "r", encoding="utf-8") as fh:
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
# @pierce - TODO: Update for proper release
BASE_WHEEL_URL
=
"https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}"
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD
=
os
.
getenv
(
"FLASH_ATTENTION_FORCE_BUILD"
,
"FALSE"
)
==
"TRUE"
SKIP_CUDA_BUILD
=
os
.
getenv
(
"FLASH_ATTENTION_SKIP_CUDA_BUILD"
,
"FALSE"
)
==
"TRUE"
def
get_platform
():
def
get_platform
():
"""
"""
Returns the platform
string
.
Returns the platform
name as used in wheel filenames
.
"""
"""
if
sys
.
platform
.
startswith
(
'linux'
):
if
sys
.
platform
.
startswith
(
'linux'
):
return
'linux_x86_64'
return
'linux_x86_64'
elif
sys
.
platform
==
'darwin'
:
elif
sys
.
platform
==
'darwin'
:
return
'macosx_10_9_x86_64'
mac_version
=
'.'
.
join
(
platform
.
mac_ver
()[
0
].
split
(
'.'
)[:
2
])
return
f
'macosx_
{
mac_version
}
_x86_64'
elif
sys
.
platform
==
'win32'
:
elif
sys
.
platform
==
'win32'
:
return
'win_amd64'
return
'win_amd64'
else
:
else
:
raise
ValueError
(
'Unsupported platform: {}'
.
format
(
sys
.
platform
))
raise
ValueError
(
'Unsupported platform: {}'
.
format
(
sys
.
platform
))
from
setuptools.command.install
import
install
# @pierce - TODO: Remove for proper release
BASE_WHEEL_URL
=
"https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}"
class
CustomInstallCommand
(
install
):
def
run
(
self
):
if
os
.
getenv
(
"FLASH_ATTENTION_FORCE_BUILD"
,
"FALSE"
)
==
"TRUE"
:
return
install
.
run
(
self
)
raise_if_cuda_home_none
(
"flash_attn"
)
# Determine the version numbers that will be used to determine the correct wheel
_
,
cuda_version_raw
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
torch_version_raw
=
parse
(
torch
.
__version__
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
flash_version
=
get_package_version
()
cuda_version
=
f
"
{
cuda_version_raw
.
major
}{
cuda_version_raw
.
minor
}
"
torch_version
=
f
"
{
torch_version_raw
.
major
}
.
{
torch_version_raw
.
minor
}
.
{
torch_version_raw
.
micro
}
"
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename
=
f
'flash_attn-
{
flash_version
}
+cu
{
cuda_version
}
torch
{
torch_version
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl'
wheel_url
=
BASE_WHEEL_URL
.
format
(
#tag_name=f"v{flash_version}",
# HACK
tag_name
=
f
"v0.0.5"
,
wheel_name
=
wheel_filename
)
print
(
"Guessing wheel URL: "
,
wheel_url
)
try
:
urllib
.
request
.
urlretrieve
(
wheel_url
,
wheel_filename
)
os
.
system
(
f
'pip install
{
wheel_filename
}
'
)
os
.
remove
(
wheel_filename
)
except
urllib
.
error
.
HTTPError
:
print
(
"Precompiled wheel not found. Building from source..."
)
# If the wheel could not be downloaded, build from source
#install.run(self)
raise
ValueError
def
get_cuda_bare_metal_version
(
cuda_dir
):
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
...
@@ -147,77 +118,77 @@ if not torch.cuda.is_available():
...
@@ -147,77 +118,77 @@ if not torch.cuda.is_available():
else
:
else
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5"
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
cmdclass
=
{}
cmdclass
=
{}
ext_modules
=
[]
ext_modules
=
[]
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
if
not
SKIP_CUDA_BUILD
:
# See https://github.com/pytorch/pytorch/pull/70650
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
generator_flag
=
[]
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
torch_dir
=
torch
.
__path__
[
0
]
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
raise_if_cuda_home_none
(
"flash_attn"
)
generator_flag
=
[]
# Check, if CUDA11 is installed for compute capability 8.0
torch_dir
=
torch
.
__path__
[
0
]
cc_flag
=
[]
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
if
bare_metal_version
<
Version
(
"11.0"
):
raise
RuntimeError
(
"FlashAttention is only supported on CUDA 11 and above"
)
raise_if_cuda_home_none
(
"flash_attn"
)
cc_flag
.
append
(
"-gencode"
)
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
.
append
(
"arch=compute_75,code=sm_75"
)
cc_flag
=
[]
cc_flag
.
append
(
"-gencode"
)
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
if
bare_metal_version
<
Version
(
"11.0"
):
if
bare_metal_version
>=
Version
(
"11.8"
):
raise
RuntimeError
(
"FlashAttention is only supported on CUDA 11 and above"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_75,code=sm_75"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_90,code=sm_90"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
if
bare_metal_version
>=
Version
(
"11.8"
):
subprocess
.
run
([
"git"
,
"submodule"
,
"update"
,
"--init"
,
"csrc/flash_attn/cutlass"
])
cc_flag
.
append
(
"-gencode"
)
ext_modules
.
append
(
cc_flag
.
append
(
"arch=compute_90,code=sm_90"
)
CUDAExtension
(
name
=
"flash_attn_cuda"
,
subprocess
.
run
([
"git"
,
"submodule"
,
"update"
,
"--init"
,
"csrc/flash_attn/cutlass"
])
sources
=
[
ext_modules
.
append
(
"csrc/flash_attn/fmha_api.cpp"
,
CUDAExtension
(
"csrc/flash_attn/src/fmha_fwd_hdim32.cu"
,
name
=
"flash_attn_cuda"
,
"csrc/flash_attn/src/fmha_fwd_hdim64.cu"
,
sources
=
[
"csrc/flash_attn/src/fmha_fwd_hdim128.cu"
,
"csrc/flash_attn/fmha_api.cpp"
,
"csrc/flash_attn/src/fmha_bwd_hdim32.cu"
,
"csrc/flash_attn/src/fmha_fwd_hdim32.cu"
,
"csrc/flash_attn/src/fmha_bwd_hdim64.cu"
,
"csrc/flash_attn/src/fmha_fwd_hdim64.cu"
,
"csrc/flash_attn/src/fmha_bwd_hdim128.cu"
,
"csrc/flash_attn/src/fmha_fwd_hdim128.cu"
,
"csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu"
,
"csrc/flash_attn/src/fmha_bwd_hdim32.cu"
,
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu"
,
"csrc/flash_attn/src/fmha_bwd_hdim64.cu"
,
],
"csrc/flash_attn/src/fmha_bwd_hdim128.cu"
,
extra_compile_args
=
{
"csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu"
,
"cxx"
:
[
"-O3"
,
"-std=c++17"
]
+
generator_flag
,
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu"
,
"nvcc"
:
append_nvcc_threads
(
],
[
extra_compile_args
=
{
"-O3"
,
"cxx"
:
[
"-O3"
,
"-std=c++17"
]
+
generator_flag
,
"-std=c++17"
,
"nvcc"
:
append_nvcc_threads
(
"-U__CUDA_NO_HALF_OPERATORS__"
,
[
"-U__CUDA_NO_HALF_CONVERSIONS__"
,
"-O3"
,
"-U__CUDA_NO_HALF2_OPERATORS__"
,
"-std=c++17"
,
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
,
"-U__CUDA_NO_HALF_OPERATORS__"
,
"--expt-relaxed-constexpr"
,
"-U__CUDA_NO_HALF_CONVERSIONS__"
,
"--expt-extended-lambda"
,
"-U__CUDA_NO_HALF2_OPERATORS__"
,
"--use_fast_math"
,
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
,
"--ptxas-options=-v"
,
"--expt-relaxed-constexpr"
,
"-lineinfo"
"--expt-extended-lambda"
,
]
"--use_fast_math"
,
+
generator_flag
"--ptxas-options=-v"
,
+
cc_flag
"-lineinfo"
),
]
},
+
generator_flag
include_dirs
=
[
+
cc_flag
Path
(
this_dir
)
/
'csrc'
/
'flash_attn'
,
),
Path
(
this_dir
)
/
'csrc'
/
'flash_attn'
/
'src'
,
},
Path
(
this_dir
)
/
'csrc'
/
'flash_attn'
/
'cutlass'
/
'include'
,
include_dirs
=
[
],
Path
(
this_dir
)
/
'csrc'
/
'flash_attn'
,
Path
(
this_dir
)
/
'csrc'
/
'flash_attn'
/
'src'
,
Path
(
this_dir
)
/
'csrc'
/
'flash_attn'
/
'cutlass'
/
'include'
,
],
)
)
)
)
def
get_package_version
():
def
get_package_version
():
with
open
(
Path
(
this_dir
)
/
"flash_attn"
/
"__init__.py"
,
"r"
)
as
f
:
with
open
(
Path
(
this_dir
)
/
"flash_attn"
/
"__init__.py"
,
"r"
)
as
f
:
...
@@ -229,18 +200,63 @@ def get_package_version():
...
@@ -229,18 +200,63 @@ def get_package_version():
else
:
else
:
return
str
(
public_version
)
return
str
(
public_version
)
class
CachedWheelsCommand
(
install
):
"""
Installer hook to scan for existing wheels that match the current platform environment.
Falls back to building from source if no wheel is found.
"""
def
run
(
self
):
if
FORCE_BUILD
:
return
install
.
run
(
self
)
raise_if_cuda_home_none
(
"flash_attn"
)
# Determine the version numbers that will be used to determine the correct wheel
_
,
cuda_version_raw
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
torch_version_raw
=
parse
(
torch
.
__version__
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
flash_version
=
get_package_version
()
cuda_version
=
f
"
{
cuda_version_raw
.
major
}{
cuda_version_raw
.
minor
}
"
torch_version
=
f
"
{
torch_version_raw
.
major
}
.
{
torch_version_raw
.
minor
}
.
{
torch_version_raw
.
micro
}
"
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename
=
f
'flash_attn-
{
flash_version
}
+cu
{
cuda_version
}
torch
{
torch_version
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl'
wheel_url
=
BASE_WHEEL_URL
.
format
(
tag_name
=
f
"v
{
flash_version
}
"
,
wheel_name
=
wheel_filename
)
print
(
"Guessing wheel URL: "
,
wheel_url
)
try
:
urllib
.
request
.
urlretrieve
(
wheel_url
,
wheel_filename
)
os
.
system
(
f
'pip install
{
wheel_filename
}
'
)
os
.
remove
(
wheel_filename
)
except
urllib
.
error
.
HTTPError
:
print
(
"Precompiled wheel not found. Building from source..."
)
# If the wheel could not be downloaded, build from source
install
.
run
(
self
)
setup
(
setup
(
name
=
"flash_attn"
,
# @pierce - TODO: Revert for official release
name
=
"flash_attn_wheels"
,
version
=
get_package_version
(),
version
=
get_package_version
(),
packages
=
find_packages
(
packages
=
find_packages
(
exclude
=
(
"build"
,
"csrc"
,
"include"
,
"tests"
,
"dist"
,
"docs"
,
"benchmarks"
,
"flash_attn.egg-info"
,)
exclude
=
(
"build"
,
"csrc"
,
"include"
,
"tests"
,
"dist"
,
"docs"
,
"benchmarks"
,
"flash_attn.egg-info"
,)
),
),
author
=
"Tri Dao"
,
#author="Tri Dao",
author_email
=
"trid@stanford.edu"
,
#author_email="trid@stanford.edu",
# @pierce - TODO: Revert for official release
author
=
"Pierce Freeman"
,
author_email
=
"pierce@freeman.vc"
,
description
=
"Flash Attention: Fast and Memory-Efficient Exact Attention"
,
description
=
"Flash Attention: Fast and Memory-Efficient Exact Attention"
,
long_description
=
long_description
,
long_description
=
long_description
,
long_description_content_type
=
"text/markdown"
,
long_description_content_type
=
"text/markdown"
,
url
=
"https://github.com/HazyResearch/flash-attention"
,
#url="https://github.com/HazyResearch/flash-attention",
url
=
"https://github.com/piercefreeman/flash-attention"
,
classifiers
=
[
classifiers
=
[
"Programming Language :: Python :: 3"
,
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: BSD License"
,
"License :: OSI Approved :: BSD License"
,
...
@@ -248,10 +264,10 @@ setup(
...
@@ -248,10 +264,10 @@ setup(
],
],
ext_modules
=
ext_modules
,
ext_modules
=
ext_modules
,
cmdclass
=
{
cmdclass
=
{
'install'
:
C
ustomInstall
Command
,
'install'
:
C
achedWheels
Command
,
"build_ext"
:
BuildExtension
"build_ext"
:
BuildExtension
}
if
ext_modules
else
{
}
if
ext_modules
else
{
'install'
:
C
ustomInstall
Command
,
'install'
:
C
achedWheels
Command
,
},
},
python_requires
=
">=3.7"
,
python_requires
=
">=3.7"
,
install_requires
=
[
install_requires
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment