Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a44f48df
Commit
a44f48df
authored
Oct 21, 2022
by
Tri Dao
Browse files
Split fwd on the seqlen_q dimension
parent
1aa6d7d9
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
111 additions
and
45 deletions
+111
-45
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+18
-8
csrc/flash_attn/src/fmha.h
csrc/flash_attn/src/fmha.h
+2
-0
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+42
-5
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+41
-30
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+8
-2
No files found.
csrc/flash_attn/fmha_api.cpp
View file @
a44f48df
...
@@ -54,7 +54,8 @@ void set_params_fprop(FMHA_fprop_params ¶ms,
...
@@ -54,7 +54,8 @@ void set_params_fprop(FMHA_fprop_params ¶ms,
void
*
softmax_lse_d
,
void
*
softmax_lse_d
,
float
p_dropout
,
float
p_dropout
,
float
softmax_scale
,
float
softmax_scale
,
bool
is_causal
)
{
bool
is_causal
,
int
num_splits
)
{
Data_type
acc_type
=
DATA_TYPE_FP32
;
Data_type
acc_type
=
DATA_TYPE_FP32
;
Data_type
data_type
=
!
(
q
.
dtype
()
==
torch
::
kBFloat16
)
?
DATA_TYPE_FP16
:
DATA_TYPE_BF16
;
Data_type
data_type
=
!
(
q
.
dtype
()
==
torch
::
kBFloat16
)
?
DATA_TYPE_FP16
:
DATA_TYPE_BF16
;
...
@@ -117,6 +118,7 @@ void set_params_fprop(FMHA_fprop_params ¶ms,
...
@@ -117,6 +118,7 @@ void set_params_fprop(FMHA_fprop_params ¶ms,
set_alpha
(
params
.
scale_dropout
,
params
.
rp_dropout
,
data_type
);
set_alpha
(
params
.
scale_dropout
,
params
.
rp_dropout
,
data_type
);
params
.
is_causal
=
is_causal
;
params
.
is_causal
=
is_causal
;
params
.
num_splits
=
num_splits
;
}
}
void
set_params_dgrad
(
FMHA_dgrad_params
&
params
,
void
set_params_dgrad
(
FMHA_dgrad_params
&
params
,
...
@@ -142,7 +144,8 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms,
...
@@ -142,7 +144,8 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms,
void
*
dsoftmax_sum_d
,
void
*
dsoftmax_sum_d
,
float
p_dropout
,
float
p_dropout
,
float
softmax_scale
,
float
softmax_scale
,
bool
is_causal
)
{
bool
is_causal
,
int
num_splits
)
{
set_params_fprop
(
params
,
set_params_fprop
(
params
,
b
,
seqlen_q
,
seqlen_k
,
h
,
d
,
b
,
seqlen_q
,
seqlen_k
,
h
,
d
,
...
@@ -154,7 +157,8 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms,
...
@@ -154,7 +157,8 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms,
softmax_lse_d
,
softmax_lse_d
,
p_dropout
,
p_dropout
,
softmax_scale
,
softmax_scale
,
is_causal
);
is_causal
,
num_splits
);
// Set the pointers and strides.
// Set the pointers and strides.
params
.
dq_ptr
=
dq
.
data_ptr
();
params
.
dq_ptr
=
dq
.
data_ptr
();
...
@@ -186,6 +190,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -186,6 +190,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const
bool
zero_tensors
,
const
bool
zero_tensors
,
const
bool
is_causal
,
const
bool
is_causal
,
const
bool
return_softmax
,
const
bool
return_softmax
,
const
int
num_splits
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
c10
::
optional
<
at
::
Generator
>
gen_
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
...
@@ -286,12 +291,14 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
...
@@ -286,12 +291,14 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
softmax_lse
.
data_ptr
(),
softmax_lse
.
data_ptr
(),
p_dropout
,
p_dropout
,
softmax_scale
,
softmax_scale
,
is_causal
);
is_causal
,
num_splits
);
run_fmha_fp16_sm80
(
launch_params
,
/*configure=*/
true
);
run_fmha_fp16_sm80
(
launch_params
,
/*configure=*/
true
);
// number of times random will be generated per thread, to offset philox counter in thc random
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// state
int64_t
counter_offset
=
launch_params
.
elts_per_thread
;
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t
counter_offset
=
launch_params
.
params
.
b
*
launch_params
.
params
.
h
*
32
;
at
::
PhiloxCudaState
rng_engine_inputs
;
at
::
PhiloxCudaState
rng_engine_inputs
;
if
(
is_dropout
)
{
if
(
is_dropout
)
{
...
@@ -440,7 +447,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
...
@@ -440,7 +447,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
softmax_d
.
data_ptr
(),
softmax_d
.
data_ptr
(),
p_dropout
,
p_dropout
,
softmax_scale
,
softmax_scale
,
is_causal
);
is_causal
,
/*num_splits=*/
1
);
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
...
@@ -560,7 +568,8 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t
...
@@ -560,7 +568,8 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t
softmax_lse
.
data_ptr
(),
softmax_lse
.
data_ptr
(),
p_dropout
,
p_dropout
,
softmax_scale
,
softmax_scale
,
is_causal
);
is_causal
,
/*num_splits=*/
1
);
launch_params
.
params
.
blockmask
=
static_cast
<
int
*>
(
blockmask
.
data_ptr
());
launch_params
.
params
.
blockmask
=
static_cast
<
int
*>
(
blockmask
.
data_ptr
());
run_fmha_block_fp16_sm80
(
launch_params
,
/*configure=*/
true
);
run_fmha_block_fp16_sm80
(
launch_params
,
/*configure=*/
true
);
...
@@ -706,7 +715,8 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size
...
@@ -706,7 +715,8 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size
softmax_d
.
data_ptr
(),
softmax_d
.
data_ptr
(),
p_dropout
,
p_dropout
,
softmax_scale
,
softmax_scale
,
is_causal
);
is_causal
,
/*num_splits=*/
1
);
params
.
blockmask
=
static_cast
<
int
*>
(
blockmask
.
data_ptr
());
params
.
blockmask
=
static_cast
<
int
*>
(
blockmask
.
data_ptr
());
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
...
...
csrc/flash_attn/src/fmha.h
View file @
a44f48df
...
@@ -127,6 +127,8 @@ struct FMHA_fprop_params : public Qkv_params {
...
@@ -127,6 +127,8 @@ struct FMHA_fprop_params : public Qkv_params {
bool
is_bf16
;
bool
is_bf16
;
bool
is_causal
;
bool
is_causal
;
int
num_splits
;
// How many SMs per attention matrix.
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
a44f48df
...
@@ -33,6 +33,32 @@
...
@@ -33,6 +33,32 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN.h"
// Find the number of splits that maximizes the occupancy. For example, if we have
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
// splits as that would incur more HBM reads/writes.
// So we find the best efficiency, then find the smallest number of splits that gets 95%
// of the best efficiency.
int
num_splits_heuristic_fwd
(
int
batch_nheads
,
int
num_SMs
,
int
ctas_per_sm
,
int
max_splits
)
{
float
max_efficiency
=
0.
f
;
std
::
vector
<
float
>
efficiency
;
efficiency
.
reserve
(
max_splits
);
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
float
n_waves
=
float
(
batch_nheads
*
num_splits
)
/
(
num_SMs
*
ctas_per_sm
);
float
eff
=
n_waves
/
ceil
(
n_waves
);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if
(
eff
>
max_efficiency
)
{
max_efficiency
=
eff
;
}
efficiency
.
push_back
(
eff
);
}
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
if
(
efficiency
[
num_splits
-
1
]
>
0.95
*
max_efficiency
)
{
// printf("num_splits chosen = %d\n", num_splits);
return
num_splits
;
}
}
return
1
;
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Return_softmax
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Return_softmax
>
__global__
void
fmha_fprop_fp16_sm80_loop_kernel
(
FMHA_fprop_params
params
)
{
__global__
void
fmha_fprop_fp16_sm80_loop_kernel
(
FMHA_fprop_params
params
)
{
fmha
::
device_1xN_loop
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
>
(
params
);
fmha
::
device_1xN_loop
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
>
(
params
);
...
@@ -75,7 +101,21 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
...
@@ -75,7 +101,21 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
}
dim3
grid
(
launch_params
.
params
.
b
,
launch_params
.
params
.
h
);
// Automatically set num_splits to maximize occupancy
if
(
launch_params
.
params
.
num_splits
<=
0
)
{
int
ctas_per_sm
;
cudaError
status_
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS
,
smem_size
);
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
// printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount);
constexpr
int
M
=
Kernel_traits
::
Cta_tile_p
::
M
;
launch_params
.
params
.
num_splits
=
num_splits_heuristic_fwd
(
launch_params
.
params
.
b
*
launch_params
.
params
.
h
,
dprops
->
multiProcessorCount
,
ctas_per_sm
,
/*max_splits=*/
std
::
min
(
30
,
(
launch_params
.
params
.
seqlen_q
+
M
-
1
/
M
))
);
}
dim3
grid
(
launch_params
.
params
.
b
,
launch_params
.
params
.
h
,
launch_params
.
params
.
num_splits
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
);
launch_params
.
params
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
...
@@ -103,10 +143,7 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
...
@@ -103,10 +143,7 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
if
(
launch_params
.
params
.
seqlen_k
==
256
)
{
}
else
if
(
launch_params
.
params
.
seqlen_k
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
}
...
...
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
a44f48df
...
@@ -197,7 +197,7 @@ constexpr size_t get_dynamic_smem_size(){
...
@@ -197,7 +197,7 @@ constexpr size_t get_dynamic_smem_size(){
}
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Return_softmax
,
bool
Is_first
,
bool
Is_last
,
typename
Params
,
typename
Prng
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Return_softmax
,
bool
Is_first
,
bool
Is_last
,
typename
Params
,
typename
Prng
>
inline
__device__
void
device_1xN_
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
int
begin
,
int
step
s
,
Prng
&
ph
,
const
int
loop_step_idx
)
{
inline
__device__
void
device_1xN_
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
int
steps
,
int
step
_stride
,
Prng
&
ph
,
const
int
loop_step_idx
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using
elem_type
=
typename
Kernel_traits
::
elem_type
;
using
elem_type
=
typename
Kernel_traits
::
elem_type
;
...
@@ -266,15 +266,23 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -266,15 +266,23 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// Wind gmem tiles to the correct position.
// Wind gmem tiles to the correct position.
static_assert
(
Cta_tile_p
::
N
%
Cta_tile_p
::
M
==
0
);
static_assert
(
Cta_tile_p
::
N
%
Cta_tile_p
::
M
==
0
);
const
int
begin_og
=
begin
;
int
begin
=
Is_causal
?
loop_step_idx
*
Cta_tile_p
::
N
/
Cta_tile_p
::
M
:
0
;
begin
=
Is_causal
?
std
::
max
(
begin
,
loop_step_idx
*
Cta_tile_p
::
N
/
Cta_tile_p
::
M
)
:
begin
;
// We want begin to be a multiple of gridDim.z
// This is because the row indices processed by each threadblock must align between the
// loop steps, otherwise we have a dependency between the blocks.
// For example, threadblock with blockIdx.z == 1 must process row indices that are
// k * gridDim.z + 1 for integer k.
const
int
begin_mod_z
=
begin
%
gridDim
.
z
;
begin
=
begin_mod_z
<=
blockIdx
.
z
?
begin
-
begin_mod_z
:
begin
+
gridDim
.
z
-
begin_mod_z
;
const
int
steps_og
=
steps
;
const
int
steps_og
=
steps
;
steps
-=
begin
-
begin_og
;
steps
-=
begin
;
gmem_q
.
move
(
begin
);
gmem_q
.
move
(
begin
+
blockIdx
.
z
);
gmem_o
.
move
(
begin
);
gmem_o
.
move
(
begin
+
blockIdx
.
z
);
gmem_o_tmp
.
move
(
begin
);
gmem_o_tmp
.
move
(
begin
+
blockIdx
.
z
);
if
(
Return_softmax
)
{
gmem_s
.
move
(
begin
);
}
if
(
Return_softmax
)
{
gmem_softmax_lse
.
move
(
begin
);
gmem_s
.
move
(
begin
+
blockIdx
.
z
);
}
gmem_softmax_lse
.
move
(
begin
+
blockIdx
.
z
);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("begin = %d, steps = %d\n", begin, steps);
// printf("begin = %d, steps = %d\n", begin, steps);
// }
// }
...
@@ -362,8 +370,11 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -362,8 +370,11 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
Smem_softmax_sum
smem_softmax_lse
(
reinterpret_cast
<
float
*>
(
&
smem_
[
Gemm1
::
SMEM_BYTES
]),
tidx
);
Smem_softmax_sum
smem_softmax_lse
(
reinterpret_cast
<
float
*>
(
&
smem_
[
Gemm1
::
SMEM_BYTES
]),
tidx
);
// Load over the entire sequence length.
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
steps
;
l
++
)
{
for
(
int
l
=
blockIdx
.
z
;
l
<
steps
;
l
+=
step_stride
)
{
if
((
begin
+
l
)
*
Cta_tile_p
::
M
>=
binfo
.
actual_seqlen_q
)
break
;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z <= 1)) {
// printf("l = %d\n", l);
// }
if
((
begin
+
l
)
*
Cta_tile_p
::
M
>=
binfo
.
actual_seqlen_q
)
break
;
// Declare the accumulators for the 1st gemm.
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
...
@@ -380,9 +391,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -380,9 +391,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
if
(
!
Is_first
)
{
gmem_o_tmp
.
load
(
out
,
0
);
}
if
(
!
Is_first
)
{
gmem_o_tmp
.
load
(
out
,
0
);
}
// Trigger the load for the next Q values.
// Trigger the load for the next Q values.
if
(
l
<
step
s
-
1
)
{
if
(
l
+
step
_stride
<
steps
)
{
gemm_q_k
.
smem_q
.
move_to_next_write_buffer
();
gemm_q_k
.
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
move
(
step_stride
);
gmem_q
.
load
();
gmem_q
.
load
();
}
}
...
@@ -395,27 +406,28 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -395,27 +406,28 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// Apply the mask.
// Apply the mask.
softmax
.
apply_mask
(
mask
);
softmax
.
apply_mask
(
mask
);
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
&&
l
==
0
)
{
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
&&
l
<
step_stride
)
{
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads
();
__syncthreads
();
}
}
// if (!Is_first) {
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l
=
= 0)) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l
>
= 0)) {
// printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]);
// printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]);
// }
// }
// }
// }
// Compute the max.
// Compute the max.
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
if
(
!
Is_first
)
{
if
(
!
Is_first
)
{
smem_softmax_lse
.
store_pair
(
p_prev_lse
,
l
%
2
);
smem_softmax_lse
.
store_pair
(
p_prev_lse
);
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; }
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; }
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
*
2
;
mi
++
)
{
p_max
[
mi
]
=
p_prev_lse
[
mi
]
/
params
.
scale_bmm1f
;
}
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
*
2
;
mi
++
)
{
p_max
[
mi
]
=
p_prev_lse
[
mi
]
/
params
.
scale_bmm1f
;
}
}
}
// Trigger the load for the next LSE values.
// Trigger the load for the next LSE values.
if
(
l
<
step
s
-
1
)
{
if
(
l
+
step
_stride
<
steps
)
{
if
(
!
Is_first
)
{
if
(
!
Is_first
)
{
gmem_softmax_lse
.
load_next
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
p_prev_lse
));
gmem_softmax_lse
.
load_next
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
p_prev_lse
),
step_stride
);
}
}
}
}
...
@@ -490,11 +502,11 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -490,11 +502,11 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
softmax
.
template
pack
<
elem_type
>(
frag_p
);
softmax
.
template
pack
<
elem_type
>(
frag_p
);
if
(
Return_softmax
)
{
if
(
Return_softmax
)
{
gmem_s
.
store
(
frag_p
,
mask
);
gmem_s
.
store
(
frag_p
,
mask
);
gmem_s
.
move
();
gmem_s
.
move
(
step_stride
);
}
}
// Commit the values for Q into shared memory.
// Commit the values for Q into shared memory.
if
(
l
<
step
s
-
1
)
{
if
(
l
+
step
_stride
<
steps
)
{
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
}
}
...
@@ -548,7 +560,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -548,7 +560,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
}
}
float
p_prev_scale_o
[
Gmem_tile_o
::
STGS_PER_LOOP
];
float
p_prev_scale_o
[
Gmem_tile_o
::
STGS_PER_LOOP
];
if
((
!
Is_first
)
&&
o_rows_are_valid
)
{
if
((
!
Is_first
)
&&
o_rows_are_valid
)
{
smem_softmax_lse
.
load
(
p_prev_scale_o
,
rows
,
l
%
2
);
smem_softmax_lse
.
load
(
p_prev_scale_o
,
rows
);
}
}
// if (!Is_first) {
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
...
@@ -594,7 +606,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -594,7 +606,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
]
>
(
p_sum_log
[
jj
]),
rows
[
jj
]);
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
]
>
(
p_sum_log
[
jj
]),
rows
[
jj
]);
}
}
}
}
gmem_softmax_lse
.
move
();
gmem_softmax_lse
.
move
(
step_stride
);
// Load from shared memory.
// Load from shared memory.
if
(
!
Is_first
)
{
if
(
!
Is_first
)
{
...
@@ -627,22 +639,21 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -627,22 +639,21 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// Output the values.
// Output the values.
if
(
is_final_write
)
{
if
(
is_final_write
)
{
gmem_o
.
template
store
<
elem_type
>(
out
,
0
);
gmem_o
.
template
store
<
elem_type
>(
out
,
0
);
gmem_o
.
move
();
gmem_o
.
move
(
step_stride
);
}
else
{
}
else
{
gmem_o_tmp
.
store
(
out
,
0
);
gmem_o_tmp
.
store
(
out
,
0
);
}
}
// Move to the next part of the output.
// Move to the next part of the output.
if
(
!
(
Is_first
&&
Is_last
))
{
gmem_o_tmp
.
move
();
}
if
(
!
(
Is_first
&&
Is_last
))
{
gmem_o_tmp
.
move
(
step_stride
);
}
gemm_q_k
.
reload_k
();
gemm_q_k
.
reload_k
();
// Make sure we are reading from the correct buffer.
// Make sure we are reading from the correct buffer.
gemm_q_k
.
smem_q
.
move_to_next_read_buffer
();
gemm_q_k
.
smem_q
.
move_to_next_read_buffer
();
// Trigger the load from shared memory for the next series of Q values.
// Trigger the load from shared memory for the next series of Q values.
if
(
l
<
step
s
-
1
)
{
if
(
l
+
step
_stride
<
steps
)
{
gemm_q_k
.
reload_q
();
gemm_q_k
.
reload_q
();
}
}
}
// Outer loop over the sequence length.
}
// Outer loop over the sequence length.
}
}
...
@@ -672,14 +683,14 @@ inline __device__ void device_1xN_loop(const Params ¶ms) {
...
@@ -672,14 +683,14 @@ inline __device__ void device_1xN_loop(const Params ¶ms) {
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
if
(
params
.
seqlen_k
==
blocksize_c
)
{
if
(
params
.
seqlen_k
==
blocksize_c
)
{
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
true
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph
,
0
);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
true
>
(
params
,
bidb
,
bidh
,
STEPS
,
gridDim
.
z
,
ph
,
0
);
}
else
{
}
else
{
const
int
max_loop_steps
=
(
params
.
seqlen_k
+
blocksize_c
-
1
)
/
blocksize_c
;
const
int
max_loop_steps
=
(
params
.
seqlen_k
+
blocksize_c
-
1
)
/
blocksize_c
;
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
false
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph
,
0
);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
false
>
(
params
,
bidb
,
bidh
,
STEPS
,
gridDim
.
z
,
ph
,
0
);
for
(
int
loop_step_idx
=
1
;
loop_step_idx
<
max_loop_steps
-
1
;
loop_step_idx
++
)
{
for
(
int
loop_step_idx
=
1
;
loop_step_idx
<
max_loop_steps
-
1
;
loop_step_idx
++
)
{
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
false
,
false
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph
,
loop_step_idx
);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
false
,
false
>
(
params
,
bidb
,
bidh
,
STEPS
,
gridDim
.
z
,
ph
,
loop_step_idx
);
}
}
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
false
,
true
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph
,
max_loop_steps
-
1
);
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
false
,
true
>
(
params
,
bidb
,
bidh
,
STEPS
,
gridDim
.
z
,
ph
,
max_loop_steps
-
1
);
}
}
}
}
...
...
flash_attn/flash_attn_interface.py
View file @
a44f48df
...
@@ -14,10 +14,16 @@ def _get_block_size(device, head_dim, is_dropout):
...
@@ -14,10 +14,16 @@ def _get_block_size(device, head_dim, is_dropout):
def
_flash_attn_forward
(
q
,
k
,
v
,
out
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
def
_flash_attn_forward
(
q
,
k
,
v
,
out
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
,
generator
=
None
):
dropout_p
,
softmax_scale
,
causal
,
return_softmax
,
num_splits
=
0
,
generator
=
None
):
"""
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means
it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking.
Don't change it unless you know what you're doing.
"""
softmax_lse
,
*
rest
=
flash_attn_cuda
.
fwd
(
softmax_lse
,
*
rest
=
flash_attn_cuda
.
fwd
(
q
,
k
,
v
,
out
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
q
,
k
,
v
,
out
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
False
,
causal
,
return_softmax
,
generator
softmax_scale
,
False
,
causal
,
return_softmax
,
num_splits
,
generator
)
)
# if out.isnan().any() or softmax_lse.isnan().any():
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
# breakpoint()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment