Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
52fb4b72
Commit
52fb4b72
authored
Oct 16, 2022
by
Tri Dao
Browse files
Fix #54: set device for multi-GPU case
parent
1b9facac
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
82 additions
and
0 deletions
+82
-0
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+7
-0
tests/test_flash_attn.py
tests/test_flash_attn.py
+75
-0
No files found.
csrc/flash_attn/fmha_api.cpp
View file @
52fb4b72
...
...
@@ -28,6 +28,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "fmha.h"
...
...
@@ -246,6 +247,9 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
int
max_seqlen_q
=
((
max_seqlen_q_
+
16
-
1
)
/
16
)
*
16
;
bool
loop
=
max_seqlen_k
>
blocksize_c
;
// Otherwise the kernel will be launched from cuda:0 device
at
::
cuda
::
CUDAGuard
device_guard
{
q
.
get_device
()};
auto
opts
=
q
.
options
();
// auto o = torch::empty({ total_q, num_heads, head_size }, opts);
...
...
@@ -400,6 +404,9 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
int
max_seqlen_q
=
((
max_seqlen_q_
+
16
-
1
)
/
16
)
*
16
;
bool
loop
=
max_seqlen_k
>
blocksize_c
;
// Otherwise the kernel will be launched from cuda:0 device
at
::
cuda
::
CUDAGuard
device_guard
{
q
.
get_device
()};
// It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different.
auto
softmax_lse
=
softmax_lse_
.
index
({
torch
::
indexing
::
Slice
(),
torch
::
indexing
::
Slice
(),
torch
::
indexing
::
Slice
(
torch
::
indexing
::
None
,
max_seqlen_q
)}).
contiguous
();
...
...
tests/test_flash_attn.py
View file @
52fb4b72
...
...
@@ -772,3 +772,78 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
assert
torch
.
equal
(
dq_unpad
,
dq_unpad_0
)
assert
torch
.
equal
(
dk_unpad
,
dk_unpad_0
)
assert
torch
.
equal
(
dv_unpad
,
dv_unpad_0
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
'requires multiple GPUs'
)
def
test_flash_attn_multigpu
():
seqlen
=
256
d
=
64
dropout_p
=
0.0
causal
=
False
dtype
=
torch
.
float16
device
=
'cuda:1'
torch
.
random
.
manual_seed
(
0
)
batch_size
=
32
nheads
=
4
x
=
torch
.
randn
(
batch_size
,
seqlen
,
nheads
*
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
Wqkv
=
torch
.
nn
.
Linear
(
nheads
*
d
,
3
*
nheads
*
d
,
device
=
device
,
dtype
=
dtype
)
key_padding_mask
=
generate_random_padding_mask
(
seqlen
,
batch_size
,
device
,
mode
=
'random'
)
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
qkv_unpad
,
cu_seqlens
,
max_seqlen
,
qkv
,
output_pad_fn
,
dqkv_pad_fn
=
generate_qkv
(
x
,
Wqkv
,
nheads
,
key_padding_mask
,
key_padding_mask
,
qkvpacked
=
True
)
output_unpad
,
sm_lse
,
S_dmask
=
flash_attn_unpadded_qkvpacked_func
(
qkv_unpad
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
)
output
=
output_pad_fn
(
output_unpad
)
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask
,
key_padding_mask
,
key_padding_mask
,
d
,
dropout_p
>
0.0
,
causal
=
causal
)
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
attn
=
normalize_flash_attn_S
(
attn_unnorm
,
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
key_padding_mask
,
key_padding_mask
,
dropout_p
>
0.0
,
causal
=
causal
)
dropout_fraction
=
get_dropout_fraction
(
dropout_mask
,
key_padding_mask
,
key_padding_mask
,
causal
=
causal
).
item
()
output_ref
,
attn_ref
=
attention_qkvpacked_ref
(
qkv
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
)
output_pt
,
attn_pt
=
attention_qkvpacked_ref
(
qkv
,
key_padding_mask
,
dropout_p
,
dropout_mask
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
)
print
(
f
'Actual dropout fraction:
{
dropout_fraction
}
'
)
print
(
f
'Output max diff:
{
(
output
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
output
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'Pytorch max diff:
{
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Pytorch mean diff:
{
(
output_pt
-
output_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'Attention max diff:
{
(
attn
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Attention Pytorch max diff:
{
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
}
'
)
g
=
torch
.
randn_like
(
output
)
dqkv_unpad
,
=
torch
.
autograd
.
grad
(
output
,
qkv_unpad
,
g
)
dqkv
=
dqkv_pad_fn
(
dqkv_unpad
)
dqkv_ref
,
=
torch
.
autograd
.
grad
(
output_ref
,
qkv
,
g
)
dqkv_pt
,
=
torch
.
autograd
.
grad
(
output_pt
,
qkv
,
g
)
print
(
f
'dQ max diff:
{
(
dqkv
[:,
:,
0
]
-
dqkv_ref
[:,
:,
0
]).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK max diff:
{
(
dqkv
[:,
:,
1
]
-
dqkv_ref
[:,
:,
1
]).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV max diff:
{
(
dqkv
[:,
:,
2
]
-
dqkv_ref
[:,
:,
2
]).
abs
().
max
().
item
()
}
'
)
print
(
f
'dQKV mean diff:
{
(
dqkv
-
dqkv_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'dQ Pytorch max diff:
{
(
dqkv_pt
[:,
:,
0
]
-
dqkv_ref
[:,
:,
0
]).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK Pytorch max diff:
{
(
dqkv_pt
[:,
:,
1
]
-
dqkv_ref
[:,
:,
1
]).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV Pytorch max diff:
{
(
dqkv_pt
[:,
:,
2
]
-
dqkv_ref
[:,
:,
2
]).
abs
().
max
().
item
()
}
'
)
print
(
f
'dQKV Pytorch mean diff:
{
(
dqkv_pt
-
dqkv_ref
).
abs
().
mean
().
item
()
}
'
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
output
-
output_ref
).
abs
().
max
().
item
()
<=
2
*
(
output_pt
-
output_ref
).
abs
().
max
().
item
()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
assert
(
attn
-
attn_ref
).
abs
().
max
().
item
()
<=
2
*
(
attn_pt
-
attn_ref
).
abs
().
max
().
item
()
# assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol)
if
dropout_p
==
0.0
:
assert
dropout_mask
.
all
()
else
:
assert
0.99
<=
dropout_fraction
/
dropout_p
<=
1.01
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment