Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
4f81aff4
Commit
4f81aff4
authored
Oct 31, 2022
by
Tri Dao
Browse files
Add debug_barrier for all headdims in Triton bwd
parent
bedcbd6a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
4 deletions
+6
-4
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+2
-2
tests/test_flash_attn.py
tests/test_flash_attn.py
+4
-2
No files found.
flash_attn/flash_attn_triton.py
View file @
4f81aff4
...
@@ -300,7 +300,7 @@ def _bwd_kernel_one_col_block(
...
@@ -300,7 +300,7 @@ def _bwd_kernel_one_col_block(
dv
+=
tl
.
dot
(
p
.
to
(
do
.
dtype
),
do
,
trans_a
=
True
)
dv
+=
tl
.
dot
(
p
.
to
(
do
.
dtype
),
do
,
trans_a
=
True
)
# compute dp = dot(v, do)
# compute dp = dot(v, do)
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
if
not
EVEN_HEADDIM
:
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
tl
.
debug_barrier
()
tl
.
debug_barrier
()
dp
=
tl
.
dot
(
do
,
v
,
trans_b
=
True
)
dp
=
tl
.
dot
(
do
,
v
,
trans_b
=
True
)
# compute ds = p * (dp - delta[:, None])
# compute ds = p * (dp - delta[:, None])
...
...
tests/test_flash_attn.py
View file @
4f81aff4
...
@@ -896,11 +896,13 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
...
@@ -896,11 +896,13 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
print
(
f
'dV max diff:
{
(
dv
-
dv_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV max diff:
{
(
dv
-
dv_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dQ mean diff:
{
(
dq
-
dq_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'dQ mean diff:
{
(
dq
-
dq_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'dK mean diff:
{
(
dk
-
dk_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'dK mean diff:
{
(
dk
-
dk_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'dV mean diff:
{
(
dv
-
dv_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV Pytorch max diff:
{
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV Pytorch max diff:
{
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'dQ Pytorch mean diff:
{
(
dq_pt
-
dq_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'dK Pytorch mean diff:
{
(
dk_pt
-
dk_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'dV Pytorch mean diff:
{
(
dv_pt
-
dv_ref
).
abs
().
mean
().
item
()
}
'
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
# of a Pytorch implementation.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment