Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
add4f0bc
Commit
add4f0bc
authored
May 30, 2023
by
Pierce Freeman
Browse files
Scaffolding for wheel prototype
parent
85b51d61
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
105 additions
and
18 deletions
+105
-18
.github/workflows/publish.yml
.github/workflows/publish.yml
+54
-17
setup.py
setup.py
+51
-1
No files found.
.github/workflows/publish.yml
View file @
add4f0bc
...
...
@@ -10,7 +10,7 @@ on:
-
'
**'
jobs
:
release
:
setup_
release
:
name
:
Create Release
runs-on
:
ubuntu-latest
steps
:
...
...
@@ -28,22 +28,26 @@ jobs:
tag_name
:
${{ steps.extract_branch.outputs.branch }}
release_name
:
${{ steps.extract_branch.outputs.branch }}
wheel
:
build_
wheel
s
:
name
:
Build Wheel
runs-on
:
${{ matrix.os }}
needs
:
release
needs
:
setup_
release
strategy
:
fail-fast
:
false
matrix
:
# os: [ubuntu-20.04]
os
:
[
ubuntu-18.04
]
python-version
:
[
'
3.7'
,
'
3.8'
,
'
3.9'
,
'
3.10'
]
torch-version
:
[
1.11.0
,
1.12.0
,
1.12.1
]
cuda-version
:
[
'
113'
,
'
116'
]
exclude
:
-
torch-version
:
1.11.0
cuda-version
:
'
116'
# TODO: @pierce - again, simplify for prototyping
os
:
[
ubuntu-20.04
]
#os: [ubuntu-20.04, ubuntu-22.04]
# python-version: ['3.7', '3.8', '3.9', '3.10']
python-version
:
[
'
3.10'
]
#torch-version: [1.11.0, 1.12.0, 1.12.1]
torch-version
:
[
1.12.1
]
#cuda-version: ['113', '116']
cuda-version
:
[
'
113'
]
#exclude:
# - torch-version: 1.11.0
# cuda-version: '116'
steps
:
-
name
:
Checkout
...
...
@@ -108,11 +112,11 @@ jobs:
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
export CUDA_INSTALL_DIR=/usr/local/cuda-11.3$CUDA_INSTALL_DIR
pip install wheel
pip install
ninja packaging setuptools
wheel
python setup.py bdist_wheel --dist-dir=dist
tmpname=cu${{ matrix.cuda-version }}torch${{ matrix.torch-version }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} ${wheel_name}
ls dist/*whl |xargs -I {} mv {}
dist/
${wheel_name}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
-
name
:
Upload Release Asset
...
...
@@ -125,3 +129,36 @@ jobs:
asset_path
:
./${{env.wheel_name}}
asset_name
:
${{env.wheel_name}}
asset_content_type
:
application/*
publish_package
:
name
:
Publish package
needs
:
[
build_wheels
]
runs-on
:
ubuntu-latest
steps
:
-
uses
:
actions/checkout@v3
-
uses
:
actions/setup-python@v4
with
:
python-version
:
'
3.10'
-
name
:
List contents
run
:
|
ls -la dist
ls -la dist/*
-
name
:
Install dependencies
run
:
|
pip install ninja packaging setuptools wheel twine
-
name
:
Build core package
run
:
|
python setup.py sdist --dist-dir=dist
-
name
:
Deploy
env
:
TWINE_USERNAME
:
${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD
:
${{ secrets.PYPI_TOKEN }}
run
:
|
python -m twine upload dist/*
setup.py
View file @
add4f0bc
...
...
@@ -10,6 +10,7 @@ from packaging.version import parse, Version
from
setuptools
import
setup
,
find_packages
import
subprocess
import
urllib
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
...
...
@@ -22,6 +23,50 @@ with open("README.md", "r", encoding="utf-8") as fh:
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
def
get_platform
():
"""
Returns the platform string.
"""
if
sys
.
platform
.
startswith
(
'linux'
):
return
'linux_x86_64'
elif
sys
.
platform
==
'darwin'
:
return
'macosx_10_9_x86_64'
elif
sys
.
platform
==
'win32'
:
return
'win_amd64'
else
:
raise
ValueError
(
'Unsupported platform: {}'
.
format
(
sys
.
platform
))
from
setuptools.command.install
import
install
# @pierce - TODO: Remove for proper release
BASE_WHEEL_URL
=
"https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}"
class
CustomInstallCommand
(
install
):
def
run
(
self
):
# Determine the version numbers that will be used to determine the correct wheel
_
,
cuda_version
=
get_cuda_bare_metal_version
()
torch_version
=
torch
.
__version__
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
flash_version
=
get_package_version
()
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename
=
f
'flash_attn-
{
flash_version
}
+cu
{
cuda_version
}
torch
{
torch_version
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl'
wheel_url
=
BASE_WHEEL_URL
.
format
(
tag_name
=
f
"v
{
flash_version
}
"
,
wheel_name
=
wheel_filename
)
try
:
urllib
.
request
.
urlretrieve
(
wheel_url
,
wheel_filename
)
os
.
system
(
f
'pip install
{
wheel_filename
}
'
)
os
.
remove
(
wheel_filename
)
except
urllib
.
error
.
HTTPError
:
print
(
"Precompiled wheel not found. Building from source..."
)
# If the wheel could not be downloaded, build from source
install
.
run
(
self
)
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
...
...
@@ -190,7 +235,12 @@ setup(
"Operating System :: Unix"
,
],
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
BuildExtension
}
if
ext_modules
else
{},
cmdclass
=
{
'install'
:
CustomInstallCommand
,
"build_ext"
:
BuildExtension
}
if
ext_modules
else
{
'install'
:
CustomInstallCommand
,
},
python_requires
=
">=3.7"
,
install_requires
=
[
"torch"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment