Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
4c98d0b4
Commit
4c98d0b4
authored
Jul 26, 2023
by
Tri Dao
Browse files
[MLP] Edit ParallelGatedMlp
parent
8ee62efc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
13 deletions
+18
-13
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+18
-13
No files found.
flash_attn/modules/mlp.py
View file @
4c98d0b4
...
@@ -11,10 +11,9 @@ except ImportError:
...
@@ -11,10 +11,9 @@ except ImportError:
ColumnParallelLinear
,
RowParallelLinear
=
None
,
None
ColumnParallelLinear
,
RowParallelLinear
=
None
,
None
try
:
try
:
from
flash_attn.ops.fused_dense
import
FusedMLP
,
ParallelFusedMLP
,
ColumnParallelLinear
,
RowParallelLinear
from
flash_attn.ops.fused_dense
import
FusedMLP
,
ParallelFusedMLP
except
ImportError
:
except
ImportError
:
FusedMLP
,
ParallelFusedMLP
=
None
,
None
FusedMLP
,
ParallelFusedMLP
=
None
,
None
ColumnParallelLinear
,
RowParallelLinear
=
None
,
None
class
Mlp
(
nn
.
Module
):
class
Mlp
(
nn
.
Module
):
...
@@ -87,25 +86,31 @@ class GatedMlp(nn.Module):
...
@@ -87,25 +86,31 @@ class GatedMlp(nn.Module):
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
class
ParallelGatedMlp
(
GatedMlp
):
class
ParallelGatedMlp
(
nn
.
Module
):
""" Parallel GatedMlp """
""" Parallel GatedMlp """
def
__init__
(
self
,
in_features
,
process_group
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
sigmoid
,
def
__init__
(
self
,
in_features
,
process_group
,
hidden_features
=
None
,
out_features
=
None
,
bias1
=
True
,
bias2
=
True
,
multiple_of
=
256
,
return_residual
=
False
,
activation
=
F
.
sigmoid
,
bias1
=
True
,
bias2
=
True
,
multiple_of
=
256
,
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
):
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
(
in_features
,
hidden_features
=
hidden_features
,
out_features
=
out_features
,
activation
=
activation
,
super
().
__init__
()
bias1
=
bias1
,
bias2
=
bias2
,
multiple_of
=
multiple_of
,
return_residual
=
return_residual
,
device
=
device
,
dtype
=
dtype
)
out_features
=
out_features
or
in_features
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
int
(
8
*
in_features
/
3
)
hidden_features
=
hidden_features
or
int
(
8
*
in_features
/
3
)
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
2
*
hidden_features
,
process_group
,
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
2
*
hidden_features
,
process_group
,
bias
=
bias1
,
bias
=
bias1
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
self
.
activation
=
activation
bias
=
bias2
,
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
bias
=
bias2
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
y
=
self
.
fc1
(
x
)
if
self
.
activation
==
F
.
sigmoid
:
# Special case for GLU
y
=
F
.
glu
(
y
,
dim
=-
1
)
else
:
y
,
gate
=
y
.
chunk
(
2
,
dim
=-
1
)
y
=
y
*
self
.
activation
(
gate
)
y
=
self
.
fc2
(
y
)
return
y
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment