Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d562aa63
Unverified
Commit
d562aa63
authored
Jul 31, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 31, 2024
Browse files
Sync with FA v2.6.0 to support soft capping (#13)
parent
12375706
Changes
81
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
332 additions
and
196 deletions
+332
-196
csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+85
-22
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+142
-156
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
...lash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
+7
-0
No files found.
csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
224
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
224
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim224
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim224
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
256
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
256
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
256
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_fwd_hdim256
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
256
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
256
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
256
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim256
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
32
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
32
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
32
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_fwd_hdim32
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
32
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
32
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
32
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim32
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
64
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
64
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
64
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_fwd_hdim64
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
64
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
64
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
64
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim64
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
96
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
96
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
96
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_fwd_hdim96
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
96
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
96
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
96
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim96
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
d562aa63
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#pragma once
#pragma once
#include <cute/
algorithm/copy
.hpp>
#include <cute/
tensor
.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/array.h>
...
@@ -22,9 +22,38 @@ namespace flash {
...
@@ -22,9 +22,38 @@ namespace flash {
using
namespace
cute
;
using
namespace
cute
;
template
<
typename
Engine
,
typename
Layout
>
__forceinline__
__device__
void
apply_softcap
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
float
softcap
){
#pragma unroll
for
(
int
i
=
0
;
i
<
size
(
tensor
);
++
i
)
{
tensor
(
i
)
=
cutlass
::
fast_tanh
(
tensor
(
i
)
*
softcap
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
ElementAccum
,
typename
Params
,
int
kBlockM
,
bool
Is_even_MN
>
__forceinline__
__device__
auto
get_lse_tile
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
&
binfo
)
{
// When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path.
// Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick.
// Otherwise, it's written as (h, b, seqlen_q).
const
bool
varlen_q
=
params
.
unpadded_lse
&&
!
params
.
seqlenq_ngroups_swapped
;
auto
lse_offset
=
varlen_q
?
binfo
.
q_offset
(
params
.
seqlen_q
,
1
,
bidb
)
:
0
;
auto
gmem_ptr_lse
=
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
lse_offset
);
auto
lse_shape
=
varlen_q
?
make_shape
(
1
,
params
.
h
,
params
.
total_q
)
:
make_shape
(
params
.
b
,
params
.
h
,
params
.
seqlen_q
);
auto
lse_stride
=
params
.
seqlenq_ngroups_swapped
?
make_stride
(
1
,
params
.
seqlen_q
*
params
.
b
,
params
.
b
)
:
(
params
.
unpadded_lse
?
make_stride
(
params
.
h
*
params
.
total_q
,
params
.
total_q
,
1
)
:
make_stride
(
params
.
h
*
params
.
seqlen_q
,
params
.
seqlen_q
,
1
)
);
auto
lse_layout
=
make_layout
(
lse_shape
,
lse_stride
);
Tensor
mLSE
=
make_tensor
(
gmem_ptr_lse
,
lse_layout
);
auto
mLSE_slice
=
varlen_q
?
mLSE
(
0
,
bidh
,
_
)
:
mLSE
(
bidb
,
bidh
,
_
);
return
local_tile
(
mLSE_slice
,
Shape
<
Int
<
kBlockM
>>
{},
make_coord
(
m_block
));
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
Element
=
typename
Kernel_traits
::
Element
;
...
@@ -74,10 +103,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -74,10 +103,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
make_stride
(
params
.
o_row_stride
,
params
.
o_head_stride
,
_1
{}));
make_stride
(
params
.
o_row_stride
,
params
.
o_head_stride
,
_1
{}));
Tensor
gO
=
local_tile
(
mO
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Tensor
gO
=
local_tile
(
mO
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
Tensor
mLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)),
make_shape
(
params
.
b
,
params
.
h
,
params
.
seqlen_q
),
Tensor
gLSE
=
get_lse_tile
<
ElementAccum
,
Params
,
kBlockM
,
Is_even_MN
>
(
params
,
bidb
,
bidh
,
m_block
,
binfo
);
make_stride
(
params
.
h
*
params
.
seqlen_q
,
params
.
seqlen_q
,
_1
{}));
Tensor
gLSE
=
local_tile
(
mLSE
(
bidb
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>>
{},
make_coord
(
m_block
));
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
...
@@ -143,7 +170,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -143,7 +170,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
()
.
get
()
,
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
...
@@ -300,6 +327,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -300,6 +327,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
smem_thr_copy_Q
,
smem_thr_copy_K
smem_thr_copy_Q
,
smem_thr_copy_K
);
);
// if (cute::thread0()) { print(acc_s); }
// if (cute::thread0()) { print(acc_s); }
if
constexpr
(
Is_softcap
){
apply_softcap
(
acc_s
,
params
.
softcap
);
}
mask
.
template
apply_mask
<
Is_causal
,
Is_even_MN
>(
mask
.
template
apply_mask
<
Is_causal
,
Is_even_MN
>(
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
...
@@ -363,6 +393,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -363,6 +393,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
smem_thr_copy_Q
,
smem_thr_copy_K
smem_thr_copy_Q
,
smem_thr_copy_K
);
);
if
constexpr
(
Is_softcap
){
apply_softcap
(
acc_s
,
params
.
softcap
);
}
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
...
@@ -425,10 +458,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -425,10 +458,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
make_stride
(
params
.
o_row_stride
,
params
.
o_head_stride
,
_1
{}));
make_stride
(
params
.
o_row_stride
,
params
.
o_head_stride
,
_1
{}));
Tensor
gO
=
local_tile
(
mO
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Tensor
gO
=
local_tile
(
mO
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
Tensor
mLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)),
Tensor
gLSE
=
get_lse_tile
<
ElementAccum
,
Params
,
kBlockM
,
Is_even_MN
>
(
params
,
bidb
,
bidh
,
m_block
,
binfo
);
make_shape
(
params
.
b
,
params
.
h
,
params
.
seqlen_q
),
make_stride
(
params
.
h
*
params
.
seqlen_q
,
params
.
seqlen_q
,
_1
{}));
Tensor
gLSE
=
local_tile
(
mLSE
(
bidb
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>>
{},
make_coord
(
m_block
));
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
...
@@ -471,7 +501,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -471,7 +501,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock_splitkv
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
n_split_idx
,
const
int
num_n_splits
)
{
inline
__device__
void
compute_attn_1rowblock_splitkv
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
n_split_idx
,
const
int
num_n_splits
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
Element
=
typename
Kernel_traits
::
Element
;
...
@@ -587,7 +617,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -587,7 +617,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
sK
=
make_tensor
(
sQ
.
data
()
+
size
(
sQ
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sK
=
make_tensor
(
sQ
.
data
()
+
size
(
sQ
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
()
.
get
()
,
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_Q
;
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_Q
;
auto
gmem_thr_copy_Q
=
gmem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_Q
=
gmem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
...
@@ -881,6 +911,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -881,6 +911,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
smem_thr_copy_Q
,
smem_thr_copy_K
smem_thr_copy_Q
,
smem_thr_copy_K
);
);
// if (cute::thread0()) { print(acc_s); }
// if (cute::thread0()) { print(acc_s); }
if
constexpr
(
Is_softcap
){
apply_softcap
(
acc_s
,
params
.
softcap
);
}
mask
.
template
apply_mask
<
Is_causal
,
Is_even_MN
>(
mask
.
template
apply_mask
<
Is_causal
,
Is_even_MN
>(
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
...
@@ -946,6 +980,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -946,6 +980,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
smem_thr_copy_Q
,
smem_thr_copy_K
smem_thr_copy_Q
,
smem_thr_copy_K
);
);
if
constexpr
(
Is_softcap
){
apply_softcap
(
acc_s
,
params
.
softcap
);
}
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
...
@@ -1004,7 +1041,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1004,7 +1041,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_oaccum
=
(((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
const
index_t
row_offset_oaccum
=
(((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
)
*
params
.
d_rounded
;
+
m_block
*
kBlockM
)
*
params
.
d_rounded
;
const
index_t
row_offset_lseaccum
=
((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
const
index_t
row_offset_lseaccum
=
(
Split
||
!
params
.
unpadded_lse
?
((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
:
bidh
*
params
.
total_q
+
binfo
.
q_offset
(
params
.
seqlen_q
,
1
,
bidb
)
)
+
m_block
*
kBlockM
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
...
@@ -1054,7 +1093,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1054,7 +1093,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn
(
const
Params
&
params
)
{
inline
__device__
void
compute_attn
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
// The block index for the batch.
...
@@ -1070,12 +1109,12 @@ inline __device__ void compute_attn(const Params ¶ms) {
...
@@ -1070,12 +1109,12 @@ inline __device__ void compute_attn(const Params ¶ms) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Is_softcap
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
inline
__device__
void
compute_attn_splitkv
(
const
Params
&
params
)
{
inline
__device__
void
compute_attn_splitkv
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
// The block index for the batch.
...
@@ -1084,7 +1123,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
...
@@ -1084,7 +1123,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
const
int
bidh
=
Split
?
blockIdx
.
z
-
bidb
*
params
.
h
:
blockIdx
.
z
;
const
int
bidh
=
Split
?
blockIdx
.
z
-
bidb
*
params
.
h
:
blockIdx
.
z
;
const
int
n_split_idx
=
Split
?
blockIdx
.
y
:
0
;
const
int
n_split_idx
=
Split
?
blockIdx
.
y
:
0
;
const
int
num_n_splits
=
Split
?
gridDim
.
y
:
1
;
const
int
num_n_splits
=
Split
?
gridDim
.
y
:
1
;
flash
::
compute_attn_1rowblock_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
num_n_splits
);
flash
::
compute_attn_1rowblock_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Is_softcap
,
Split
,
Append_KV
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
num_n_splits
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -1110,21 +1149,36 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
...
@@ -1110,21 +1149,36 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
const
int
tidx
=
threadIdx
.
x
;
const
int
tidx
=
threadIdx
.
x
;
const
int
bidx
=
blockIdx
.
x
;
const
int
bidx
=
blockIdx
.
x
;
const
index_t
lse_size
=
params
.
b
*
params
.
h
*
params
.
seqlen_q
;
const
index_t
row_offset_lse
=
bidx
*
kBlockM
;
const
index_t
row_offset_lse
=
bidx
*
kBlockM
;
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lseaccum_ptr
)
+
row_offset_lse
),
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lseaccum_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kMaxSplits
>
,
Int
<
kBlockM
>>
{},
Shape
<
Int
<
kMaxSplits
>
,
Int
<
kBlockM
>>
{},
make_stride
(
params
.
b
*
params
.
h
*
params
.
seqlen_q
,
_1
{}));
make_stride
(
lse_size
,
_1
{}));
// LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile.
// This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}.
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
// This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}.
Layout
flat_layout
=
make_layout
(
lse_size
);
Layout
orig_layout
=
make_layout
(
make_shape
(
params
.
seqlen_q
,
params
.
h
,
params
.
b
));
auto
transposed_stride
=
params
.
seqlenq_ngroups_swapped
?
make_stride
(
params
.
b
,
params
.
seqlen_q
*
params
.
b
,
1
)
:
make_stride
(
1
,
params
.
seqlen_q
*
params
.
b
,
params
.
seqlen_q
);
Layout
remapped_layout
=
make_layout
(
make_shape
(
params
.
seqlen_q
,
params
.
h
,
params
.
b
),
transposed_stride
);
Layout
final_layout
=
cute
::
composition
(
remapped_layout
,
cute
::
composition
(
orig_layout
,
flat_layout
));
Tensor
gLSE_unpadded
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)),
final_layout
);
constexpr
int
kNLsePerThread
=
(
kMaxSplits
*
kBlockM
+
kNThreads
-
1
)
/
kNThreads
;
constexpr
int
kNLsePerThread
=
(
kMaxSplits
*
kBlockM
+
kNThreads
-
1
)
/
kNThreads
;
// Read the LSE values from gmem and store them in shared memory, then tranpose them.
// Read the LSE values from gmem and store them in shared memory, then tran
s
pose them.
constexpr
int
kRowsPerLoadLSE
=
kNThreads
/
kBlockM
;
constexpr
int
kRowsPerLoadLSE
=
kNThreads
/
kBlockM
;
#pragma unroll
#pragma unroll
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
const
int
row
=
l
*
kRowsPerLoadLSE
+
tidx
/
kBlockM
;
const
int
row
=
l
*
kRowsPerLoadLSE
+
tidx
/
kBlockM
;
const
int
col
=
tidx
%
kBlockM
;
const
int
col
=
tidx
%
kBlockM
;
ElementAccum
lse
=
(
row
<
params
.
num_splits
&&
col
<
params
.
b
*
params
.
h
*
params
.
seqlen_q
-
bidx
*
kBlockM
)
?
gLSEaccum
(
row
,
col
)
:
-
INFINITY
;
ElementAccum
lse
=
(
row
<
params
.
num_splits
&&
col
<
lse_size
-
bidx
*
kBlockM
)
?
gLSEaccum
(
row
,
col
)
:
-
INFINITY
;
if
(
row
<
kMaxSplits
)
{
sLSE
[
row
][
col
]
=
lse
;
}
if
(
row
<
kMaxSplits
)
{
sLSE
[
row
][
col
]
=
lse
;
}
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
}
}
...
@@ -1163,7 +1217,16 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
...
@@ -1163,7 +1217,16 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
ElementAccum
lse_logsum
=
(
lse_sum
==
0.
f
||
lse_sum
!=
lse_sum
)
?
INFINITY
:
logf
(
lse_sum
)
+
lse_max
;
ElementAccum
lse_logsum
=
(
lse_sum
==
0.
f
||
lse_sum
!=
lse_sum
)
?
INFINITY
:
logf
(
lse_sum
)
+
lse_max
;
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
if
(
tidx
%
kRowsPerLoadTranspose
==
0
&&
tidx
/
kRowsPerLoadTranspose
<
kBlockM
)
{
gLSE
(
tidx
/
kRowsPerLoadTranspose
)
=
lse_logsum
;
}
if
(
tidx
%
kRowsPerLoadTranspose
==
0
&&
tidx
/
kRowsPerLoadTranspose
<
kBlockM
)
{
if
(
params
.
unpadded_lse
)
{
const
index_t
lse_offset
=
row_offset_lse
+
tidx
/
kRowsPerLoadTranspose
;
if
(
lse_offset
<
lse_size
)
{
gLSE_unpadded
(
lse_offset
)
=
lse_logsum
;
}
}
else
{
gLSE
(
tidx
/
kRowsPerLoadTranspose
)
=
lse_logsum
;
}
}
// Store the scales exp(lse - lse_logsum) in shared memory.
// Store the scales exp(lse - lse_logsum) in shared memory.
#pragma unroll
#pragma unroll
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
d562aa63
...
@@ -26,18 +26,18 @@
...
@@ -26,18 +26,18 @@
template<typename Kernel_traits, __VA_ARGS__> \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_kernel
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
)
{
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_kernel
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Return_softmax
)
{
#if defined(ARCH_SUPPORTS_FLASH)
#if defined(ARCH_SUPPORTS_FLASH)
static_assert
(
!
(
Is_causal
&&
Is_local
));
// Enforce constraints
static_assert
(
!
(
Is_causal
&&
Is_local
));
// Enforce constraints
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Is_softcap
,
Return_softmax
>
(
params
);
#else
#else
FLASH_UNSUPPORTED_ARCH
FLASH_UNSUPPORTED_ARCH
#endif
#endif
}
}
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_splitkv_kernel
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
)
{
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_splitkv_kernel
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Split
,
bool
Append_KV
)
{
#if defined(ARCH_SUPPORTS_FLASH)
#if defined(ARCH_SUPPORTS_FLASH)
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Is_softcap
,
Split
,
Append_KV
>
(
params
);
#else
#else
FLASH_UNSUPPORTED_ARCH
FLASH_UNSUPPORTED_ARCH
#endif
#endif
...
@@ -67,12 +67,13 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -67,12 +67,13 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
LOCAL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
LOCAL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
ALIBI_SWITCH
(
params
.
alibi_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
ALIBI_SWITCH
(
params
.
alibi_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
SOFTCAP_SWITCH
(
params
.
softcap
>
0.0
,
Is_softcap
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
!
ReturnSoftmaxConst
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
&&
!
Is_softcap
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
!
ReturnSoftmaxConst
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
,
ReturnSoftmaxConst
&&
Is_dropout
&&
!
Is_softcap
>
;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
...
@@ -91,9 +92,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -91,9 +92,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
});
});
});
});
});
});
});
}
}
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
,
bool
Is_causal
>
void
run_flash_splitkv_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_flash_splitkv_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
static_assert
(
!
Kernel_traits
::
Is_Q_in_regs
,
"SplitKV implementation does not support Is_Q_in_regs"
);
static_assert
(
!
Kernel_traits
::
Is_Q_in_regs
,
"SplitKV implementation does not support Is_Q_in_regs"
);
static_assert
(
!
Kernel_traits
::
Share_Q_K_smem
,
"SplitKV implementation does not support Share_Q_K_smem"
);
static_assert
(
!
Kernel_traits
::
Share_Q_K_smem
,
"SplitKV implementation does not support Share_Q_K_smem"
);
...
@@ -102,17 +104,17 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -102,17 +104,17 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
dim3
grid
(
num_m_block
,
params
.
num_splits
>
1
?
params
.
num_splits
:
params
.
b
,
params
.
num_splits
>
1
?
params
.
b
*
params
.
h
:
params
.
h
);
dim3
grid
(
num_m_block
,
params
.
num_splits
>
1
?
params
.
num_splits
:
params
.
b
,
params
.
num_splits
>
1
?
params
.
b
*
params
.
h
:
params
.
h
);
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
EVENK_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
EVENK_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
LOCAL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
LOCAL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
ALIBI_SWITCH
(
params
.
alibi_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
ALIBI_SWITCH
(
params
.
alibi_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
SOFTCAP_SWITCH
(
params
.
softcap
>
0.0
,
Is_softcap
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Split
,
Append_KV
>
;
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
,
Split
,
Append_KV
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
...
@@ -155,31 +157,28 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -155,31 +157,28 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
}
}
}
}
template
<
typename
T
,
int
Headdim
>
template
<
typename
T
,
int
Headdim
,
bool
Is_causal
>
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
kBlockM
=
64
;
// Fixed for all head dimensions
constexpr
static
int
kBlockM
=
64
;
// Fixed for all head dimensions
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
// and for headdim 192 with block size 64 x 128.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr
static
int
kBlockN
=
Headdim
<=
64
?
256
:
(
Headdim
<=
128
?
128
:
64
);
constexpr
static
int
kBlockN
=
Headdim
<=
64
?
256
:
(
Headdim
<=
128
?
128
:
64
);
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
>>
(
params
,
stream
);
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
>
,
Is_causal
>
(
params
,
stream
);
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim32
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim32
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
32
;
constexpr
static
int
Headdim
=
32
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim64
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim64
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
64
;
constexpr
static
int
Headdim
=
64
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
if
constexpr
(
!
Is_dropout
)
{
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using block size (64 x 256) is 27% slower for seqlen=2k
// Using block size (64 x 256) is 27% slower for seqlen=2k
...
@@ -194,16 +193,14 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -194,16 +193,14 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
}
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim96
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim96
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
96
;
constexpr
static
int
Headdim
=
96
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
if
(
is_sm8x
)
{
if
(
is_sm8x
)
{
if
constexpr
(
!
Is_causal
)
{
if
constexpr
(
!
Is_causal
)
{
...
@@ -220,16 +217,14 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -220,16 +217,14 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim128
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim128
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
128
;
constexpr
static
int
Headdim
=
128
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
if
constexpr
(
!
Is_dropout
)
{
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
...
@@ -257,16 +252,14 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -257,16 +252,14 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
}
}
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim160
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim160
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
160
;
constexpr
static
int
Headdim
=
160
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// For A100, H100, 128 x 32 is the fastest.
// For A100, H100, 128 x 32 is the fastest.
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 64 with 8 warps is the fastest for non-causal.
// and 128 x 64 with 8 warps is the fastest for non-causal.
...
@@ -287,14 +280,12 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -287,14 +280,12 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim192
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim192
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
192
;
constexpr
static
int
Headdim
=
192
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
if
constexpr
(
!
Is_dropout
)
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
}
else
{
...
@@ -306,10 +297,9 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -306,10 +297,9 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim224
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim224
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
224
;
constexpr
static
int
Headdim
=
224
;
int
device
;
int
device
;
...
@@ -322,7 +312,6 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -322,7 +312,6 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
}
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
// printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
(
max_smem_per_block
>=
2
*
Headdim
*
(
128
+
2
*
64
))
{
// 112 KB
if
(
max_smem_per_block
>=
2
*
Headdim
*
(
128
+
2
*
64
))
{
// 112 KB
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
}
else
{
...
@@ -335,10 +324,9 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -335,10 +324,9 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
// is 8 elements. This means we can only use 128 threads and not 256 threads.
// is 8 elements. This means we can only use 128 threads and not 256 threads.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim256
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim256
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
256
;
constexpr
static
int
Headdim
=
256
;
int
device
;
int
device
;
...
@@ -353,7 +341,6 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -353,7 +341,6 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
}
}
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// For A100, we want to run with 128 x 64 (128KB smem).
// For A100, we want to run with 128 x 64 (128KB smem).
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
if
(
max_smem_per_block
>=
2
*
Headdim
*
(
128
+
2
*
64
)
&&
max_smem_per_sm
<
4
*
Headdim
*
(
64
+
2
*
64
))
{
if
(
max_smem_per_block
>=
2
*
Headdim
*
(
128
+
2
*
64
)
&&
max_smem_per_sm
<
4
*
Headdim
*
(
64
+
2
*
64
))
{
...
@@ -366,5 +353,4 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -366,5 +353,4 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
// 96 KB
// 96 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
});
}
}
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
128
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment