Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
684196b8
Unverified
Commit
684196b8
authored
Jul 23, 2023
by
Kiarash Jamali
Committed by
GitHub
Jul 23, 2023
Browse files
Allow rotary embeddings for Bert (#363)
parent
cbf982af
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
2 deletions
+7
-2
flash_attn/models/bert.py
flash_attn/models/bert.py
+7
-2
No files found.
flash_attn/models/bert.py
View file @
684196b8
...
@@ -52,10 +52,16 @@ logger = logging.getLogger(__name__)
...
@@ -52,10 +52,16 @@ logger = logging.getLogger(__name__)
def
create_mixer_cls
(
config
,
cross_attn
=
False
,
return_residual
=
False
):
def
create_mixer_cls
(
config
,
cross_attn
=
False
,
return_residual
=
False
):
use_flash_attn
=
getattr
(
config
,
'use_flash_attn'
,
False
)
use_flash_attn
=
getattr
(
config
,
'use_flash_attn'
,
False
)
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
rotary_kwargs
=
{}
if
config
.
position_embedding_type
==
"rotary"
:
rotary_kwargs
[
"rotary_emb_dim"
]
=
getattr
(
config
,
"rotary_emb_dim"
,
config
.
hidden_size
)
rotary_kwargs
[
"rotary_emb_base"
]
=
getattr
(
config
,
"rotary_emb_base"
,
10000.0
)
rotary_kwargs
[
"rotary_emb_scale_base"
]
=
getattr
(
config
,
"rotary_emb_scale_base"
,
None
)
rotary_kwargs
[
"rotary_emb_interleaved"
]
=
getattr
(
config
,
"rotary_emb_interleaved"
,
False
)
mixer_cls
=
partial
(
MHA
,
num_heads
=
config
.
num_attention_heads
,
cross_attn
=
cross_attn
,
mixer_cls
=
partial
(
MHA
,
num_heads
=
config
.
num_attention_heads
,
cross_attn
=
cross_attn
,
dropout
=
config
.
attention_probs_dropout_prob
,
causal
=
False
,
dropout
=
config
.
attention_probs_dropout_prob
,
causal
=
False
,
fused_bias_fc
=
fused_bias_fc
,
use_flash_attn
=
use_flash_attn
,
fused_bias_fc
=
fused_bias_fc
,
use_flash_attn
=
use_flash_attn
,
return_residual
=
return_residual
)
return_residual
=
return_residual
,
**
rotary_kwargs
)
return
mixer_cls
return
mixer_cls
...
@@ -298,7 +304,6 @@ class BertModel(BertPreTrainedModel):
...
@@ -298,7 +304,6 @@ class BertModel(BertPreTrainedModel):
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
)
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
assert
config
.
position_embedding_type
==
'absolute'
assert
config
.
hidden_act
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
]
assert
config
.
hidden_act
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
]
self
.
embeddings
=
BertEmbeddings
(
config
.
hidden_size
,
config
.
vocab_size
,
self
.
embeddings
=
BertEmbeddings
(
config
.
hidden_size
,
config
.
vocab_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment