Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
6fc1e07d
Commit
6fc1e07d
authored
Jul 21, 2023
by
Tri Dao
Browse files
[Block] Re-enable DropPath
parent
9ee0ff1d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
19 deletions
+17
-19
flash_attn/modules/block.py
flash_attn/modules/block.py
+17
-19
No files found.
flash_attn/modules/block.py
View file @
6fc1e07d
...
@@ -8,7 +8,7 @@ import torch.nn as nn
...
@@ -8,7 +8,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch
import
Tensor
#
from torchvision.ops import StochasticDepth
from
torchvision.ops
import
StochasticDepth
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
Mlp
from
flash_attn.modules.mlp
import
Mlp
...
@@ -70,12 +70,12 @@ class Block(nn.Module):
...
@@ -70,12 +70,12 @@ class Block(nn.Module):
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
4
*
dim
)
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
4
*
dim
)
self
.
mixer
=
mixer_cls
(
dim
)
self
.
mixer
=
mixer_cls
(
dim
)
self
.
dropout1
=
dropout_cls
(
resid_dropout1
)
self
.
dropout1
=
dropout_cls
(
resid_dropout1
)
#
self.drop_path1 = StochasticDepth(drop_path1, mode='row')
self
.
drop_path1
=
StochasticDepth
(
drop_path1
,
mode
=
'row'
)
self
.
norm1
=
norm_cls
(
dim
)
self
.
norm1
=
norm_cls
(
dim
)
self
.
mlp
=
mlp_cls
(
dim
)
self
.
mlp
=
mlp_cls
(
dim
)
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
self
.
dropout2
=
dropout_cls
(
resid_dropout2
)
self
.
dropout2
=
dropout_cls
(
resid_dropout2
)
#
self.drop_path2 = StochasticDepth(drop_path2, mode='row')
self
.
drop_path2
=
StochasticDepth
(
drop_path2
,
mode
=
'row'
)
self
.
norm2
=
norm_cls
(
dim
)
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
if
self
.
fused_dropout_add_ln
:
...
@@ -129,14 +129,13 @@ class Block(nn.Module):
...
@@ -129,14 +129,13 @@ class Block(nn.Module):
if
self
.
residual_in_fp32
:
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
residual
=
residual
.
to
(
torch
.
float32
)
else
:
else
:
if
self
.
drop_path1
.
p
==
0
or
not
self
.
training
:
rowscale1
=
None
rowscale1
=
None
# if self.drop_path1.p == 0 or not self.training:
else
:
# rowscale1 = None
rowscale1
=
self
.
drop_path1
(
torch
.
ones
(
# else:
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
# rowscale1 = self.drop_path1(torch.ones(
dtype
=
hidden_states
.
dtype
)
# hidden_states.shape[:-1], device=hidden_states.device,
)
# dtype=hidden_states.dtype)
# )
hidden_states
,
residual
=
fused_add_norm_fn
(
hidden_states
,
residual
=
fused_add_norm_fn
(
hidden_states
,
residual
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
hidden_states
,
residual
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
...
@@ -157,14 +156,13 @@ class Block(nn.Module):
...
@@ -157,14 +156,13 @@ class Block(nn.Module):
if
self
.
residual_in_fp32
:
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
residual
=
residual
.
to
(
torch
.
float32
)
else
:
else
:
# if self.drop_path2.p == 0 or not self.training:
if
self
.
drop_path2
.
p
==
0
or
not
self
.
training
:
# rowscale2 = None
# else:
# rowscale2 = self.drop_path2(torch.ones(
# hidden_states.shape[:-1], device=hidden_states.device,
# dtype=hidden_states.dtype)
# )
rowscale2
=
None
rowscale2
=
None
else
:
rowscale2
=
self
.
drop_path2
(
torch
.
ones
(
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
)
hidden_states
,
residual
=
fused_add_norm_fn
(
hidden_states
,
residual
=
fused_add_norm_fn
(
hidden_states
,
residual
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
hidden_states
,
residual
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
self
.
norm2
.
eps
,
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
self
.
norm2
.
eps
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment