Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
1feb9426
Commit
1feb9426
authored
Nov 23, 2022
by
Tri Dao
Browse files
[ViT] Use dropout_add_ln for the 1st layer norm
parent
45bcf37b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
10 deletions
+27
-10
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+7
-7
flash_attn/models/vit.py
flash_attn/models/vit.py
+20
-2
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+0
-1
No files found.
flash_attn/models/gpt.py
View file @
1feb9426
...
...
@@ -104,14 +104,14 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid
nn
.
init
.
normal_
(
p
,
mean
=
0.0
,
std
=
initializer_range
/
math
.
sqrt
(
2
*
n_layer
))
class
GPT
2
Model
(
nn
.
Module
):
class
GPTModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPT2Config
):
super
().
__init__
()
self
.
pad_vocab_size_multiple
_8
=
getattr
(
config
,
'pad_vocab_size_multiple
_8
'
,
False
)
if
self
.
pad_vocab_size_multiple
_8
:
if
config
.
vocab_size
%
8
!=
0
:
config
.
vocab_size
+=
8
-
(
config
.
vocab_size
%
8
)
self
.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
if
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
!=
0
:
config
.
vocab_size
+=
(
self
.
pad_vocab_size_multiple
-
(
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
)
)
self
.
embeddings
=
GPT2Embeddings
(
config
.
hidden_size
,
config
.
vocab_size
,
config
.
max_position_embeddings
)
...
...
@@ -153,11 +153,11 @@ class GPT2Model(nn.Module):
return
hidden_states
class
GPT
2
LMHeadModel
(
nn
.
Module
):
class
GPTLMHeadModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPT2Config
):
super
().
__init__
()
self
.
transformer
=
GPT
2
Model
(
config
)
self
.
transformer
=
GPTModel
(
config
)
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
,
bias
=
False
)
# Initialize weights and apply final processing
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
config
.
num_hidden_layers
,
...
...
flash_attn/models/vit.py
View file @
1feb9426
...
...
@@ -18,6 +18,11 @@ from flash_attn.modules.mha import MHA
from
flash_attn.modules.mlp
import
Mlp
,
FusedDenseGeluDense
from
flash_attn.modules.block
import
Block
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
except
ImportError
:
dropout_add_layer_norm
=
None
def
create_mixer_cls
(
num_heads
,
qkv_bias
,
attn_drop
,
use_flash_attn
,
fused_bias_fc
,
cross_attn
=
False
):
...
...
@@ -152,6 +157,10 @@ class VisionTransformer(nn.Module):
# (in the pretrained weight) is the final layer norm.
self
.
norm_0
=
norm_layer
(
embed_dim
)
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
if
self
.
fused_dropout_add_ln
and
dropout_add_layer_norm
is
None
:
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
self
.
blocks
=
nn
.
ModuleList
([
create_block
(
embed_dim
,
num_heads
,
mlp_ratio
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path
=
dpr
[
i
],
norm_layer
=
norm_layer
,
act_layer
=
act_layer
,
use_flash_attn
=
use_flash_attn
,
...
...
@@ -193,7 +202,7 @@ class VisionTransformer(nn.Module):
if
self
.
cls_token
is
not
None
:
x
=
torch
.
cat
((
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
),
x
),
dim
=
1
)
x
=
x
+
self
.
pos_embed
return
self
.
pos_drop
(
x
)
return
x
def
forward_features
(
self
,
x
,
all_tokens
=
True
):
"""
...
...
@@ -201,8 +210,17 @@ class VisionTransformer(nn.Module):
cls token.
"""
x
=
self
.
patch_embed
(
x
)
x
=
self
.
_pos_embed
(
x
)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
residual
=
self
.
_pos_embed
(
x
).
float
()
if
not
self
.
fused_dropout_add_ln
:
residual
=
self
.
pos_drop
(
x
).
float
()
hidden_states
=
self
.
norm_0
(
residual
.
to
(
dtype
=
self
.
norm_0
.
weight
.
dtype
))
else
:
hidden_states
,
residual
=
dropout_add_layer_norm
(
x
,
None
,
self
.
norm_0
.
weight
,
self
.
norm_0
.
bias
,
self
.
pos_drop
.
p
if
self
.
training
else
0.0
,
self
.
norm_0
.
eps
,
prenorm
=
True
,
residual_in_fp32
=
True
)
hidden_states
=
self
.
norm_0
(
residual
.
to
(
dtype
=
self
.
norm_0
.
weight
.
dtype
))
if
self
.
global_pool
!=
'token'
or
all_tokens
:
for
block
in
self
.
blocks
:
...
...
flash_attn/modules/mlp.py
View file @
1feb9426
...
...
@@ -64,7 +64,6 @@ class FusedDenseGeluDense(nn.Module):
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
assert
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
x
.
is_cuda
fn
=
(
fused_dense_gelu_dense_function_td
if
not
self
.
return_residual
else
fused_dense_res_gelu_dense_function_td
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment