Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
3557e0bb
Commit
3557e0bb
authored
Sep 04, 2023
by
Tri Dao
Browse files
[MLP] Implement SwiGLU with torch jiterator
parent
37c6e054
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
1 deletion
+41
-1
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+10
-1
flash_attn/ops/activations.py
flash_attn/ops/activations.py
+31
-0
No files found.
flash_attn/modules/mlp.py
View file @
3557e0bb
# Copyright (c) 202
2
, Tri Dao.
# Copyright (c) 202
3
, Tri Dao.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.distributed
import
ProcessGroup
try
:
from
flash_attn.ops.activations
import
swiglu
except
ImportError
:
swiglu
=
None
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
,
RowParallelLinear
except
ImportError
:
...
...
@@ -120,6 +126,9 @@ class GatedMlp(nn.Module):
y
=
self
.
fc1
(
x
)
if
self
.
activation
==
F
.
sigmoid
:
# Special case for GLU
y
=
F
.
glu
(
y
,
dim
=-
1
)
elif
self
.
activation
==
F
.
silu
and
swiglu
is
not
None
:
# Special case for SwiGLU
y
,
gate
=
y
.
chunk
(
2
,
dim
=-
1
)
y
=
swiglu
(
gate
,
y
)
else
:
y
,
gate
=
y
.
chunk
(
2
,
dim
=-
1
)
y
=
y
*
self
.
activation
(
gate
)
...
...
flash_attn/ops/activations.py
View file @
3557e0bb
...
...
@@ -102,3 +102,34 @@ def sqrelu_fwd(x):
@
torch
.
jit
.
script
def
sqrelu_bwd
(
g
,
x
):
return
(
2.0
*
g
*
F
.
relu
(
x
)).
to
(
dtype
=
x
.
dtype
)
swiglu_fwd_codestring
=
"""
template <typename T> T swiglu_fwd(T x, T y) {
return float(x) * float(y) / (1.0f + ::exp(-float(x)));
}
"""
swiglu_bwd_codestring
=
"""
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
dy = float(x) * x_sigmoid * float(g);
}
"""
swiglu_fwd
=
torch
.
cuda
.
jiterator
.
_create_jit_fn
(
swiglu_fwd_codestring
)
swiglu_bwd
=
torch
.
cuda
.
jiterator
.
_create_multi_output_jit_fn
(
swiglu_bwd_codestring
,
num_outputs
=
2
)
class
SwiGLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
y
):
ctx
.
save_for_backward
(
x
,
y
)
return
swiglu_fwd
(
x
,
y
)
@
staticmethod
def
backward
(
ctx
,
dout
):
x
,
y
=
ctx
.
saved_tensors
return
swiglu_bwd
(
x
,
y
,
dout
)
swiglu
=
SwiGLUFunction
.
apply
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment