Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
081c2b01
Unverified
Commit
081c2b01
authored
Apr 13, 2023
by
Kirthi Shankar Sivamani
Committed by
GitHub
Apr 13, 2023
Browse files
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture
parents
7d25a4ec
1c9ef9b3
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
46 additions
and
34 deletions
+46
-34
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+2
-2
flash_attn/ops/activations.py
flash_attn/ops/activations.py
+18
-1
flash_attn/ops/fused_dense.py
flash_attn/ops/fused_dense.py
+19
-18
flash_attn/ops/triton/mlp.py
flash_attn/ops/triton/mlp.py
+1
-11
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+6
-2
No files found.
flash_attn/models/gpt.py
View file @
081c2b01
...
...
@@ -93,7 +93,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
fused_mlp
=
getattr
(
config
,
'fused_mlp'
,
False
)
if
fused_mlp
:
assert
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
,
'relu'
]
assert
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
,
'relu'
,
'sqrelu'
]
fused_dense_sqrelu_dense
=
getattr
(
config
,
'fused_dense_sqrelu_dense'
,
False
)
if
fused_dense_sqrelu_dense
:
assert
config
.
activation_function
==
'sqrelu'
,
(
'fused_dense_sqrelu_dense only '
...
...
@@ -123,7 +123,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if
FusedMLP
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
activation
=
(
'gelu_approx'
if
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
]
else
'relu'
)
in
[
'gelu_new'
,
'gelu_fast'
,
'gelu_approx'
]
else
config
.
activation_function
)
mlp_cls
=
FusedMLP
if
process_group
is
None
else
ParallelFusedMLP
parallel_kwargs
=
({
'process_group'
:
process_group
,
'sequence_parallel'
:
getattr
(
config
,
'sequence_parallel'
,
True
)}
...
...
flash_attn/ops/
gelu_
activation.py
→
flash_attn/ops/activation
s
.py
View file @
081c2b01
...
...
@@ -2,7 +2,8 @@
import
math
import
torch
from
torch
import
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
# 1/sqrt(2*pi)-> 0.3989423
...
...
@@ -80,3 +81,19 @@ class FastGeLUFunction(torch.autograd.Function):
return
tmp
fast_gelu_impl
=
FastGeLUFunction
.
apply
@
torch
.
jit
.
script
def
relu_bwd
(
g
,
x
):
return
torch
.
where
(
x
>=
0
,
g
,
0.0
).
to
(
dtype
=
x
.
dtype
)
@
torch
.
jit
.
script
def
sqrelu_fwd
(
x
):
r
=
F
.
relu
(
x
)
return
(
r
*
r
).
to
(
dtype
=
x
.
dtype
)
@
torch
.
jit
.
script
def
sqrelu_bwd
(
g
,
x
):
return
(
2.0
*
g
*
F
.
relu
(
x
)).
to
(
dtype
=
x
.
dtype
)
flash_attn/ops/fused_dense.py
View file @
081c2b01
...
...
@@ -15,16 +15,11 @@ from torch.cuda.amp import custom_bwd, custom_fwd
# import fused_dense_cuda # from apex
import
fused_dense_lib
as
fused_dense_cuda
from
flash_attn.ops.
gelu_
activation
import
gelu_bwd
from
flash_attn.ops.activation
s
import
gelu_bwd
,
relu_bwd
,
sqrelu_fwd
,
sqrelu_bwd
from
flash_attn.utils.distributed
import
all_gather_raw
,
reduce_scatter_raw
,
all_reduce_raw
from
flash_attn.utils.distributed
import
reduce_scatter
,
all_reduce
@
torch
.
jit
.
script
def
relu_bwd
(
g
,
x
):
return
torch
.
where
(
x
>=
0
,
g
,
0.0
).
to
(
dtype
=
x
.
dtype
)
class
FusedDenseFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
...
...
@@ -209,7 +204,9 @@ class FusedMLPFunc(torch.autograd.Function):
2: recompute pre_act and gelu_out / relu_out in the bwd
"""
assert
-
1
<=
heuristic
<=
4
assert
activation
in
[
'gelu_approx'
,
'relu'
]
assert
activation
in
[
'gelu_approx'
,
'relu'
,
'sqrelu'
]
if
activation
==
'sqrelu'
:
assert
heuristic
==
-
1
if
not
save_pre_act
:
checkpoint_lvl
=
2
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
...
...
@@ -248,8 +245,9 @@ class FusedMLPFunc(torch.autograd.Function):
if
heuristic
==
-
1
:
pre_act
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
'tanh'
)
if
activation
==
'gelu_approx'
else
F
.
relu
)
output1
=
activation_fn
(
pre_act
)
else
(
sqrelu_fwd
if
activation
==
'sqrelu'
else
F
.
relu
))
with
torch
.
jit
.
fuser
(
'fuser2'
):
output1
=
activation_fn
(
pre_act
)
# This is before adding bias1
# pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
# with torch.jit.fuser('fuser2'):
...
...
@@ -279,7 +277,7 @@ class FusedMLPFunc(torch.autograd.Function):
checkpoint_lvl
=
ctx
.
checkpoint_lvl
activation
=
ctx
.
activation
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
'tanh'
)
if
activation
==
'gelu_approx'
else
F
.
relu
)
else
(
sqrelu_fwd
if
activation
==
'sqrelu'
else
F
.
relu
)
)
if
ctx
.
return_residual
:
grad_input
,
=
args
grad_input
=
grad_input
.
contiguous
()
...
...
@@ -297,14 +295,16 @@ class FusedMLPFunc(torch.autograd.Function):
pre_act
,
output1
=
rest
elif
checkpoint_lvl
==
1
:
pre_act
,
=
rest
output1
=
activation_fn
(
pre_act
)
with
torch
.
jit
.
fuser
(
'fuser2'
):
output1
=
activation_fn
(
pre_act
)
elif
checkpoint_lvl
==
2
:
bias1
,
=
rest
if
process_group
is
not
None
and
sequence_parallel
:
total_x
,
_
=
all_gather_raw
(
x
,
process_group
)
if
ctx
.
heuristic
==
-
1
:
pre_act
=
F
.
linear
(
total_x
,
weight1
,
bias1
)
output1
=
activation_fn
(
pre_act
)
with
torch
.
jit
.
fuser
(
'fuser2'
):
output1
=
activation_fn
(
pre_act
)
else
:
output1
,
pre_act
=
fused_dense_cuda
.
linear_act_forward
(
total_x
.
reshape
(
batch_dim
,
total_x
.
shape
[
-
1
]),
weight1
,
bias1
,
...
...
@@ -324,8 +324,9 @@ class FusedMLPFunc(torch.autograd.Function):
if
ctx
.
heuristic
==
-
1
:
# grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
grad_output1
=
F
.
linear
(
grad_output
,
weight2
.
t
())
activation_grad_fn
=
(
gelu_bwd
if
activation
==
'gelu_approx'
else
(
sqrelu_bwd
if
activation
==
'sqrelu'
else
relu_bwd
))
with
torch
.
jit
.
fuser
(
'fuser2'
):
activation_grad_fn
=
gelu_bwd
if
activation
==
'gelu_approx'
else
relu_bwd
grad_pre_act
=
activation_grad_fn
(
grad_output1
,
pre_act
)
else
:
# The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
...
...
@@ -380,7 +381,7 @@ def fused_mlp_func(
process_group
:
Optional
[
ProcessGroup
]
=
None
,
sequence_parallel
:
bool
=
True
):
assert
activation
in
[
'gelu_approx'
,
'relu'
]
assert
activation
in
[
'gelu_approx'
,
'relu'
,
'sqrelu'
]
dtype_eligible
=
(
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
(
x
.
dtype
==
torch
.
float32
and
torch
.
is_autocast_enabled
()))
# If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
...
...
@@ -428,7 +429,7 @@ class FusedMLP(nn.Module):
to fuse the backward of nn.Linear with the residual connection.
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
activation
in
[
'gelu_approx'
,
'relu'
]
assert
activation
in
[
'gelu_approx'
,
'relu'
,
'sqrelu'
]
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
if
out_features
is
None
:
...
...
@@ -436,7 +437,7 @@ class FusedMLP(nn.Module):
self
.
activation
=
activation
self
.
return_residual
=
return_residual
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
heuristic
=
heuristic
self
.
heuristic
=
heuristic
if
activation
!=
'sqrelu'
else
-
1
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
bias
=
bias2
,
**
factory_kwargs
)
...
...
@@ -489,7 +490,7 @@ class ParallelFusedMLP(nn.Module):
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
"""
assert
checkpoint_lvl
in
[
0
,
1
,
2
]
assert
activation
in
[
'gelu_approx'
,
'relu'
]
assert
activation
in
[
'gelu_approx'
,
'relu'
,
'sqrelu'
]
assert
process_group
is
not
None
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
...
...
@@ -499,7 +500,7 @@ class ParallelFusedMLP(nn.Module):
self
.
process_group
=
process_group
self
.
sequence_parallel
=
sequence_parallel
self
.
checkpoint_lvl
=
checkpoint_lvl
self
.
heuristic
=
heuristic
self
.
heuristic
=
heuristic
if
activation
!=
'sqrelu'
else
-
1
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
process_group
,
bias
=
bias1
,
**
factory_kwargs
)
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
out_features
,
process_group
,
...
...
flash_attn/ops/triton/mlp.py
View file @
081c2b01
...
...
@@ -8,17 +8,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
import
fused_dense_lib
as
fused_dense_cuda
from
flash_attn.ops.triton.linear
import
triton_linear_act
,
triton_dgrad_act
@
torch
.
jit
.
script
def
sqrelu_fwd
(
x
):
r
=
F
.
relu
(
x
)
return
(
r
*
r
).
to
(
dtype
=
x
.
dtype
)
@
torch
.
jit
.
script
def
sqrelu_bwd
(
g
,
x
):
return
(
2.0
*
g
*
F
.
relu
(
x
)).
to
(
dtype
=
x
.
dtype
)
from
flash_attn.ops.activations
import
sqrelu_fwd
,
sqrelu_bwd
class
FusedDenseSqreluDenseFunc
(
torch
.
autograd
.
Function
):
...
...
flash_attn/utils/generation.py
View file @
081c2b01
...
...
@@ -107,10 +107,12 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
fused_ft_kernel
=
fused_ft_kernel
)
scores
=
[]
with
torch
.
inference_mode
():
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
if
timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
if
vocab_size
is
not
None
:
logits
=
logits
[...,
:
vocab_size
]
scores
.
append
(
logits
if
not
cg
else
logits
.
clone
())
...
...
@@ -143,8 +145,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
break
if
timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
print
(
f
'
D
ecoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
print
(
f
'
Prompt processing + d
ecoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment