Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
55797f32
Commit
55797f32
authored
Nov 10, 2022
by
Tri Dao
Browse files
Remove RotaryEmbedding from FlashAttention module
To avoid import error if one doesn't have rotary_emb installed
parent
6998e0ec
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
17 deletions
+3
-17
flash_attn/flash_attention.py
flash_attn/flash_attention.py
+3
-17
No files found.
flash_attn/flash_attention.py
View file @
55797f32
...
@@ -4,7 +4,6 @@ import torch.nn as nn
...
@@ -4,7 +4,6 @@ import torch.nn as nn
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.rotary
import
RotaryEmbedding
,
RotaryEmbedding2D
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_qkvpacked_func
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_qkvpacked_func
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
...
@@ -75,7 +74,7 @@ class FlashAttention(nn.Module):
...
@@ -75,7 +74,7 @@ class FlashAttention(nn.Module):
class
FlashMHA
(
nn
.
Module
):
class
FlashMHA
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
num_heads
,
bias
=
True
,
batch_first
=
True
,
attention_dropout
=
0.0
,
def
__init__
(
self
,
embed_dim
,
num_heads
,
bias
=
True
,
batch_first
=
True
,
attention_dropout
=
0.0
,
causal
=
False
,
use_rotary_emb
=
None
,
device
=
None
,
dtype
=
None
,
**
kwargs
)
->
None
:
causal
=
False
,
device
=
None
,
dtype
=
None
,
**
kwargs
)
->
None
:
assert
batch_first
assert
batch_first
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
super
().
__init__
()
...
@@ -85,14 +84,7 @@ class FlashMHA(nn.Module):
...
@@ -85,14 +84,7 @@ class FlashMHA(nn.Module):
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
assert
self
.
embed_dim
%
num_heads
==
0
,
"self.kdim must be divisible by num_heads"
assert
self
.
embed_dim
%
num_heads
==
0
,
"self.kdim must be divisible by num_heads"
self
.
head_dim
=
self
.
embed_dim
//
num_heads
self
.
head_dim
=
self
.
embed_dim
//
num_heads
assert
self
.
head_dim
in
[
16
,
32
,
64
,
128
],
"Only support head_dim == 16, 32, 64, or 128"
assert
self
.
head_dim
%
8
==
0
and
self
.
head_dim
<=
128
,
"Only support head_dim <= 128 and divisible by 8"
assert
use_rotary_emb
in
[
None
,
'1d'
,
'2d'
]
self
.
use_rotary_emb
=
use_rotary_emb
if
self
.
use_rotary_emb
==
'1d'
:
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
head_dim
)
elif
self
.
use_rotary_emb
==
'2d'
:
self
.
rotary_emb
=
RotaryEmbedding2D
(
self
.
head_dim
)
self
.
Wqkv
=
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wqkv
=
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
inner_attn
=
FlashAttention
(
attention_dropout
=
attention_dropout
,
**
factory_kwargs
)
self
.
inner_attn
=
FlashAttention
(
attention_dropout
=
attention_dropout
,
**
factory_kwargs
)
...
@@ -103,12 +95,6 @@ class FlashMHA(nn.Module):
...
@@ -103,12 +95,6 @@ class FlashMHA(nn.Module):
key_padding_mask: bool tensor of shape (batch, seqlen)
key_padding_mask: bool tensor of shape (batch, seqlen)
"""
"""
qkv
=
self
.
Wqkv
(
x
)
qkv
=
self
.
Wqkv
(
x
)
if
self
.
use_rotary_emb
:
query
,
key
,
value
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
h
=
self
.
num_heads
).
unbind
(
dim
=
2
)
query
,
key
=
self
.
rotary_emb
(
query
,
key
,
seq_dimension
=-
3
)
qkv
=
torch
.
stack
([
query
.
type
(
x
.
dtype
),
key
.
type
(
x
.
dtype
),
value
],
dim
=
2
)
else
:
qkv
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
h
=
self
.
num_heads
)
qkv
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
h
=
self
.
num_heads
)
context
,
attn_weights
=
self
.
inner_attn
(
qkv
,
key_padding_mask
=
key_padding_mask
,
context
,
attn_weights
=
self
.
inner_attn
(
qkv
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
causal
=
self
.
causal
)
need_weights
=
need_weights
,
causal
=
self
.
causal
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment