Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d4a7c8ff
Commit
d4a7c8ff
authored
Nov 27, 2023
by
Tri Dao
Browse files
[CI] Only compile for CUDA 11.8 & 12.2, MAX_JOBS=2,add torch-nightly
parent
ce3e7280
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
34 deletions
+21
-34
.github/workflows/publish.yml
.github/workflows/publish.yml
+13
-25
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+2
-0
setup.py
setup.py
+6
-9
No files found.
.github/workflows/publish.yml
View file @
d4a7c8ff
...
@@ -44,8 +44,8 @@ jobs:
...
@@ -44,8 +44,8 @@ jobs:
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os
:
[
ubuntu-20.04
]
os
:
[
ubuntu-20.04
]
python-version
:
[
'
3.7'
,
'
3.8'
,
'
3.9'
,
'
3.10'
,
'
3.11'
]
python-version
:
[
'
3.7'
,
'
3.8'
,
'
3.9'
,
'
3.10'
,
'
3.11'
]
torch-version
:
[
'
1.12.1'
,
'
1.13.1'
,
'
2.0.1'
,
'
2.1.
0
'
]
torch-version
:
[
'
1.12.1'
,
'
1.13.1'
,
'
2.0.1'
,
'
2.1.
1'
,
'
2.2.0.dev20231127
'
]
cuda-version
:
[
'
11.
6.2'
,
'
11.7.1'
,
'
11.8.0'
,
'
12.1
.0'
,
'
12.2.0'
]
cuda-version
:
[
'
11.
8
.0'
,
'
12.2.0'
]
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
...
@@ -58,31 +58,17 @@ jobs:
...
@@ -58,31 +58,17 @@ jobs:
# Pytorch >= 2.0 only supports Python >= 3.8
# Pytorch >= 2.0 only supports Python >= 3.8
-
torch-version
:
'
2.0.1'
-
torch-version
:
'
2.0.1'
python-version
:
'
3.7'
python-version
:
'
3.7'
-
torch-version
:
'
2.1.0'
-
torch-version
:
'
2.1.1'
python-version
:
'
3.7'
-
torch-version
:
'
2.2.0.dev20231127'
python-version
:
'
3.7'
python-version
:
'
3.7'
# Pytorch <= 2.0 only supports CUDA <= 11.8
# Pytorch <= 2.0 only supports CUDA <= 11.8
-
torch-version
:
'
1.12.1'
cuda-version
:
'
12.1.0'
-
torch-version
:
'
1.12.1'
-
torch-version
:
'
1.12.1'
cuda-version
:
'
12.2.0'
cuda-version
:
'
12.2.0'
-
torch-version
:
'
1.13.1'
cuda-version
:
'
12.1.0'
-
torch-version
:
'
1.13.1'
-
torch-version
:
'
1.13.1'
cuda-version
:
'
12.2.0'
cuda-version
:
'
12.2.0'
-
torch-version
:
'
2.0.1'
cuda-version
:
'
12.1.0'
-
torch-version
:
'
2.0.1'
-
torch-version
:
'
2.0.1'
cuda-version
:
'
12.2.0'
cuda-version
:
'
12.2.0'
# Pytorch >= 2.1 only supports CUDA >= 11.8
-
torch-version
:
'
2.1.0'
cuda-version
:
'
11.6.2'
-
torch-version
:
'
2.1.0'
cuda-version
:
'
11.7.1'
# Pytorch >= 2.1 with nvcc 12.1.0 segfaults during compilation, so
# we only use CUDA 12.2. setup.py as a special case that will
# download the wheel for CUDA 12.2 instead.
-
torch-version
:
'
2.1.0'
cuda-version
:
'
12.1.0'
steps
:
steps
:
-
name
:
Checkout
-
name
:
Checkout
...
@@ -107,6 +93,12 @@ jobs:
...
@@ -107,6 +93,12 @@ jobs:
sudo rm -rf /opt/ghc
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo rm -rf /opt/hostedtoolcache/CodeQL
-
name
:
Set up swap space
if
:
runner.os == 'Linux'
uses
:
pierotofy/set-swap-space@v1.0
with
:
swap-size-gb
:
10
-
name
:
Install CUDA ${{ matrix.cuda-version }}
-
name
:
Install CUDA ${{ matrix.cuda-version }}
if
:
${{ matrix.cuda-version != 'cpu' }}
if
:
${{ matrix.cuda-version != 'cpu' }}
uses
:
Jimver/cuda-toolkit@v0.2.11
uses
:
Jimver/cuda-toolkit@v0.2.11
...
@@ -130,7 +122,7 @@ jobs:
...
@@ -130,7 +122,7 @@ jobs:
# We want to figure out the CUDA version to download pytorch
# We want to figure out the CUDA version to download pytorch
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# This code is ugly, maybe there's a better way to do this.
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))")
export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1':
118, '2.2':
118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1':
121, '2.2':
121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))")
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
else
else
...
@@ -153,12 +145,8 @@ jobs:
...
@@ -153,12 +145,8 @@ jobs:
pip install ninja packaging wheel
pip install ninja packaging wheel
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Currently for this setting the runner goes OOM if we pass --threads 4 to nvcc
if [[ ( ${MATRIX_CUDA_VERSION} == "121" || ${MATRIX_CUDA_VERSION} == "122" ) && ${MATRIX_TORCH_VERSION} == "2.1" ]]; then
export FLASH_ATTENTION_FORCE_SINGLE_THREAD="TRUE"
fi
# Limit MAX_JOBS otherwise the github runner goes OOM
# Limit MAX_JOBS otherwise the github runner goes OOM
MAX_JOBS=
1
FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
MAX_JOBS=
2
FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
...
...
flash_attn/flash_attn_interface.py
View file @
d4a7c8ff
# Copyright (c) 2023, Tri Dao.
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
import
torch
import
torch
...
...
setup.py
View file @
d4a7c8ff
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
# Copyright (c) 2023, Tri Dao.
import
sys
import
sys
import
warnings
import
warnings
import
os
import
os
...
@@ -43,8 +44,6 @@ FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
...
@@ -43,8 +44,6 @@ FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD
=
os
.
getenv
(
"FLASH_ATTENTION_SKIP_CUDA_BUILD"
,
"FALSE"
)
==
"TRUE"
SKIP_CUDA_BUILD
=
os
.
getenv
(
"FLASH_ATTENTION_SKIP_CUDA_BUILD"
,
"FALSE"
)
==
"TRUE"
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI
=
os
.
getenv
(
"FLASH_ATTENTION_FORCE_CXX11_ABI"
,
"FALSE"
)
==
"TRUE"
FORCE_CXX11_ABI
=
os
.
getenv
(
"FLASH_ATTENTION_FORCE_CXX11_ABI"
,
"FALSE"
)
==
"TRUE"
# For CI, we want the option to not add "--threads 4" to nvcc, since the runner can OOM
FORCE_SINGLE_THREAD
=
os
.
getenv
(
"FLASH_ATTENTION_FORCE_SINGLE_THREAD"
,
"FALSE"
)
==
"TRUE"
def
get_platform
():
def
get_platform
():
...
@@ -84,9 +83,7 @@ def check_if_cuda_home_none(global_option: str) -> None:
...
@@ -84,9 +83,7 @@ def check_if_cuda_home_none(global_option: str) -> None:
def
append_nvcc_threads
(
nvcc_extra_args
):
def
append_nvcc_threads
(
nvcc_extra_args
):
if
not
FORCE_SINGLE_THREAD
:
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
cmdclass
=
{}
cmdclass
=
{}
...
@@ -233,9 +230,9 @@ def get_wheel_url():
...
@@ -233,9 +230,9 @@ def get_wheel_url():
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
torch_version_raw
=
parse
(
torch
.
__version__
)
torch_version_raw
=
parse
(
torch
.
__version__
)
#
W
or
karound for nvcc 12.1 segfaults when compiling with Pytorch
2.
1
#
F
or
CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 1
2.
2
if
to
rch_version_raw
.
major
==
2
and
torch_version_raw
.
minor
==
1
and
torch_cuda_version
.
major
==
12
:
#
to
save CI time. Minor versions should be compatible.
torch_cuda_version
=
parse
(
"12.2"
)
torch_cuda_version
=
parse
(
"11.8"
)
if
torch_cuda_version
.
major
==
11
else
parse
(
"12.2"
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
platform_name
=
get_platform
()
flash_version
=
get_package_version
()
flash_version
=
get_package_version
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment