Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
844912dc
Commit
844912dc
authored
Jul 11, 2024
by
Tri Dao
Browse files
[CI] Switch from CUDA 12.2 to 12.3
parent
40e534a7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
6 deletions
+6
-6
.github/workflows/publish.yml
.github/workflows/publish.yml
+4
-4
setup.py
setup.py
+2
-2
No files found.
.github/workflows/publish.yml
View file @
844912dc
...
...
@@ -45,7 +45,7 @@ jobs:
os
:
[
ubuntu-20.04
]
python-version
:
[
'
3.8'
,
'
3.9'
,
'
3.10'
,
'
3.11'
,
'
3.12'
]
torch-version
:
[
'
2.0.1'
,
'
2.1.2'
,
'
2.2.2'
,
'
2.3.1'
,
'
2.4.0.dev20240514'
]
cuda-version
:
[
'
11.8.0'
,
'
12.
2
.2'
]
cuda-version
:
[
'
11.8.0'
,
'
12.
3
.2'
]
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
...
...
@@ -60,7 +60,7 @@ jobs:
python-version
:
'
3.12'
# Pytorch <= 2.0 only supports CUDA <= 11.8
-
torch-version
:
'
2.0.1'
cuda-version
:
'
12.
2
.2'
cuda-version
:
'
12.
3
.2'
steps
:
-
name
:
Checkout
...
...
@@ -145,8 +145,8 @@ jobs:
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Limit MAX_JOBS otherwise the github runner goes OOM
# CUDA 11.8 can compile with 2 jobs, but CUDA 12.
2
goes OOM
MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "12
2
" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
# CUDA 11.8 can compile with 2 jobs, but CUDA 12.
3
goes OOM
MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "12
3
" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
...
...
setup.py
View file @
844912dc
...
...
@@ -269,9 +269,9 @@ def get_wheel_url():
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
torch_version_raw
=
parse
(
torch
.
__version__
)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.
2
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.
3
# to save CI time. Minor versions should be compatible.
torch_cuda_version
=
parse
(
"11.8"
)
if
torch_cuda_version
.
major
==
11
else
parse
(
"12.
2
"
)
torch_cuda_version
=
parse
(
"11.8"
)
if
torch_cuda_version
.
major
==
11
else
parse
(
"12.
3
"
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
flash_version
=
get_package_version
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment