Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a5559a0e
Commit
a5559a0e
authored
Jul 03, 2022
by
Tri Dao
Browse files
Do P * dP (pointwise) in the bwd in fp32 instead of fp16
parent
6c3a8c65
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
43 deletions
+22
-43
csrc/flash_attn/src/fmha.h
csrc/flash_attn/src/fmha.h
+1
-1
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
+19
-40
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+2
-2
No files found.
csrc/flash_attn/src/fmha.h
View file @
a5559a0e
...
@@ -96,7 +96,7 @@ struct FMHA_fprop_params : public Qkv_params {
...
@@ -96,7 +96,7 @@ struct FMHA_fprop_params : public Qkv_params {
void
*
__restrict__
softmax_lse_ptr
;
void
*
__restrict__
softmax_lse_ptr
;
// The dimensions.
// The dimensions.
int
b
,
seqlen_q
,
seqlen_k
,
d
,
seqlen_q_rounded
;
int
b
,
seqlen_q
,
seqlen_k
,
d
;
// The scaling factors for the kernel.
// The scaling factors for the kernel.
float
scale_bmm1f
;
float
scale_bmm1f
;
...
...
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
View file @
a5559a0e
...
@@ -389,6 +389,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -389,6 +389,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
}
}
}
}
auto
pointwise_mult
=
[](
float
p
,
float
dp
,
float
d
)
{
return
p
*
((
!
Is_dropout
)
||
p
>=
0.
f
?
dp
:
d
);
};
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
softmax
.
elt_
[
2
*
mi
+
0
][
4
*
ni
+
0
]
=
pointwise_mult
(
softmax
.
elt_
[
2
*
mi
+
0
][
4
*
ni
+
0
],
acc_dp
[
mi
][
ni
].
elt
(
0
),
dp_sum
[
2
*
mi
+
0
]);
softmax
.
elt_
[
2
*
mi
+
0
][
4
*
ni
+
1
]
=
pointwise_mult
(
softmax
.
elt_
[
2
*
mi
+
0
][
4
*
ni
+
1
],
acc_dp
[
mi
][
ni
].
elt
(
1
),
dp_sum
[
2
*
mi
+
0
]);
softmax
.
elt_
[
2
*
mi
+
0
][
4
*
ni
+
2
]
=
pointwise_mult
(
softmax
.
elt_
[
2
*
mi
+
0
][
4
*
ni
+
2
],
acc_dp
[
mi
][
ni
].
elt
(
4
),
dp_sum
[
2
*
mi
+
0
]);
softmax
.
elt_
[
2
*
mi
+
0
][
4
*
ni
+
3
]
=
pointwise_mult
(
softmax
.
elt_
[
2
*
mi
+
0
][
4
*
ni
+
3
],
acc_dp
[
mi
][
ni
].
elt
(
5
),
dp_sum
[
2
*
mi
+
0
]);
softmax
.
elt_
[
2
*
mi
+
1
][
4
*
ni
+
0
]
=
pointwise_mult
(
softmax
.
elt_
[
2
*
mi
+
1
][
4
*
ni
+
0
],
acc_dp
[
mi
][
ni
].
elt
(
2
),
dp_sum
[
2
*
mi
+
1
]);
softmax
.
elt_
[
2
*
mi
+
1
][
4
*
ni
+
1
]
=
pointwise_mult
(
softmax
.
elt_
[
2
*
mi
+
1
][
4
*
ni
+
1
],
acc_dp
[
mi
][
ni
].
elt
(
3
),
dp_sum
[
2
*
mi
+
1
]);
softmax
.
elt_
[
2
*
mi
+
1
][
4
*
ni
+
2
]
=
pointwise_mult
(
softmax
.
elt_
[
2
*
mi
+
1
][
4
*
ni
+
2
],
acc_dp
[
mi
][
ni
].
elt
(
6
),
dp_sum
[
2
*
mi
+
1
]);
softmax
.
elt_
[
2
*
mi
+
1
][
4
*
ni
+
3
]
=
pointwise_mult
(
softmax
.
elt_
[
2
*
mi
+
1
][
4
*
ni
+
3
],
acc_dp
[
mi
][
ni
].
elt
(
7
),
dp_sum
[
2
*
mi
+
1
]);
}
}
// Load the fragments for K^T.
// Load the fragments for K^T.
typename
Smem_tile_kt
::
Fragment
frag_kt
[
2
][
Mma_tile_dq
::
MMAS_N
];
typename
Smem_tile_kt
::
Fragment
frag_kt
[
2
][
Mma_tile_dq
::
MMAS_N
];
smem_kt
.
load
(
frag_kt
[
0
],
0
);
smem_kt
.
load
(
frag_kt
[
0
],
0
);
...
@@ -404,46 +422,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -404,46 +422,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
}
}
}
}
softmax
.
unpack_noscale
(
acc_dp
);
softmax
.
pack
(
frag_p
);
// // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
// // will be zero.
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum[mi] *= params.p_dropout; }
Frag_p
frag_dp
[
Mma_tile_dq
::
MMAS_K
][
Mma_tile_dq
::
MMAS_M
];
softmax
.
pack
(
frag_dp
);
if
(
!
Is_dropout
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
frag_p
[
mi
][
ni
].
hmul
(
frag_dp
[
mi
][
ni
]);
}
}
}
else
{
__half2
dp_sum_half
[
Mma_tile_p
::
MMAS_M
*
2
];
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
*
2
;
mi
++
)
{
dp_sum_half
[
mi
]
=
__float2half2_rn
(
dp_sum
[
mi
]);
}
const
__half
zero_h
=
__half
(
0.
f
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
4
;
++
ii
)
{
const
__half2
p
=
frag_p
[
mi
][
ni
].
template
elt_as
<
__half2
>(
ii
);
const
__half2
pdp
=
__hmul2
(
p
,
frag_dp
[
mi
][
ni
].
template
elt_as
<
__half2
>(
ii
));
// If this element is dropped, then frag_p stores -p instead of p.
// So pd holds -p * dp_sum in that case.
const
__half2
pd
=
__hmul2
(
p
,
dp_sum_half
[
mi
*
2
+
(
ii
%
2
)]);
const
__half
low
=
__low2half
(
p
)
>=
zero_h
?
__low2half
(
pdp
)
:
__low2half
(
pd
);
const
__half
high
=
__high2half
(
p
)
>=
zero_h
?
__high2half
(
pdp
)
:
__high2half
(
pd
);
frag_p
[
mi
][
ni
].
template
elt_as
<
__half2
>(
ii
)
=
__halves2half2
(
low
,
high
);
}
}
}
}
// Store dp to smem for transpose
// Store dp to smem for transpose
smem_dp
.
store
(
frag_p
);
smem_dp
.
store
(
frag_p
);
...
...
flash_attn/flash_attn_interface.py
View file @
a5559a0e
...
@@ -215,8 +215,8 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
...
@@ -215,8 +215,8 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
Arguments:
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k,
2,
nheads, headdim), where total_k = total number of key tokens in the batch.
k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
v: (total_k,
2,
nheads, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment