Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
5ca83a9c
Commit
5ca83a9c
authored
Jul 22, 2024
by
Tri Dao
Browse files
Clean up softcapping bwd a bit
parent
751c762c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
18 deletions
+7
-18
README.md
README.md
+1
-1
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+5
-16
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+1
-1
No files found.
README.md
View file @
5ca83a9c
...
...
@@ -353,7 +353,7 @@ Thanks to @beginlner for this contribution.
### 2.6: Softcapping.
Support attention with softcapping, as used in Gemma-2 and Grok models.
Thanks to @Narsil for this contribution.
Thanks to @Narsil
and @lucidrains
for this contribution.
## Performance
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
5ca83a9c
...
...
@@ -480,16 +480,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// if (cute::thread(32, 0)) { print(scores); }
// Softcapping - calculating dTanh and scaling dS later with it
auto
dtanh
=
([
&
]{
Tensor
dtanh
=
make_tensor_like
(
scores
);
if
constexpr
(
Is_softcap
)
{
Tensor
_dtanh
=
make_tensor_like
(
scores
);
flash
::
calculate_dtanh
(
scores
,
_dtanh
,
params
.
softcap
);
return
_dtanh
;
flash
::
calculate_dtanh
(
scores
,
dtanh
,
params
.
softcap
);
}
else
{
return
nullptr
;
}
}());
// Alibi
if
(
Has_alibi
)
{
...
...
@@ -591,13 +585,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
dS
);
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
dS
);
++
ni
)
{
float
scaled_ds
=
pointwise_mult
(
scores
(
mi
,
ni
),
dS
(
mi
,
ni
),
dP_sum
(
mi
));
if
constexpr
(
Is_softcap
)
{
scaled_ds
*=
dtanh
(
mi
,
ni
);
}
if
constexpr
(
Is_softcap
)
{
scaled_ds
*=
dtanh
(
mi
,
ni
);
}
dS
(
mi
,
ni
)
=
scaled_ds
;
}
}
...
...
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
5ca83a9c
...
...
@@ -99,7 +99,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream)
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
>
;
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
&&
!
Is_softcap
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment