Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
04fb1985
Unverified
Commit
04fb1985
authored
Sep 06, 2022
by
Tri Dao
Committed by
GitHub
Sep 06, 2022
Browse files
Merge pull request #43 from eric-tc-wong/patch-1
Update flash_attention.py
parents
19d12610
b410d14f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
1 deletion
+1
-1
flash_attn/flash_attention.py
flash_attn/flash_attention.py
+1
-1
No files found.
flash_attn/flash_attention.py
View file @
04fb1985
...
@@ -107,7 +107,7 @@ class FlashMHA(nn.Module):
...
@@ -107,7 +107,7 @@ class FlashMHA(nn.Module):
query
,
key
,
value
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
query
,
key
,
value
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
h
=
self
.
num_heads
).
unbind
(
dim
=
2
)
h
=
self
.
num_heads
).
unbind
(
dim
=
2
)
query
,
key
=
self
.
rotary_emb
(
query
,
key
,
seq_dimension
=-
3
)
query
,
key
=
self
.
rotary_emb
(
query
,
key
,
seq_dimension
=-
3
)
qkv
=
torch
.
stack
([
query
,
key
,
value
],
dim
=
2
)
qkv
=
torch
.
stack
([
query
.
type
(
x
.
dtype
),
key
.
type
(
x
.
dtype
)
,
value
],
dim
=
2
)
else
:
else
:
qkv
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
h
=
self
.
num_heads
)
qkv
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
h
=
self
.
num_heads
)
context
,
attn_weights
=
self
.
inner_attn
(
qkv
,
key_padding_mask
=
key_padding_mask
,
context
,
attn_weights
=
self
.
inner_attn
(
qkv
,
key_padding_mask
=
key_padding_mask
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment