Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
364a5b4a
Commit
364a5b4a
authored
Aug 10, 2023
by
Tri Dao
Browse files
[MLP] Change the check for out_features being None
parent
d30f2e1c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
8 deletions
+10
-8
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+10
-8
No files found.
flash_attn/modules/mlp.py
View file @
364a5b4a
...
...
@@ -22,8 +22,8 @@ class Mlp(nn.Module):
bias1
=
True
,
bias2
=
True
,
return_residual
=
False
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
*
4
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
hidden_features
if
hidden_features
is
not
None
else
in_features
*
4
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
activation
=
activation
...
...
@@ -45,8 +45,8 @@ class ParallelMLP(nn.Module):
super
().
__init__
()
assert
ColumnParallelLinear
is
not
None
,
"Need to install fused_dense"
assert
RowParallelLinear
is
not
None
,
"Need to install fused_dense"
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
*
4
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
hidden_features
if
hidden_features
is
not
None
else
in_features
*
4
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
process_group
,
bias
=
bias1
,
sequence_parallel
=
sequence_parallel
,
**
factory_kwargs
)
self
.
activation
=
activation
...
...
@@ -67,8 +67,9 @@ class GatedMlp(nn.Module):
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
int
(
8
*
in_features
/
3
)
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
(
hidden_features
if
hidden_features
is
not
None
else
int
(
8
*
in_features
/
3
))
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
2
*
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
...
...
@@ -94,8 +95,9 @@ class ParallelGatedMlp(nn.Module):
sequence_parallel
=
True
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
int
(
8
*
in_features
/
3
)
out_features
=
out_features
if
out_features
is
not
None
else
in_features
hidden_features
=
(
hidden_features
if
hidden_features
is
not
None
else
int
(
8
*
in_features
/
3
))
hidden_features
=
(
hidden_features
+
multiple_of
-
1
)
//
multiple_of
*
multiple_of
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment