Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
1dc1b6c8
Commit
1dc1b6c8
authored
Sep 03, 2023
by
Tri Dao
Browse files
Bump to v2.1.2
parent
0c04943f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
11 additions
and
5 deletions
+11
-5
.github/workflows/publish.yml
.github/workflows/publish.yml
+7
-1
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+1
-1
flash_attn/__init__.py
flash_attn/__init__.py
+1
-1
training/Dockerfile
training/Dockerfile
+2
-2
No files found.
.github/workflows/publish.yml
View file @
1dc1b6c8
...
...
@@ -63,11 +63,17 @@ jobs:
# Pytorch <= 2.0 only supports CUDA <= 11.8
-
torch-version
:
'
1.12.1'
cuda-version
:
'
12.1.0'
-
torch-version
:
'
1.12.1'
cuda-version
:
'
12.2.0'
-
torch-version
:
'
1.13.1'
cuda-version
:
'
12.1.0'
-
torch-version
:
'
1.13.1'
cuda-version
:
'
12.2.0'
-
torch-version
:
'
2.0.1'
cuda-version
:
'
12.1.0'
# Pytorch >= 2.1 only supports CUDA 12.1
-
torch-version
:
'
2.0.1'
cuda-version
:
'
12.2.0'
# Pytorch >= 2.1 only supports CUDA >= 12.1
-
torch-version
:
'
2.1.0.dev20230731'
cuda-version
:
'
11.6.2'
-
torch-version
:
'
2.1.0.dev20230731'
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
1dc1b6c8
...
...
@@ -947,7 +947,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
acc_o_rowcol
);
++
mi
)
{
float
sum
=
scores_sum
(
mi
);
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
INFINITY
:
scores_max
(
mi
)
*
params
.
scale_softmax
+
__logf
(
sum
);
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
-
INFINITY
:
scores_max
(
mi
)
*
params
.
scale_softmax
+
__logf
(
sum
);
float
scale
=
inv_sum
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scale
;
}
...
...
flash_attn/__init__.py
View file @
1dc1b6c8
__version__
=
"2.1.
1
"
__version__
=
"2.1.
2
"
from
flash_attn.flash_attn_interface
import
(
flash_attn_func
,
...
...
training/Dockerfile
View file @
1dc1b6c8
...
...
@@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN
pip
install
git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
RUN
pip
install
flash-attn
==
2.1.
1
RUN
pip
install
flash-attn
==
2.1.
2
# Install CUDA extensions for cross-entropy, fused dense, layer norm
RUN
git clone https://github.com/HazyResearch/flash-attention
\
&&
cd
flash-attention
&&
git checkout v2.1.
1
\
&&
cd
flash-attention
&&
git checkout v2.1.
2
\
&&
cd
csrc/fused_softmax
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/rotary
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/xentropy
&&
pip
install
.
&&
cd
../../
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment