Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
318e2f1b
Unverified
Commit
318e2f1b
authored
Mar 15, 2023
by
Tri Dao
Committed by
GitHub
Mar 15, 2023
Browse files
Merge pull request #140 from VikParuchuri/main
Remove unused kwargs like device in FlashAttention
parents
e45a46a5
31653980
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
flash_attn/flash_attention.py
flash_attn/flash_attention.py
+3
-3
No files found.
flash_attn/flash_attention.py
View file @
318e2f1b
...
...
@@ -18,7 +18,7 @@ class FlashAttention(nn.Module):
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def
__init__
(
self
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
device
=
None
,
dtype
=
None
):
def
__init__
(
self
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
super
().
__init__
()
self
.
softmax_scale
=
softmax_scale
self
.
dropout_p
=
attention_dropout
...
...
@@ -74,7 +74,7 @@ class FlashAttention(nn.Module):
class
FlashMHA
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
num_heads
,
bias
=
True
,
batch_first
=
True
,
attention_dropout
=
0.0
,
causal
=
False
,
device
=
None
,
dtype
=
None
,
**
kwargs
)
->
None
:
causal
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
assert
batch_first
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
...
...
@@ -87,7 +87,7 @@ class FlashMHA(nn.Module):
assert
self
.
head_dim
%
8
==
0
and
self
.
head_dim
<=
128
,
"Only support head_dim <= 128 and divisible by 8"
self
.
Wqkv
=
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
inner_attn
=
FlashAttention
(
attention_dropout
=
attention_dropout
,
**
factory_kwargs
)
self
.
inner_attn
=
FlashAttention
(
attention_dropout
=
attention_dropout
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
key_padding_mask
=
None
,
need_weights
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment