Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d562aa63
Unverified
Commit
d562aa63
authored
Jul 31, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 31, 2024
Browse files
Sync with FA v2.6.0 to support soft capping (#13)
parent
12375706
Changes
81
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
332 additions
and
196 deletions
+332
-196
csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+85
-22
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+142
-156
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
...lash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
+7
-0
No files found.
csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
224
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
224
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim224
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim224
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
256
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
256
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
256
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_fwd_hdim256
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
256
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
256
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
256
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim256
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
32
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
32
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
32
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_fwd_hdim32
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
32
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
32
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
32
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim32
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
64
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
64
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
64
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_fwd_hdim64
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
64
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
64
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
64
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim64
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
96
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
96
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
96
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_fwd_hdim96
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
96
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
View file @
d562aa63
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
96
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
96
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim96
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
d562aa63
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#pragma once
#pragma once
#include <cute/
algorithm/copy
.hpp>
#include <cute/
tensor
.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/array.h>
...
@@ -22,9 +22,38 @@ namespace flash {
...
@@ -22,9 +22,38 @@ namespace flash {
using
namespace
cute
;
using
namespace
cute
;
template
<
typename
Engine
,
typename
Layout
>
__forceinline__
__device__
void
apply_softcap
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
float
softcap
){
#pragma unroll
for
(
int
i
=
0
;
i
<
size
(
tensor
);
++
i
)
{
tensor
(
i
)
=
cutlass
::
fast_tanh
(
tensor
(
i
)
*
softcap
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
ElementAccum
,
typename
Params
,
int
kBlockM
,
bool
Is_even_MN
>
__forceinline__
__device__
auto
get_lse_tile
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
&
binfo
)
{
// When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path.
// Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick.
// Otherwise, it's written as (h, b, seqlen_q).
const
bool
varlen_q
=
params
.
unpadded_lse
&&
!
params
.
seqlenq_ngroups_swapped
;
auto
lse_offset
=
varlen_q
?
binfo
.
q_offset
(
params
.
seqlen_q
,
1
,
bidb
)
:
0
;
auto
gmem_ptr_lse
=
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
lse_offset
);
auto
lse_shape
=
varlen_q
?
make_shape
(
1
,
params
.
h
,
params
.
total_q
)
:
make_shape
(
params
.
b
,
params
.
h
,
params
.
seqlen_q
);
auto
lse_stride
=
params
.
seqlenq_ngroups_swapped
?
make_stride
(
1
,
params
.
seqlen_q
*
params
.
b
,
params
.
b
)
:
(
params
.
unpadded_lse
?
make_stride
(
params
.
h
*
params
.
total_q
,
params
.
total_q
,
1
)
:
make_stride
(
params
.
h
*
params
.
seqlen_q
,
params
.
seqlen_q
,
1
)
);
auto
lse_layout
=
make_layout
(
lse_shape
,
lse_stride
);
Tensor
mLSE
=
make_tensor
(
gmem_ptr_lse
,
lse_layout
);
auto
mLSE_slice
=
varlen_q
?
mLSE
(
0
,
bidh
,
_
)
:
mLSE
(
bidb
,
bidh
,
_
);
return
local_tile
(
mLSE_slice
,
Shape
<
Int
<
kBlockM
>>
{},
make_coord
(
m_block
));
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
Element
=
typename
Kernel_traits
::
Element
;
...
@@ -74,10 +103,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -74,10 +103,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
make_stride
(
params
.
o_row_stride
,
params
.
o_head_stride
,
_1
{}));
make_stride
(
params
.
o_row_stride
,
params
.
o_head_stride
,
_1
{}));
Tensor
gO
=
local_tile
(
mO
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Tensor
gO
=
local_tile
(
mO
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
Tensor
mLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)),
make_shape
(
params
.
b
,
params
.
h
,
params
.
seqlen_q
),
Tensor
gLSE
=
get_lse_tile
<
ElementAccum
,
Params
,
kBlockM
,
Is_even_MN
>
(
params
,
bidb
,
bidh
,
m_block
,
binfo
);
make_stride
(
params
.
h
*
params
.
seqlen_q
,
params
.
seqlen_q
,
_1
{}));
Tensor
gLSE
=
local_tile
(
mLSE
(
bidb
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>>
{},
make_coord
(
m_block
));
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
...
@@ -143,7 +170,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -143,7 +170,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
()
.
get
()
,
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
...
@@ -300,6 +327,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -300,6 +327,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
smem_thr_copy_Q
,
smem_thr_copy_K
smem_thr_copy_Q
,
smem_thr_copy_K
);
);
// if (cute::thread0()) { print(acc_s); }
// if (cute::thread0()) { print(acc_s); }
if
constexpr
(
Is_softcap
){
apply_softcap
(
acc_s
,
params
.
softcap
);
}
mask
.
template
apply_mask
<
Is_causal
,
Is_even_MN
>(
mask
.
template
apply_mask
<
Is_causal
,
Is_even_MN
>(
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
...
@@ -363,6 +393,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -363,6 +393,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
smem_thr_copy_Q
,
smem_thr_copy_K
smem_thr_copy_Q
,
smem_thr_copy_K
);
);
if
constexpr
(
Is_softcap
){
apply_softcap
(
acc_s
,
params
.
softcap
);
}
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
...
@@ -425,10 +458,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -425,10 +458,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
make_stride
(
params
.
o_row_stride
,
params
.
o_head_stride
,
_1
{}));
make_stride
(
params
.
o_row_stride
,
params
.
o_head_stride
,
_1
{}));
Tensor
gO
=
local_tile
(
mO
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Tensor
gO
=
local_tile
(
mO
(
_
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
make_coord
(
m_block
,
0
));
// (kBlockM, kHeadDim)
Tensor
mLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)),
Tensor
gLSE
=
get_lse_tile
<
ElementAccum
,
Params
,
kBlockM
,
Is_even_MN
>
(
params
,
bidb
,
bidh
,
m_block
,
binfo
);
make_shape
(
params
.
b
,
params
.
h
,
params
.
seqlen_q
),
make_stride
(
params
.
h
*
params
.
seqlen_q
,
params
.
seqlen_q
,
_1
{}));
Tensor
gLSE
=
local_tile
(
mLSE
(
bidb
,
bidh
,
_
),
Shape
<
Int
<
kBlockM
>>
{},
make_coord
(
m_block
));
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
...
@@ -471,7 +501,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -471,7 +501,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock_splitkv
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
n_split_idx
,
const
int
num_n_splits
)
{
inline
__device__
void
compute_attn_1rowblock_splitkv
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
n_split_idx
,
const
int
num_n_splits
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
Element
=
typename
Kernel_traits
::
Element
;
...
@@ -587,7 +617,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -587,7 +617,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
sK
=
make_tensor
(
sQ
.
data
()
+
size
(
sQ
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sK
=
make_tensor
(
sQ
.
data
()
+
size
(
sQ
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
()
.
get
()
,
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_Q
;
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_Q
;
auto
gmem_thr_copy_Q
=
gmem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_Q
=
gmem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
...
@@ -881,6 +911,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -881,6 +911,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
smem_thr_copy_Q
,
smem_thr_copy_K
smem_thr_copy_Q
,
smem_thr_copy_K
);
);
// if (cute::thread0()) { print(acc_s); }
// if (cute::thread0()) { print(acc_s); }
if
constexpr
(
Is_softcap
){
apply_softcap
(
acc_s
,
params
.
softcap
);
}
mask
.
template
apply_mask
<
Is_causal
,
Is_even_MN
>(
mask
.
template
apply_mask
<
Is_causal
,
Is_even_MN
>(
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
...
@@ -946,6 +980,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -946,6 +980,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
smem_thr_copy_Q
,
smem_thr_copy_K
smem_thr_copy_Q
,
smem_thr_copy_K
);
);
if
constexpr
(
Is_softcap
){
apply_softcap
(
acc_s
,
params
.
softcap
);
}
flash
::
cp_async_wait
<
0
>
();
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
__syncthreads
();
...
@@ -1004,7 +1041,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1004,7 +1041,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_oaccum
=
(((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
const
index_t
row_offset_oaccum
=
(((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
)
*
params
.
d_rounded
;
+
m_block
*
kBlockM
)
*
params
.
d_rounded
;
const
index_t
row_offset_lseaccum
=
((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
const
index_t
row_offset_lseaccum
=
(
Split
||
!
params
.
unpadded_lse
?
((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
:
bidh
*
params
.
total_q
+
binfo
.
q_offset
(
params
.
seqlen_q
,
1
,
bidb
)
)
+
m_block
*
kBlockM
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
Split
?
row_offset_oaccum
:
row_offset_o
)),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
...
@@ -1054,7 +1093,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1054,7 +1093,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn
(
const
Params
&
params
)
{
inline
__device__
void
compute_attn
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
// The block index for the batch.
...
@@ -1070,12 +1109,12 @@ inline __device__ void compute_attn(const Params ¶ms) {
...
@@ -1070,12 +1109,12 @@ inline __device__ void compute_attn(const Params ¶ms) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Is_softcap
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Split
,
bool
Append_KV
,
typename
Params
>
inline
__device__
void
compute_attn_splitkv
(
const
Params
&
params
)
{
inline
__device__
void
compute_attn_splitkv
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
// The block index for the batch.
...
@@ -1084,7 +1123,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
...
@@ -1084,7 +1123,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
const
int
bidh
=
Split
?
blockIdx
.
z
-
bidb
*
params
.
h
:
blockIdx
.
z
;
const
int
bidh
=
Split
?
blockIdx
.
z
-
bidb
*
params
.
h
:
blockIdx
.
z
;
const
int
n_split_idx
=
Split
?
blockIdx
.
y
:
0
;
const
int
n_split_idx
=
Split
?
blockIdx
.
y
:
0
;
const
int
num_n_splits
=
Split
?
gridDim
.
y
:
1
;
const
int
num_n_splits
=
Split
?
gridDim
.
y
:
1
;
flash
::
compute_attn_1rowblock_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
num_n_splits
);
flash
::
compute_attn_1rowblock_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Is_softcap
,
Split
,
Append_KV
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
num_n_splits
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -1110,21 +1149,36 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
...
@@ -1110,21 +1149,36 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
const
int
tidx
=
threadIdx
.
x
;
const
int
tidx
=
threadIdx
.
x
;
const
int
bidx
=
blockIdx
.
x
;
const
int
bidx
=
blockIdx
.
x
;
const
index_t
lse_size
=
params
.
b
*
params
.
h
*
params
.
seqlen_q
;
const
index_t
row_offset_lse
=
bidx
*
kBlockM
;
const
index_t
row_offset_lse
=
bidx
*
kBlockM
;
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lseaccum_ptr
)
+
row_offset_lse
),
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lseaccum_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kMaxSplits
>
,
Int
<
kBlockM
>>
{},
Shape
<
Int
<
kMaxSplits
>
,
Int
<
kBlockM
>>
{},
make_stride
(
params
.
b
*
params
.
h
*
params
.
seqlen_q
,
_1
{}));
make_stride
(
lse_size
,
_1
{}));
// LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile.
// This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}.
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
// This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}.
Layout
flat_layout
=
make_layout
(
lse_size
);
Layout
orig_layout
=
make_layout
(
make_shape
(
params
.
seqlen_q
,
params
.
h
,
params
.
b
));
auto
transposed_stride
=
params
.
seqlenq_ngroups_swapped
?
make_stride
(
params
.
b
,
params
.
seqlen_q
*
params
.
b
,
1
)
:
make_stride
(
1
,
params
.
seqlen_q
*
params
.
b
,
params
.
seqlen_q
);
Layout
remapped_layout
=
make_layout
(
make_shape
(
params
.
seqlen_q
,
params
.
h
,
params
.
b
),
transposed_stride
);
Layout
final_layout
=
cute
::
composition
(
remapped_layout
,
cute
::
composition
(
orig_layout
,
flat_layout
));
Tensor
gLSE_unpadded
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)),
final_layout
);
constexpr
int
kNLsePerThread
=
(
kMaxSplits
*
kBlockM
+
kNThreads
-
1
)
/
kNThreads
;
constexpr
int
kNLsePerThread
=
(
kMaxSplits
*
kBlockM
+
kNThreads
-
1
)
/
kNThreads
;
// Read the LSE values from gmem and store them in shared memory, then tranpose them.
// Read the LSE values from gmem and store them in shared memory, then tran
s
pose them.
constexpr
int
kRowsPerLoadLSE
=
kNThreads
/
kBlockM
;
constexpr
int
kRowsPerLoadLSE
=
kNThreads
/
kBlockM
;
#pragma unroll
#pragma unroll
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
const
int
row
=
l
*
kRowsPerLoadLSE
+
tidx
/
kBlockM
;
const
int
row
=
l
*
kRowsPerLoadLSE
+
tidx
/
kBlockM
;
const
int
col
=
tidx
%
kBlockM
;
const
int
col
=
tidx
%
kBlockM
;
ElementAccum
lse
=
(
row
<
params
.
num_splits
&&
col
<
params
.
b
*
params
.
h
*
params
.
seqlen_q
-
bidx
*
kBlockM
)
?
gLSEaccum
(
row
,
col
)
:
-
INFINITY
;
ElementAccum
lse
=
(
row
<
params
.
num_splits
&&
col
<
lse_size
-
bidx
*
kBlockM
)
?
gLSEaccum
(
row
,
col
)
:
-
INFINITY
;
if
(
row
<
kMaxSplits
)
{
sLSE
[
row
][
col
]
=
lse
;
}
if
(
row
<
kMaxSplits
)
{
sLSE
[
row
][
col
]
=
lse
;
}
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
}
}
...
@@ -1163,7 +1217,16 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
...
@@ -1163,7 +1217,16 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
ElementAccum
lse_logsum
=
(
lse_sum
==
0.
f
||
lse_sum
!=
lse_sum
)
?
INFINITY
:
logf
(
lse_sum
)
+
lse_max
;
ElementAccum
lse_logsum
=
(
lse_sum
==
0.
f
||
lse_sum
!=
lse_sum
)
?
INFINITY
:
logf
(
lse_sum
)
+
lse_max
;
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
if
(
tidx
%
kRowsPerLoadTranspose
==
0
&&
tidx
/
kRowsPerLoadTranspose
<
kBlockM
)
{
gLSE
(
tidx
/
kRowsPerLoadTranspose
)
=
lse_logsum
;
}
if
(
tidx
%
kRowsPerLoadTranspose
==
0
&&
tidx
/
kRowsPerLoadTranspose
<
kBlockM
)
{
if
(
params
.
unpadded_lse
)
{
const
index_t
lse_offset
=
row_offset_lse
+
tidx
/
kRowsPerLoadTranspose
;
if
(
lse_offset
<
lse_size
)
{
gLSE_unpadded
(
lse_offset
)
=
lse_logsum
;
}
}
else
{
gLSE
(
tidx
/
kRowsPerLoadTranspose
)
=
lse_logsum
;
}
}
// Store the scales exp(lse - lse_logsum) in shared memory.
// Store the scales exp(lse - lse_logsum) in shared memory.
#pragma unroll
#pragma unroll
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
d562aa63
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
0 → 100644
View file @
d562aa63
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
128
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment