Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
0e8c46ae
Commit
0e8c46ae
authored
Aug 18, 2023
by
Tri Dao
Browse files
Run isort and black on test files
parent
7fcd3e6a
Changes
24
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
891 additions
and
476 deletions
+891
-476
tests/ops/test_fused_dense.py
tests/ops/test_fused_dense.py
+83
-44
tests/ops/test_fused_dense_parallel.py
tests/ops/test_fused_dense_parallel.py
+113
-72
tests/test_flash_attn.py
tests/test_flash_attn.py
+684
-350
tests/test_rotary.py
tests/test_rotary.py
+11
-10
No files found.
tests/ops/test_fused_dense.py
View file @
0e8c46ae
import
math
import
math
from
functools
import
partial
from
functools
import
partial
import
pytest
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
pytest
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.ops.fused_dense
import
FusedDense
,
FusedMLP
from
flash_attn.ops.fused_dense
import
FusedDense
,
FusedMLP
@
pytest
.
mark
.
parametrize
(
'
dtype
'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"
dtype
"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'
return_residual
'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"
return_residual
"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'
has_bias
'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"
has_bias
"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'
out_features
'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"
out_features
"
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'
in_features
'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"
in_features
"
,
[
1024
,
4096
])
def
test_fused_linear_bias
(
in_features
,
out_features
,
has_bias
,
return_residual
,
dtype
):
def
test_fused_linear_bias
(
in_features
,
out_features
,
has_bias
,
return_residual
,
dtype
):
device
=
'
cuda
'
device
=
"
cuda
"
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
batch_size
=
8
seqlen
=
512
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
x_pt
=
torch
.
randn
(
requires_grad
=
True
)
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
has_bias
,
device
=
device
,
dtype
=
dtype
)
model_pt
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
has_bias
,
device
=
device
,
dtype
=
dtype
)
model
=
FusedDense
(
in_features
,
out_features
,
bias
=
has_bias
,
return_residual
=
return_residual
,
model
=
FusedDense
(
device
=
device
,
dtype
=
dtype
)
in_features
,
out_features
,
bias
=
has_bias
,
return_residual
=
return_residual
,
device
=
device
,
dtype
=
dtype
,
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
weight
.
copy_
(
model_pt
.
weight
)
model
.
weight
.
copy_
(
model_pt
.
weight
)
if
has_bias
:
if
has_bias
:
...
@@ -37,10 +42,16 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual,
...
@@ -37,10 +42,16 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual,
out
=
model
(
x
)
out
=
model
(
x
)
else
:
else
:
out
,
x_copy
=
model
(
x
)
out
,
x_copy
=
model
(
x
)
x_copy
=
(
x_copy
[...,
:
out_features
]
if
out_features
<
in_features
x_copy
=
(
else
F
.
pad
(
x_copy
,
(
0
,
out_features
-
in_features
)))
x_copy
[...,
:
out_features
]
x_pt_copy
=
(
x_pt
[...,
:
out_features
]
if
out_features
<
in_features
if
out_features
<
in_features
else
F
.
pad
(
x_pt
,
(
0
,
out_features
-
in_features
)))
else
F
.
pad
(
x_copy
,
(
0
,
out_features
-
in_features
))
)
x_pt_copy
=
(
x_pt
[...,
:
out_features
]
if
out_features
<
in_features
else
F
.
pad
(
x_pt
,
(
0
,
out_features
-
in_features
))
)
# Just add some random function of the residual
# Just add some random function of the residual
out_pt
=
out_pt
+
F
.
gelu
(
x_pt_copy
)
out_pt
=
out_pt
+
F
.
gelu
(
x_pt_copy
)
out
=
out
+
F
.
gelu
(
x_copy
)
out
=
out
+
F
.
gelu
(
x_copy
)
...
@@ -60,43 +71,64 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual,
...
@@ -60,43 +71,64 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual,
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
bias
.
grad
,
model_pt
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
@
pytest
.
mark
.
parametrize
(
'
dtype
'
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"
dtype
"
,
[
torch
.
float16
,
torch
.
bfloat16
])
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'
heuristic
'
,
[
'
auto
'
,
-
1
])
@
pytest
.
mark
.
parametrize
(
"
heuristic
"
,
[
"
auto
"
,
-
1
])
# @pytest.mark.parametrize('heuristic', ['auto'])
# @pytest.mark.parametrize('heuristic', ['auto'])
@
pytest
.
mark
.
parametrize
(
'
checkpoint_lvl
'
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"
checkpoint_lvl
"
,
[
0
,
1
,
2
])
# @pytest.mark.parametrize('checkpoint_lvl', [1])
# @pytest.mark.parametrize('checkpoint_lvl', [1])
@
pytest
.
mark
.
parametrize
(
'
return_residual
'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"
return_residual
"
,
[
False
,
True
])
# @pytest.mark.parametrize('return_residual', [False])
# @pytest.mark.parametrize('return_residual', [False])
@
pytest
.
mark
.
parametrize
(
'
has_bias2
'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"
has_bias2
"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
'
has_bias1
'
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"
has_bias1
"
,
[
True
,
False
])
# @pytest.mark.parametrize('has_bias2', [True])
# @pytest.mark.parametrize('has_bias2', [True])
# @pytest.mark.parametrize('has_bias1', [True])
# @pytest.mark.parametrize('has_bias1', [True])
@
pytest
.
mark
.
parametrize
(
'
activation
'
,
[
'
gelu_approx
'
,
'
relu
'
])
@
pytest
.
mark
.
parametrize
(
"
activation
"
,
[
"
gelu_approx
"
,
"
relu
"
])
# @pytest.mark.parametrize('activation', ['relu'])
# @pytest.mark.parametrize('activation', ['relu'])
@
pytest
.
mark
.
parametrize
(
'
out_features
'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"
out_features
"
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
'
in_features
'
,
[
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"
in_features
"
,
[
1024
,
4096
])
# @pytest.mark.parametrize('out_features', [4096])
# @pytest.mark.parametrize('out_features', [4096])
# @pytest.mark.parametrize('in_features', [1024])
# @pytest.mark.parametrize('in_features', [1024])
def
test_fused_mlp
(
in_features
,
out_features
,
activation
,
has_bias1
,
has_bias2
,
return_residual
,
def
test_fused_mlp
(
checkpoint_lvl
,
heuristic
,
dtype
):
in_features
,
device
=
'cuda'
out_features
,
activation
,
has_bias1
,
has_bias2
,
return_residual
,
checkpoint_lvl
,
heuristic
,
dtype
,
):
device
=
"cuda"
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
rtol
,
atol
=
(
3e-3
,
3e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
batch_size
=
8
seqlen
=
512
seqlen
=
512
x_pt
=
torch
.
randn
(
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
x_pt
=
torch
.
randn
(
requires_grad
=
True
)
batch_size
,
seqlen
,
in_features
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
x
=
x_pt
.
detach
().
clone
().
requires_grad_
()
model_pt_fc1
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
has_bias1
,
device
=
device
,
model_pt_fc1
=
torch
.
nn
.
Linear
(
dtype
=
dtype
)
in_features
,
out_features
,
bias
=
has_bias1
,
device
=
device
,
dtype
=
dtype
model_pt_fc2
=
torch
.
nn
.
Linear
(
out_features
,
in_features
,
bias
=
has_bias2
,
device
=
device
,
)
dtype
=
dtype
)
model_pt_fc2
=
torch
.
nn
.
Linear
(
model
=
FusedMLP
(
in_features
,
out_features
,
in_features
,
activation
=
activation
,
out_features
,
in_features
,
bias
=
has_bias2
,
device
=
device
,
dtype
=
dtype
bias1
=
has_bias1
,
bias2
=
has_bias2
,
return_residual
=
return_residual
,
)
checkpoint_lvl
=
checkpoint_lvl
,
heuristic
=
heuristic
,
model
=
FusedMLP
(
device
=
device
,
dtype
=
dtype
)
in_features
,
out_features
,
in_features
,
activation
=
activation
,
bias1
=
has_bias1
,
bias2
=
has_bias2
,
return_residual
=
return_residual
,
checkpoint_lvl
=
checkpoint_lvl
,
heuristic
=
heuristic
,
device
=
device
,
dtype
=
dtype
,
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
)
model
.
fc1
.
weight
.
copy_
(
model_pt_fc1
.
weight
)
if
has_bias1
:
if
has_bias1
:
...
@@ -104,8 +136,11 @@ def test_fused_mlp(in_features, out_features, activation, has_bias1, has_bias2,
...
@@ -104,8 +136,11 @@ def test_fused_mlp(in_features, out_features, activation, has_bias1, has_bias2,
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
)
model
.
fc2
.
weight
.
copy_
(
model_pt_fc2
.
weight
)
if
has_bias2
:
if
has_bias2
:
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
model
.
fc2
.
bias
.
copy_
(
model_pt_fc2
.
bias
)
activation_fn
=
(
partial
(
F
.
gelu
,
approximate
=
'tanh'
)
if
activation
==
'gelu_approx'
activation_fn
=
(
else
partial
(
F
.
relu
,
inplace
=
True
))
partial
(
F
.
gelu
,
approximate
=
"tanh"
)
if
activation
==
"gelu_approx"
else
partial
(
F
.
relu
,
inplace
=
True
)
)
out_pt
=
model_pt_fc2
(
activation_fn
(
model_pt_fc1
(
x_pt
)))
out_pt
=
model_pt_fc2
(
activation_fn
(
model_pt_fc1
(
x_pt
)))
if
not
return_residual
:
if
not
return_residual
:
out
=
model
(
x
)
out
=
model
(
x
)
...
@@ -121,13 +156,17 @@ def test_fused_mlp(in_features, out_features, activation, has_bias1, has_bias2,
...
@@ -121,13 +156,17 @@ def test_fused_mlp(in_features, out_features, activation, has_bias1, has_bias2,
out_pt
.
backward
(
g
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
)
out
.
backward
(
g
)
# The error for relu is higher still
# The error for relu is higher still
if
activation
==
'
relu
'
:
if
activation
==
"
relu
"
:
atol
=
1e-1
if
dtype
==
torch
.
bfloat16
else
5e-2
atol
=
1e-1
if
dtype
==
torch
.
bfloat16
else
5e-2
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc1
.
weight
.
grad
,
model_pt_fc1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
if
has_bias1
:
if
has_bias1
:
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
fc1
.
bias
.
grad
,
model_pt_fc1
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt_fc2
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
fc2
.
weight
.
grad
,
model_pt_fc2
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
10
)
if
has_bias2
:
if
has_bias2
:
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
fc2
.
bias
.
grad
,
model_pt_fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
tests/ops/test_fused_dense_parallel.py
View file @
0e8c46ae
This diff is collapsed.
Click to expand it.
tests/test_flash_attn.py
View file @
0e8c46ae
This diff is collapsed.
Click to expand it.
tests/test_rotary.py
View file @
0e8c46ae
import
math
import
math
import
pytest
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
pytest
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.layers.rotary
import
apply_rotary_emb_func
,
apply_rotary_emb_torch
from
flash_attn.layers.rotary
import
apply_rotary_emb_func
,
apply_rotary_emb_torch
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
>=
(
8
,
0
)
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)
>=
(
8
,
0
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
'
rotary_fraction
'
,
[
1.0
,
0.5
])
@
pytest
.
mark
.
parametrize
(
"
rotary_fraction
"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [0.5])
# @pytest.mark.parametrize('rotary_fraction', [0.5])
@
pytest
.
mark
.
parametrize
(
'
inplace
'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"
inplace
"
,
[
False
,
True
])
# @pytest.mark.parametrize('inplace', [False])
# @pytest.mark.parametrize('inplace', [False])
def
test_rotary_single_tensor
(
inplace
,
rotary_fraction
,
dtype
):
def
test_rotary_single_tensor
(
inplace
,
rotary_fraction
,
dtype
):
rtol
=
1e-3
rtol
=
1e-3
...
@@ -23,12 +23,13 @@ def test_rotary_single_tensor(inplace, rotary_fraction, dtype):
...
@@ -23,12 +23,13 @@ def test_rotary_single_tensor(inplace, rotary_fraction, dtype):
nheads
=
4
nheads
=
4
seqlen
=
217
seqlen
=
217
headdim
=
128
headdim
=
128
x
=
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
'cuda'
,
x
=
torch
.
randn
(
requires_grad
=
True
)
batch_size
,
seqlen
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
x_pt
=
x
.
detach
().
clone
().
requires_grad_
()
x_pt
=
x
.
detach
().
clone
().
requires_grad_
()
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
assert
rotary_dim
%
2
==
0
assert
rotary_dim
%
2
==
0
angle
=
torch
.
randn
(
seqlen
,
rotary_dim
//
2
,
device
=
'
cuda
'
)
angle
=
torch
.
randn
(
seqlen
,
rotary_dim
//
2
,
device
=
"
cuda
"
)
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
out
=
apply_rotary_emb_func
(
x
,
cos
,
sin
,
inplace
)
out
=
apply_rotary_emb_func
(
x
,
cos
,
sin
,
inplace
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment