Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d562aa63
"test/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "8ecf6b9d2480c3f600826c7d8fef6a16ed603c3f"
Unverified
Commit
d562aa63
authored
Jul 31, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 31, 2024
Browse files
Sync with FA v2.6.0 to support soft capping (#13)
parent
12375706
Changes
81
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
332 additions
and
196 deletions
+332
-196
csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+85
-22
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+142
-156
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
...lash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
+7
-0
No files found.
csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
224
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
224
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim224
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim224
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
256
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
256
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
256
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_fwd_hdim256
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
256
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
256
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
256
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim256
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
32
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
32
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
32
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_fwd_hdim32
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
32
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
32
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
32
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim32
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
64
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
64
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
64
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_fwd_hdim64
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
64
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
64
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
64
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim64
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
96
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
96
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
96
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_fwd_hdim96
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
96
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
96
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
96
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim96
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
d562aa63
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#pragma once
#pragma once
#include <cute/
algorithm/copy
.hpp>
#include <cute/
tensor
.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/array.h>
...
@@ -22,9 +22,38 @@ namespace flash {
...
@@ -22,9 +22,38 @@ namespace flash {
using
namespace
cute
;
using
namespace
cute
;
template
<
typename
Engine
,
typename
Layout
>
__forceinline__
__device__
void
apply_softcap
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
float
softcap
){
#pragma unroll
for
(
int
i
=
0
;
i
<
size
(
tensor
);
++
i
)
{
tensor
(
i
)
=
cutlass
::
fast_tanh
(
tensor
(
i
)
*
softcap
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
ElementAccum
,
typename
Params
,
int
kBlockM
,
bool
Is_even_MN
>
__forceinline__
__device__
auto
get_lse_tile
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
&
binfo
)
{
// When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path.
// Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick.
// Otherwise, it's written as (h, b, seqlen_q).
const
bool
varlen_q
=
params
.
unpadded_lse
&&
!
params
.
seqlenq_ngroups_swapped
;
auto
lse_offset
=
varlen_q
?
binfo
.
q_offset
(
params
.
seqlen_q
,
1
,
bidb
)
:
0
;
auto
gmem_ptr_lse
=
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
lse_offset
);
auto
lse_shape
=
varlen_q
?
make_shape
(
1
,
params
.
h
,
params
.
total_q
)
:
make_shape
(
params
.
b
,
params
.
h
,
params
.
seqlen_q
);
auto
lse_stride
=
params
.
seqlenq_ngroups_swapped
?
make_stride
(
1
,
params
.
seqlen_q
*
params
.
b
,
params
.
b
)
:
(
params
.
unpadded_lse
?
make_stride
(
params
.
h
*
params
.
total_q
,
params
.
total_q
,
1
)
:
make_stride
(
params
.
h
*
params
.
seqlen_q
,
params
.
seqlen_q
,
1
)
);
auto
lse_layout
=
make_layout
(
lse_shape
,
lse_stride
);
Tensor
mLSE
=
make_tensor
(
gmem_ptr_lse
,
lse_layout
);
auto
mLSE_slice
=
varlen_q
?
mLSE
(
0
,
bidh
,
_
)
:
mLSE
(
bidb
,
bidh
,
_
);
return
local_tile
(
mLSE_slice
,
Shape
<
Int
<
kBlockM
>>
{},
make_coord
(
m_block
));
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
Element
=
typename
Kernel_traits
::
Element
;
...
@@ -74,10 +103,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -74,10 +103,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
make_stride
(
params
.
o_row_stride
,
params
.
o_head_stride
,
_1
{}));
make_stride
(
params
.
o_row_stride
,
params
.
o_head_stride
,
_1
{}));
Tensor
gO
=
local_tile
(
mO
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Tensor
gO
=
local_tile
(
mO
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
Tensor
mLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)),
make_shape
(
params
.
b
,
params
.
h
,
params
.
seqlen_q
),
Tensor
gLSE
=
get_lse_tile
<
ElementAccum
,
Params
,
kBlockM
,
Is_even_MN
>
(
params
,
bidb
,
bidh
,
m_block
,
binfo
);
make_stride
(
params
.
h
*
params
.
seqlen_q
,
params
.
seqlen_q
,
_1
{}));
Tensor
gLSE
=
local_tile
(
mLSE
(
bidb
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>>
{},
make_coord
(
m_block
));
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
...
@@ -143,7 +170,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -143,7 +170,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
()
.
get
()
,
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
...
@@ -300,6 +327,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -300,6 +327,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
smem_thr_copy_Q
,
smem_thr_copy_K
smem_thr_copy_Q
,
smem_thr_copy_K
);
);
// if (cute::thread0()) { print(acc_s); }
// if (cute::thread0()) { print(acc_s); }
if
constexpr
(
Is_softcap
){
apply_softcap
(
acc_s
,
params
.
softcap
);
}
mask
.
template
apply_mask
<
Is_causal
,
Is_even_MN
>(
mask
.
template
apply_mask
<
Is_causal
,
Is_even_MN
>(
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
...
@@ -363,6 +393,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -363,6 +393,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
smem_thr_copy_Q
,
smem_thr_copy_K
smem_thr_copy_Q
,
smem_thr_copy_K
);
);
if
constexpr
(
Is_softcap
){
apply_softcap
(
acc_s
,
params
.
softcap
);
}
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
...
@@ -425,10 +458,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -425,10 +458,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
make_stride
(
params
.
o_row_stride
,
params
.
o_head_stride
,
_1
{}));
make_stride
(
params
.
o_row_stride
,
params
.
o_head_stride
,
_1
{}));
Tensor
gO
=
local_tile
(
mO
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Tensor
gO
=
local_tile
(
mO
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
Tensor
mLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)),
Tensor
gLSE
=
get_lse_tile
<
ElementAccum
,
Params
,
kBlockM
,
Is_even_MN
>
(
params
,
bidb
,
bidh
,
m_block
,
binfo
);
make_shape
(
params
.
b
,
params
.
h
,
params
.
seqlen_q
),
make_stride
(
params
.
h
*
params
.
seqlen_q
,
params
.
seqlen_q
,
_1
{}));
Tensor
gLSE
=
local_tile
(
mLSE
(
bidb
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>>
{},
make_coord
(
m_block
));
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
...
@@ -471,7 +501,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -471,7 +501,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock_splitkv
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
n_split_idx
,
const
int
num_n_splits
)
{
inline
__device__
void
compute_attn_1rowblock_splitkv
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
n_split_idx
,
const
int
num_n_splits
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
Element
=
typename
Kernel_traits
::
Element
;
...
@@ -587,7 +617,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -587,7 +617,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
sK
=
make_tensor
(
sQ
.
data
()
+
size
(
sQ
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sK
=
make_tensor
(
sQ
.
data
()
+
size
(
sQ
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
()
.
get
()
,
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_Q
;
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_Q
;
auto
gmem_thr_copy_Q
=
gmem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_Q
=
gmem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
...
@@ -881,6 +911,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -881,6 +911,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
smem_thr_copy_Q
,
smem_thr_copy_K
smem_thr_copy_Q
,
smem_thr_copy_K
);
);
// if (cute::thread0()) { print(acc_s); }
// if (cute::thread0()) { print(acc_s); }
if
constexpr
(
Is_softcap
){
apply_softcap
(
acc_s
,
params
.
softcap
);
}
mask
.
template
apply_mask
<
Is_causal
,
Is_even_MN
>(
mask
.
template
apply_mask
<
Is_causal
,
Is_even_MN
>(
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
...
@@ -946,6 +980,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -946,6 +980,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
smem_thr_copy_Q
,
smem_thr_copy_K
smem_thr_copy_Q
,
smem_thr_copy_K
);
);
if
constexpr
(
Is_softcap
){
apply_softcap
(
acc_s
,
params
.
softcap
);
}
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
...
@@ -1004,7 +1041,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1004,7 +1041,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_oaccum
=
(((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
const
index_t
row_offset_oaccum
=
(((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
)
*
params
.
d_rounded
;
+
m_block
*
kBlockM
)
*
params
.
d_rounded
;
const
index_t
row_offset_lseaccum
=
((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
const
index_t
row_offset_lseaccum
=
(
Split
||
!
params
.
unpadded_lse
?
((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
:
bidh
*
params
.
total_q
+
binfo
.
q_offset
(
params
.
seqlen_q
,
1
,
bidb
)
)
+
m_block
*
kBlockM
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
...
@@ -1054,7 +1093,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1054,7 +1093,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn
(
const
Params
&
params
)
{
inline
__device__
void
compute_attn
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
// The block index for the batch.
...
@@ -1070,12 +1109,12 @@ inline __device__ void compute_attn(const Params ¶ms) {
...
@@ -1070,12 +1109,12 @@ inline __device__ void compute_attn(const Params ¶ms) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Is_softcap
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
inline
__device__
void
compute_attn_splitkv
(
const
Params
&
params
)
{
inline
__device__
void
compute_attn_splitkv
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
// The block index for the batch.
...
@@ -1084,7 +1123,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
...
@@ -1084,7 +1123,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
const
int
bidh
=
Split
?
blockIdx
.
z
-
bidb
*
params
.
h
:
blockIdx
.
z
;
const
int
bidh
=
Split
?
blockIdx
.
z
-
bidb
*
params
.
h
:
blockIdx
.
z
;
const
int
n_split_idx
=
Split
?
blockIdx
.
y
:
0
;
const
int
n_split_idx
=
Split
?
blockIdx
.
y
:
0
;
const
int
num_n_splits
=
Split
?
gridDim
.
y
:
1
;
const
int
num_n_splits
=
Split
?
gridDim
.
y
:
1
;
flash
::
compute_attn_1rowblock_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
num_n_splits
);
flash
::
compute_attn_1rowblock_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Is_softcap
,
Split
,
Append_KV
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
num_n_splits
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -1110,21 +1149,36 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
...
@@ -1110,21 +1149,36 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
const
int
tidx
=
threadIdx
.
x
;
const
int
tidx
=
threadIdx
.
x
;
const
int
bidx
=
blockIdx
.
x
;
const
int
bidx
=
blockIdx
.
x
;
const
index_t
lse_size
=
params
.
b
*
params
.
h
*
params
.
seqlen_q
;
const
index_t
row_offset_lse
=
bidx
*
kBlockM
;
const
index_t
row_offset_lse
=
bidx
*
kBlockM
;
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lseaccum_ptr
)
+
row_offset_lse
),
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lseaccum_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kMaxSplits
>
,
Int
<
kBlockM
>>
{},
Shape
<
Int
<
kMaxSplits
>
,
Int
<
kBlockM
>>
{},
make_stride
(
params
.
b
*
params
.
h
*
params
.
seqlen_q
,
_1
{}));
make_stride
(
lse_size
,
_1
{}));
// LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile.
// This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}.
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
// This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}.
Layout
flat_layout
=
make_layout
(
lse_size
);
Layout
orig_layout
=
make_layout
(
make_shape
(
params
.
seqlen_q
,
params
.
h
,
params
.
b
));
auto
transposed_stride
=
params
.
seqlenq_ngroups_swapped
?
make_stride
(
params
.
b
,
params
.
seqlen_q
*
params
.
b
,
1
)
:
make_stride
(
1
,
params
.
seqlen_q
*
params
.
b
,
params
.
seqlen_q
);
Layout
remapped_layout
=
make_layout
(
make_shape
(
params
.
seqlen_q
,
params
.
h
,
params
.
b
),
transposed_stride
);
Layout
final_layout
=
cute
::
composition
(
remapped_layout
,
cute
::
composition
(
orig_layout
,
flat_layout
));
Tensor
gLSE_unpadded
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)),
final_layout
);
constexpr
int
kNLsePerThread
=
(
kMaxSplits
*
kBlockM
+
kNThreads
-
1
)
/
kNThreads
;
constexpr
int
kNLsePerThread
=
(
kMaxSplits
*
kBlockM
+
kNThreads
-
1
)
/
kNThreads
;
// Read the LSE values from gmem and store them in shared memory, then tranpose them.
// Read the LSE values from gmem and store them in shared memory, then tran
s
pose them.
constexpr
int
kRowsPerLoadLSE
=
kNThreads
/
kBlockM
;
constexpr
int
kRowsPerLoadLSE
=
kNThreads
/
kBlockM
;
#pragma unroll
#pragma unroll
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
const
int
row
=
l
*
kRowsPerLoadLSE
+
tidx
/
kBlockM
;
const
int
row
=
l
*
kRowsPerLoadLSE
+
tidx
/
kBlockM
;
const
int
col
=
tidx
%
kBlockM
;
const
int
col
=
tidx
%
kBlockM
;
ElementAccum
lse
=
(
row
<
params
.
num_splits
&&
col
<
params
.
b
*
params
.
h
*
params
.
seqlen_q
-
bidx
*
kBlockM
)
?
gLSEaccum
(
row
,
col
)
:
-
INFINITY
;
ElementAccum
lse
=
(
row
<
params
.
num_splits
&&
col
<
lse_size
-
bidx
*
kBlockM
)
?
gLSEaccum
(
row
,
col
)
:
-
INFINITY
;
if
(
row
<
kMaxSplits
)
{
sLSE
[
row
][
col
]
=
lse
;
}
if
(
row
<
kMaxSplits
)
{
sLSE
[
row
][
col
]
=
lse
;
}
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
}
}
...
@@ -1163,7 +1217,16 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
...
@@ -1163,7 +1217,16 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
ElementAccum
lse_logsum
=
(
lse_sum
==
0.
f
||
lse_sum
!=
lse_sum
)
?
INFINITY
:
logf
(
lse_sum
)
+
lse_max
;
ElementAccum
lse_logsum
=
(
lse_sum
==
0.
f
||
lse_sum
!=
lse_sum
)
?
INFINITY
:
logf
(
lse_sum
)
+
lse_max
;
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
if
(
tidx
%
kRowsPerLoadTranspose
==
0
&&
tidx
/
kRowsPerLoadTranspose
<
kBlockM
)
{
gLSE
(
tidx
/
kRowsPerLoadTranspose
)
=
lse_logsum
;
}
if
(
tidx
%
kRowsPerLoadTranspose
==
0
&&
tidx
/
kRowsPerLoadTranspose
<
kBlockM
)
{
if
(
params
.
unpadded_lse
)
{
const
index_t
lse_offset
=
row_offset_lse
+
tidx
/
kRowsPerLoadTranspose
;
if
(
lse_offset
<
lse_size
)
{
gLSE_unpadded
(
lse_offset
)
=
lse_logsum
;
}
}
else
{
gLSE
(
tidx
/
kRowsPerLoadTranspose
)
=
lse_logsum
;
}
}
// Store the scales exp(lse - lse_logsum) in shared memory.
// Store the scales exp(lse - lse_logsum) in shared memory.
#pragma unroll
#pragma unroll
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
d562aa63
...
@@ -26,18 +26,18 @@
...
@@ -26,18 +26,18 @@
template<typename Kernel_traits, __VA_ARGS__> \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_kernel
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
)
{
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_kernel
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Return_softmax
)
{
#if defined(ARCH_SUPPORTS_FLASH)
#if defined(ARCH_SUPPORTS_FLASH)
static_assert
(
!
(
Is_causal
&&
Is_local
));
// Enforce constraints
static_assert
(
!
(
Is_causal
&&
Is_local
));
// Enforce constraints
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Is_softcap
,
Return_softmax
>
(
params
);
#else
#else
FLASH_UNSUPPORTED_ARCH
FLASH_UNSUPPORTED_ARCH
#endif
#endif
}
}
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_splitkv_kernel
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
)
{
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_splitkv_kernel
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Split
,
bool
Append_KV
)
{
#if defined(ARCH_SUPPORTS_FLASH)
#if defined(ARCH_SUPPORTS_FLASH)
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Is_softcap
,
Split
,
Append_KV
>
(
params
);
#else
#else
FLASH_UNSUPPORTED_ARCH
FLASH_UNSUPPORTED_ARCH
#endif
#endif
...
@@ -67,25 +67,27 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -67,25 +67,27 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
LOCAL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
LOCAL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
ALIBI_SWITCH
(
params
.
alibi_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
ALIBI_SWITCH
(
params
.
alibi_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
SOFTCAP_SWITCH
(
params
.
softcap
>
0.0
,
Is_softcap
,
[
&
]
{
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// Will only return softmax if dropout, to reduce compilation time.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
!
ReturnSoftmaxConst
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// If Is_local, set Is_causal to false
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
&&
!
Is_softcap
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
!
ReturnSoftmaxConst
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
,
ReturnSoftmaxConst
&&
Is_dropout
&&
!
Is_softcap
>
;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
if
(
smem_size
>=
48
*
1024
)
{
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
if
(
smem_size
>=
48
*
1024
)
{
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
}
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
// int ctas_per_sm;
}
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// int ctas_per_sm;
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
});
});
});
});
...
@@ -93,7 +95,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -93,7 +95,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
});
});
}
}
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
,
bool
Is_causal
>
void
run_flash_splitkv_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_flash_splitkv_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
static_assert
(
!
Kernel_traits
::
Is_Q_in_regs
,
"SplitKV implementation does not support Is_Q_in_regs"
);
static_assert
(
!
Kernel_traits
::
Is_Q_in_regs
,
"SplitKV implementation does not support Is_Q_in_regs"
);
static_assert
(
!
Kernel_traits
::
Share_Q_K_smem
,
"SplitKV implementation does not support Share_Q_K_smem"
);
static_assert
(
!
Kernel_traits
::
Share_Q_K_smem
,
"SplitKV implementation does not support Share_Q_K_smem"
);
...
@@ -102,17 +104,17 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -102,17 +104,17 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
dim3
grid
(
num_m_block
,
params
.
num_splits
>
1
?
params
.
num_splits
:
params
.
b
,
params
.
num_splits
>
1
?
params
.
b
*
params
.
h
:
params
.
h
);
dim3
grid
(
num_m_block
,
params
.
num_splits
>
1
?
params
.
num_splits
:
params
.
b
,
params
.
num_splits
>
1
?
params
.
b
*
params
.
h
:
params
.
h
);
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL
_SWITCH
(
is_even_
MN
,
IsEven
MN
Const
,
[
&
]
{
EVENK
_SWITCH
(
is_even_
K
,
IsEven
K
Const
,
[
&
]
{
EVENK
_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
LOCAL
_SWITCH
(
(
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
LOCA
L_SWITCH
(
(
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOO
L_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
BOOL
_SWITCH
(
params
.
knew
_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
ALIBI
_SWITCH
(
params
.
alibi_slopes
_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
ALIBI
_SWITCH
(
params
.
alibi_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
SOFTCAP
_SWITCH
(
params
.
softcap
>
0.0
,
Is_softcap
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Split
,
Append_KV
>
;
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
,
Split
,
Append_KV
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
...
@@ -155,161 +157,149 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -155,161 +157,149 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
}
}
}
}
template
<
typename
T
,
int
Headdim
>
template
<
typename
T
,
int
Headdim
,
bool
Is_causal
>
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
kBlockM
=
64
;
// Fixed for all head dimensions
constexpr
static
int
kBlockM
=
64
;
// Fixed for all head dimensions
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
// and for headdim 192 with block size 64 x 128.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr
static
int
kBlockN
=
Headdim
<=
64
?
256
:
(
Headdim
<=
128
?
128
:
64
);
constexpr
static
int
kBlockN
=
Headdim
<=
64
?
256
:
(
Headdim
<=
128
?
128
:
64
);
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
>>
(
params
,
stream
);
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
>
,
Is_causal
>
(
params
,
stream
);
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim32
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim32
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
32
;
constexpr
static
int
Headdim
=
32
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim64
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim64
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
64
;
constexpr
static
int
Headdim
=
64
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
if
constexpr
(
!
Is_dropout
)
{
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using block size (64 x 256) is 27% slower for seqlen=2k
// Using block size (64 x 256) is 27% slower for seqlen=2k
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
}
else
{
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
}
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim96
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim96
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
96
;
constexpr
static
int
Headdim
=
96
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
if
(
is_sm8x
)
{
if
(
is_sm8x
)
{
if
constexpr
(
!
Is_causal
)
{
if
constexpr
(
!
Is_causal
)
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
}
else
{
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// These two are always slower
}
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
});
// These two are always slower
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim128
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim128
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
128
;
constexpr
static
int
Headdim
=
128
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
if
constexpr
(
!
Is_dropout
)
{
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
if
(
is_sm8x
)
{
if
(
is_sm8x
)
{
if
constexpr
(
!
Is_causal
)
{
if
constexpr
(
!
Is_causal
)
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
32
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
32
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
else
{
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// 1st ones are good for H100, A100
// 2nd one is good for A6000 bc we get slightly better occupancy
}
else
{
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
32
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
}
}
});
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// 1st ones are good for H100, A100
// 2nd one is good for A6000 bc we get slightly better occupancy
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
32
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
}
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim160
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim160
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
160
;
constexpr
static
int
Headdim
=
160
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// For A100, H100, 128 x 32 is the fastest.
// For A100, H100, 128 x 32 is the fastest.
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 64 with 8 warps is the fastest for non-causal.
// and 128 x 64 with 8 warps is the fastest for non-causal.
if
(
is_sm8x
)
{
if
(
is_sm8x
)
{
if
constexpr
(
!
Is_causal
)
{
if
constexpr
(
!
Is_causal
)
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
else
{
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
32
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
}
else
{
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
32
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
});
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim192
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim192
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
192
;
constexpr
static
int
Headdim
=
192
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
if
constexpr
(
!
Is_dropout
)
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim224
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim224
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
224
;
constexpr
static
int
Headdim
=
224
;
int
device
;
int
device
;
...
@@ -322,23 +312,21 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -322,23 +312,21 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
}
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
// printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
(
max_smem_per_block
>=
2
*
Headdim
*
(
128
+
2
*
64
))
{
// 112 KB
if
(
max_smem_per_block
>=
2
*
Headdim
*
(
128
+
2
*
64
))
{
// 112 KB
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
// We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
// If we have N = 32, there are only 1024 elements to load at once, where each load
// If we have N = 32, there are only 1024 elements to load at once, where each load
// is 8 elements. This means we can only use 128 threads and not 256 threads.
// is 8 elements. This means we can only use 128 threads and not 256 threads.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim256
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim256
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
256
;
constexpr
static
int
Headdim
=
256
;
int
device
;
int
device
;
...
@@ -353,18 +341,16 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -353,18 +341,16 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
}
}
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// For A100, we want to run with 128 x 64 (128KB smem).
// For A100, we want to run with 128 x 64 (128KB smem).
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
if
(
max_smem_per_block
>=
2
*
Headdim
*
(
128
+
2
*
64
)
&&
max_smem_per_sm
<
4
*
Headdim
*
(
64
+
2
*
64
))
{
if
(
max_smem_per_block
>=
2
*
Headdim
*
(
128
+
2
*
64
)
&&
max_smem_per_sm
<
4
*
Headdim
*
(
64
+
2
*
64
))
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
// 64 KB
// 64 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// 96 KB
// 96 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
});
}
}
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
128
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment