Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
799f56fa
"include/ck/utility/reduction_operator.hpp" did not exist on "b3e8d57d51300b88b591900621f71b6a1b3a7acc"
Commit
799f56fa
authored
Sep 17, 2023
by
Tri Dao
Browse files
Don't compile for Pytorch 2.1 on CUDA 12.1 due to nvcc segfaults
parent
c984208d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
10 additions
and
6 deletions
+10
-6
.github/workflows/publish.yml
.github/workflows/publish.yml
+5
-0
flash_attn/__init__.py
flash_attn/__init__.py
+1
-1
setup.py
setup.py
+2
-0
training/Dockerfile
training/Dockerfile
+2
-5
No files found.
.github/workflows/publish.yml
View file @
799f56fa
...
...
@@ -80,6 +80,11 @@ jobs:
cuda-version
:
'
11.7.1'
-
torch-version
:
'
2.1.0.dev20230731'
cuda-version
:
'
11.8.0'
# Pytorch >= 2.1 with nvcc 12.1.0 segfaults during compilation, so
# we only use CUDA 12.2. setup.py as a special case that will
# download the wheel for CUDA 12.2 instead.
-
torch-version
:
'
2.1.0.dev20230731'
cuda-version
:
'
12.1.0'
steps
:
-
name
:
Checkout
...
...
flash_attn/__init__.py
View file @
799f56fa
__version__
=
"2.2.3.post
1
"
__version__
=
"2.2.3.post
2
"
from
flash_attn.flash_attn_interface
import
(
flash_attn_func
,
...
...
setup.py
View file @
799f56fa
...
...
@@ -223,6 +223,8 @@ def get_wheel_url():
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
torch_version_raw
=
parse
(
torch
.
__version__
)
if
torch_version_raw
.
major
==
2
and
torch_version_raw
.
minor
==
1
:
torch_cuda_version
=
parse
(
"12.2"
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
flash_version
=
get_package_version
()
...
...
training/Dockerfile
View file @
799f56fa
...
...
@@ -85,14 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN
pip
install
git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
RUN
pip
install
flash-attn
==
2.2.3.post
1
RUN
pip
install
flash-attn
==
2.2.3.post
2
# Install CUDA extensions for cross-entropy, fused dense, layer norm
RUN
git clone https://github.com/HazyResearch/flash-attention
\
&&
cd
flash-attention
&&
git checkout v2.2.3.post1
\
&&
cd
csrc/fused_softmax
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/rotary
&&
pip
install
.
&&
cd
../../
\
&&
cd
flash-attention
&&
git checkout v2.2.3.post2
\
&&
cd
csrc/layer_norm
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/fused_dense_lib
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/ft_attention
&&
pip
install
.
&&
cd
../../
\
&&
cd
..
&&
rm
-rf
flash-attention
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment