Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d30f2e1c
Commit
d30f2e1c
authored
Aug 01, 2023
by
Tri Dao
Browse files
Bump to v2.0.4
parent
1c41d2b0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
5 deletions
+5
-5
README.md
README.md
+2
-2
flash_attn/__init__.py
flash_attn/__init__.py
+1
-1
training/Dockerfile
training/Dockerfile
+2
-2
No files found.
README.md
View file @
d30f2e1c
...
...
@@ -101,7 +101,7 @@ Return:
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in
KV
must be divisible by the number of heads in
Q
.
than Q. Note that the number of heads in
Q
must be divisible by the number of heads in
KV
.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
...
...
@@ -131,7 +131,7 @@ These functions have been renamed:
If the inputs have the same sequence lengths in the same batch, it is simpler
and faster to use these functions:
```
python
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
)
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
)
```
```
python
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
)
...
...
flash_attn/__init__.py
View file @
d30f2e1c
__version__
=
"2.0.
3
"
__version__
=
"2.0.
4
"
from
flash_attn.flash_attn_interface
import
flash_attn_func
from
flash_attn.flash_attn_interface
import
flash_attn_kvpacked_func
...
...
training/Dockerfile
View file @
d30f2e1c
...
...
@@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN
pip
install
git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
RUN
pip
install
flash-attn
==
2.0.
3
RUN
pip
install
flash-attn
==
2.0.
4
# Install CUDA extensions for cross-entropy, fused dense, layer norm
RUN
git clone https://github.com/HazyResearch/flash-attention
\
&&
cd
flash-attention
&&
git checkout v2.0.
3
\
&&
cd
flash-attention
&&
git checkout v2.0.
4
\
&&
cd
csrc/fused_softmax
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/rotary
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/xentropy
&&
pip
install
.
&&
cd
../../
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment