Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
fa6d1ce4
Commit
fa6d1ce4
authored
Nov 13, 2022
by
Tri Dao
Browse files
Add fused_dense and dropout_add_layernorm CUDA extensions
parent
b92f2c3b
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
154 additions
and
0 deletions
+154
-0
tests/ops/test_fused_dense.py
tests/ops/test_fused_dense.py
+154
-0
No files found.
tests/ops/test_fused_dense.py
0 → 100644
View file @
fa6d1ce4
import
math
import
torch
import
torch.nn.functional
as
F
import
pytest
from
einops
import
rearrange
from
flash_attn.ops.fused_dense
import
FusedDenseTD
,
FusedDenseGeluDenseTD
from
flash_attn.ops.fused_dense
import
FusedDenseResidual
,
FusedDenseResGeluDense
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
def
test_fused_linear_bias
(
in_features
,
out_features
,
dtype
):
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseTD
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
out_pt
=
model_pt
(
x_pt
)
out
=
model
(
x
)
# with torch.no_grad():
# out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
out
)
/
32
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
weight
.
grad
,
model_pt
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'out_features,in_features'
,
[(
1024
,
1024
),
(
4096
,
4096
)])
def
test_fused_linear_bias_residual
(
in_features
,
out_features
,
dtype
):
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseResidual
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
bias
.
copy_
(
model_pt
.
bias
)
out_pt
=
model_pt
(
x_pt
)
+
F
.
gelu
(
x_pt
)
# Just add some random function of the residual x_pt
out
,
x_copy
=
model
(
x
)
out
=
out
+
F
.
gelu
(
x_copy
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
*
2
)
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
out
)
/
32
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
weight
.
grad
,
model_pt
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'heuristic'
,
[
1
,
-
1
])
@
pytest
.
mark
.
parametrize
(
'checkpoint_lvl'
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
def
test_fused_dense_gelu_dense
(
in_features
,
out_features
,
checkpoint_lvl
,
heuristic
,
dtype
):
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseGeluDenseTD
(
in_features
,
out_features
,
in_features
,
checkpoint_lvl
=
checkpoint_lvl
,
heuristic
=
heuristic
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
)
model
.
fc1
.
bias
.
copy_
(
model_pt_fc1
.
bias
)
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
)
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
'tanh'
))
out
=
model
(
x
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
)
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
out
)
/
32
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt_fc2
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'checkpoint_lvl'
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'out_features'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'in_features'
,
[
1024
,
4096
])
def
test_fused_dense_residual_gelu_dense
(
in_features
,
out_features
,
checkpoint_lvl
,
dtype
):
device
=
'cuda'
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDenseResGeluDense
(
in_features
,
out_features
,
in_features
,
checkpoint_lvl
=
checkpoint_lvl
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
)
model
.
fc1
.
bias
.
copy_
(
model_pt_fc1
.
bias
)
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
)
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
out_pt
=
model_pt_fc2
(
F
.
gelu
(
model_pt_fc1
(
x_pt
),
approximate
=
'tanh'
))
+
F
.
gelu
(
x_pt
)
out
,
x_copy
=
model
(
x
)
out
=
out
+
F
.
gelu
(
x_copy
)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
atol
*
2
)
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
out
)
/
32
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt_fc2
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment