Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
de19de7a
Commit
de19de7a
authored
Jul 09, 2022
by
Tri Dao
Browse files
Implement for bf16
parent
6a77a6da
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
329 additions
and
262 deletions
+329
-262
README.md
README.md
+2
-2
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+17
-13
csrc/flash_attn/src/fmha.h
csrc/flash_attn/src/fmha.h
+1
-0
csrc/flash_attn/src/fmha/kernel_traits.h
csrc/flash_attn/src/fmha/kernel_traits.h
+5
-1
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
+94
-92
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
+29
-21
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+122
-120
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+25
-12
csrc/flash_attn/src/fmha_utils.h
csrc/flash_attn/src/fmha_utils.h
+9
-1
csrc/flash_attn/src/static_switch.h
csrc/flash_attn/src/static_switch.h
+25
-0
No files found.
README.md
View file @
de19de7a
...
...
@@ -31,12 +31,12 @@ Our tentative roadmap:
2.
~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done].
3.
[Jun 2022] Refactor to use Cutlass.
4.
~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].
5.
[Jun 2022] Support bf16.
5.
~~
[Jun 2022] Support bf16
~~[Done]
.
6.
~~[Jul 2022] Implement cross-attention~~[Done].
7.
~~[Jul 2022] Support head dimension 128~~[Done].
8.
[Jul 2022] Support SM70 GPUs (V100).
9.
[Aug 2022] Fuse rotary embedding.
10.
[Aug 2022] Support
A
ttention
linear
bias (e.g. ALiBi).
10.
[Aug 2022] Support
a
ttention bias (e.g. ALiBi
, relative positional encoding
).
## Speedup and Memory Savings
...
...
csrc/flash_attn/fmha_api.cpp
View file @
de19de7a
...
...
@@ -56,11 +56,13 @@ void set_params_fprop(FMHA_fprop_params ¶ms,
bool
is_causal
)
{
Data_type
acc_type
=
DATA_TYPE_FP32
;
Data_type
data_type
=
DATA_TYPE_FP
16
;
Data_type
data_type
=
!
(
q
.
dtype
()
==
torch
::
kBFloat16
)
?
DATA_TYPE_FP16
:
DATA_TYPE_BF
16
;
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
params
.
is_bf16
=
q
.
dtype
()
==
torch
::
kBFloat16
;
// Set the pointers and strides.
params
.
q_ptr
=
q
.
data_ptr
();
params
.
k_ptr
=
k
.
data_ptr
();
...
...
@@ -192,9 +194,10 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
bool
is_dropout
=
p_dropout
>
0.0
;
Launch_params
<
FMHA_fprop_params
>
launch_params
(
dprops
,
stream
,
is_dropout
,
return_softmax
);
TORCH_CHECK
(
q
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
k
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
v
.
dtype
()
==
torch
::
kFloat16
);
auto
q_dtype
=
q
.
dtype
();
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
(
is_sm8x
&&
q_dtype
==
torch
::
kBFloat16
));
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
);
TORCH_CHECK
(
cu_seqlens_q
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
cu_seqlens_k
.
dtype
()
==
torch
::
kInt32
);
...
...
@@ -326,14 +329,15 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
bool
is_dropout
=
p_dropout
>
0.0
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
q
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
k
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
v
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
dout
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
dq
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
dk
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
dv
.
dtype
()
==
torch
::
kFloat16
);
auto
q_dtype
=
q
.
dtype
();
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
(
is_sm8x
&&
q_dtype
==
torch
::
kBFloat16
));
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
);
TORCH_CHECK
(
out
.
dtype
()
==
q_dtype
);
TORCH_CHECK
(
dout
.
dtype
()
==
q_dtype
);
TORCH_CHECK
(
dq
.
dtype
()
==
q_dtype
);
TORCH_CHECK
(
dk
.
dtype
()
==
q_dtype
);
TORCH_CHECK
(
dv
.
dtype
()
==
q_dtype
);
TORCH_CHECK
(
cu_seqlens_q
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
cu_seqlens_k
.
dtype
()
==
torch
::
kInt32
);
...
...
csrc/flash_attn/src/fmha.h
View file @
de19de7a
...
...
@@ -123,6 +123,7 @@ struct FMHA_fprop_params : public Qkv_params {
// Random state.
at
::
PhiloxCudaState
philox_args
;
bool
is_bf16
;
bool
is_causal
;
};
...
...
csrc/flash_attn/src/fmha/kernel_traits.h
View file @
de19de7a
...
...
@@ -25,11 +25,13 @@
*
******************************************************************************/
#include <cuda_fp16.h>
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
S
,
int
D
,
int
STEP
,
int
WARPS_M
,
int
WARPS_N
,
uint32_t
FLAGS
=
0x08u
>
template
<
int
S
,
int
D
,
int
STEP
,
int
WARPS_M
,
int
WARPS_N
,
uint32_t
FLAGS
=
0x08u
,
typename
elem_type_
=
__half
>
struct
FMHA_kernel_traits
{
// The CTA description for the 1st GEMM.
...
...
@@ -80,6 +82,8 @@ struct FMHA_kernel_traits {
// The shared memory tile to store dp sum.
using
Smem_dp_sum
=
fmha
::
Smem_tile_dp_sum
<
Gmem_tile_q
,
2
>
;
using
elem_type
=
elem_type_
;
// Make sure the number of threads match.
static_assert
((
int
)
Gmem_tile_o
::
THREADS_PER_ROW
==
(
int
)
Smem_tile_o
::
THREADS_PER_ROW
,
""
);
...
...
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
View file @
de19de7a
/* Copyright (c) 2022, Tri Dao.
*/
#include "static_switch.h"
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_loop.h"
...
...
@@ -22,24 +23,21 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_
static_assert
(
smem_size_dq
==
16
*
Kernel_traits
::
Cta_tile_p
::
K
*
4
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
);
constexpr
int
smem_size_dq_dk_dv
=
smem_size_q
*
2
+
smem_size_v
*
(
Kernel_traits
::
V_IN_REGS
?
1
:
2
)
+
smem_size_dq
+
smem_size_s
*
2
;
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
// printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
bool
is_dropout
=
params
.
p_dropout
<
1.
f
;
// params.p_dropout is the probability of "keeping"
bool
is_causal
=
params
.
is_causal
;
auto
kernel
=
is_dropout
?
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
>
)
:
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
>
);
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
BOOL_SWITCH
(
is_dropout
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
IsCausalConst
,
[
&
]
{
auto
kernel
=
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
IsCausalConst
>
;
if
(
params
.
seqlen_k
==
blocksize_c
)
{
kernel
=
is_dropout
?
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
,
/*loop_steps=*/
1
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
,
/*loop_steps=*/
1
>
)
:
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
,
/*loop_steps=*/
1
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
,
/*loop_steps=*/
1
>
);
kernel
=
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
IsCausalConst
,
/*loop_steps=*/
1
>
;
}
else
if
(
params
.
seqlen_k
==
blocksize_c
*
2
)
{
kernel
=
is_dropout
?
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
,
/*loop_steps=*/
2
>
)
:
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
,
/*loop_steps=*/
2
>
);
kernel
=
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
IsCausalConst
,
/*loop_steps=*/
2
>
;
}
// printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
...
...
@@ -47,63 +45,66 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_
dim3
grid
(
params
.
b
,
params
.
h
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
});
});
}
void
run_fmha_dgrad_fp16_sm80
(
const
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
)
{
BOOL_SWITCH
(
params
.
is_bf16
,
IsBf16Const
,
[
&
]
{
using
elem_type
=
std
::
conditional
<
IsBf16Const
,
__nv_bfloat16
,
__half
>::
type
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
params
.
d
==
16
)
{
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
16
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
16
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen_k
==
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
{
// TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u>;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
8
,
0x08u
>
;
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u
, elem_type
>;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
}
else
if
(
params
.
d
==
32
)
{
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen_k
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
}
else
if
(
params
.
d
==
64
)
{
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
if
(
params
.
seqlen_k
>=
256
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
)
{
// Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it
// uses more shared memory, which is fine on A100 but not other GPUs.
// For other GPUs, we keep V in registers.
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
8
,
0x100u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
8
,
0x100u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
if
(
dprops
->
major
==
8
&&
dprops
->
minor
>
0
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
if
(
dprops
->
major
==
7
&&
dprops
->
minor
==
5
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
}
}
else
if
(
params
.
d
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
8
,
0x100u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
8
,
0x100u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
// if (params.d == 64) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 7 && dprops->minor == 5) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u
, elem_type
>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else {
// if( params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u
, elem_type
>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if( params.seqlen_k >= 256 ) {
// if (dprops->major == 8 && dprops->minor == 0) {
...
...
@@ -111,17 +112,18 @@ void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stre
// // This speeds things up by 2-3% by avoiding register spills, but it
// // uses more shared memory, which is fine on A100 but not other GPUs.
// // For other GPUs, we keep V in registers.
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u
, elem_type
>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if (dprops->major == 8 && dprops->minor > 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u
, elem_type
>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// }
// }
// }
// }
// if (params.d == 128) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u
_elem_type
>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// }
});
}
\ No newline at end of file
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
View file @
de19de7a
...
...
@@ -35,6 +35,14 @@ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_first,
inline
__device__
void
compute_dq_dk_dv_1xN_one_iter
(
const
Params
&
params
,
Prng
&
ph
,
const
int
loop_step_idx
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using
elem_type
=
typename
Kernel_traits
::
elem_type
;
#else
constexpr
bool
is_fp16_type
=
std
::
is_same
<
typename
Kernel_traits
::
elem_type
,
__half
>::
value
;
assert
(
is_fp16_type
);
using
elem_type
=
__half
;
#endif
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
...
...
@@ -106,7 +114,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
using
Gmem_softmax_sum
=
typename
Kernel_traits
::
Gmem_softmax_sum
;
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using
Gemm1
=
Gemm_Q_K
<
Kernel_traits
,
/*K-in_regs=*/
false
>
;
using
Gemm1
=
Gemm_Q_K
<
Kernel_traits
,
/*K-in_regs=*/
false
,
elem_type
>
;
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
...
...
@@ -214,7 +222,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
gmem_do
.
commit
(
smem_do
);
if
(
Is_first
)
{
dot_do_o
<
Gmem_tile_do
::
ROWS
,
Gmem_tile_do
::
THREADS_PER_ROW
,
__half
>
(
dot_do_o
<
Gmem_tile_do
::
ROWS
,
Gmem_tile_do
::
THREADS_PER_ROW
,
elem_type
>
(
gmem_do
.
fetch_
,
gmem_o
.
fetch_
,
params
.
p_dropout
,
gmem_softmax_d
,
tidx
);
}
...
...
@@ -333,7 +341,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
Frag_p
frag_p
[
Mma_tile_dq
::
MMAS_K
][
Mma_tile_dq
::
MMAS_M
];
static_assert
(
Mma_tile_dq
::
MMAS_M
==
Mma_tile_p
::
MMAS_M
);
static_assert
(
Mma_tile_dq
::
MMAS_K
==
Mma_tile_p
::
MMAS_N
);
softmax
.
template
pack
<
__half
>(
frag_p
);
softmax
.
template
pack
<
elem_type
>(
frag_p
);
// Store s * dmask to smem for transpose
smem_s
.
store
(
frag_p
);
...
...
@@ -369,9 +377,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
smem_do
.
load
(
frag_do
[
ki
&
1
],
ki
);
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
smem_v
.
load
(
frag_v
[
ki
&
1
],
ki
);
fmha
::
gemm_cl
<
__half
>
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
elem_type
>
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
}
else
{
fmha
::
gemm_cl
<
__half
>
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[
ki
-
1
]);
fmha
::
gemm_cl
<
elem_type
>
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[
ki
-
1
]);
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
...
...
@@ -385,9 +393,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
fmha
::
gemm_cl
<
__half
>
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
elem_type
>
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
}
else
{
fmha
::
gemm_cl
<
__half
>
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)]);
fmha
::
gemm_cl
<
elem_type
>
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)]);
}
}
...
...
@@ -424,7 +432,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
}
}
softmax
.
template
pack
<
__half
>(
frag_p
);
softmax
.
template
pack
<
elem_type
>(
frag_p
);
// Store dp to smem for transpose
smem_dp
.
store
(
frag_p
);
...
...
@@ -442,14 +450,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// Trigger the load from shared memory for the next series of Q values.
smem_kt
.
load
(
frag_kt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm_cl
<
__half
>
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm_cl<
__half
>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
fmha
::
gemm_cl
<
elem_type
>
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm_cl<
elem_type
>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dq
::
MMAS_K
;
fmha
::
gemm_cl
<
__half
>
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm_cl<
__half
>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
fmha
::
gemm_cl
<
elem_type
>
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm_cl<
elem_type
>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
static_assert
(
Gmem_tile_dq
::
LOOPS
==
1
);
...
...
@@ -475,7 +483,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
for
(
int
ki
=
0
;
ki
<
Mma_tile_dkv
::
MMAS_K
;
ki
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_dkv
::
MMAS_M
;
mi
++
)
{
frag_s
[
ki
][
mi
].
template
hrelu_
<
__half
>();
frag_s
[
ki
][
mi
].
template
hrelu_
<
elem_type
>();
}
}
}
...
...
@@ -485,13 +493,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// Trigger the load from shared memory for the next series of Q values.
smem_dot
.
load
(
frag_dot
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm_cl
<
__half
>
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
elem_type
>
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dkv
::
MMAS_K
;
fmha
::
gemm_cl
<
__half
>
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
elem_type
>
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
...
...
@@ -519,7 +527,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
if
(
l
<
steps
-
1
)
{
gmem_do
.
commit
(
smem_do
);
if
(
Is_first
)
{
dot_do_o
<
Gmem_tile_do
::
ROWS
,
Gmem_tile_do
::
THREADS_PER_ROW
,
__half
>
(
dot_do_o
<
Gmem_tile_do
::
ROWS
,
Gmem_tile_do
::
THREADS_PER_ROW
,
elem_type
>
(
gmem_do
.
fetch_
,
gmem_o
.
fetch_
,
params
.
p_dropout
,
gmem_softmax_d
,
tidx
);
}
...
...
@@ -542,13 +550,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// Trigger the load from shared memory for the next series of Q values.
smem_qt
.
load
(
frag_qt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm_cl
<
__half
>
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
elem_type
>
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dkv
::
MMAS_K
;
fmha
::
gemm_cl
<
__half
>
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
elem_type
>
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Make sure dQ is in shared memory.
...
...
@@ -575,7 +583,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
dq_out
[
jj
]
=
fmha
::
fmul4
(
dq_out
[
jj
],
params
.
scale_bmm1_rp_dropout
);
}
// Output the values.
gmem_dq
.
template
store
<
__half
>(
dq_out
,
0
);
gmem_dq
.
template
store
<
elem_type
>(
dq_out
,
0
);
// Move to the next part of the output.
gmem_dq
.
move
();
}
else
{
...
...
@@ -629,11 +637,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// the total amount of shared mem?
// Epilogue swizzle for dV
Smem_tile_dv
smem_dv
(
&
smem_
[
0
],
tidx
);
smem_dv
.
template
store
<
__half
>(
acc_dv
);
smem_dv
.
template
store
<
elem_type
>(
acc_dv
);
// Epilogue swizzle for dK
Smem_tile_dk
smem_dk
(
&
smem_
[
Smem_tile_dv
::
BYTES_PER_TILE
],
tidx
);
smem_dk
.
template
store
<
__half
>(
acc_dk
);
smem_dk
.
template
store
<
elem_type
>(
acc_dk
);
__syncthreads
();
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
...
...
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
de19de7a
...
...
@@ -25,6 +25,10 @@
*
******************************************************************************/
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "static_switch.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
...
...
@@ -36,26 +40,8 @@ __global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) {
template
<
typename
Kernel_traits
>
void
run_fmha_fp16_sm80_loop_
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
bool
is_causal
=
launch_params
.
params
.
is_causal
;
// TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way?
auto
kernel
=
launch_params
.
is_dropout
?
(
is_causal
?
(
launch_params
.
return_softmax
?
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
true
,
true
,
true
>
:
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
true
,
true
,
false
>
)
:
(
launch_params
.
return_softmax
?
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
true
,
false
,
true
>
:
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
true
,
false
,
false
>
))
:
(
is_causal
?
(
launch_params
.
return_softmax
?
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
true
,
true
>
:
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
true
,
false
>
)
:
(
launch_params
.
return_softmax
?
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
false
,
true
>
:
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
false
,
false
>
));
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
const
int
loop_steps
=
(
launch_params
.
params
.
seqlen_k
+
blocksize_c
-
1
)
/
blocksize_c
;
constexpr
int
smem_size_softmax_lse
=
Kernel_traits
::
Smem_dp_sum
::
BYTES_PER_TILE
;
// Don't need smem_size_softmax_lse if we're not looping
const
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
()
+
(
loop_steps
>
1
?
smem_size_softmax_lse
:
0
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
if
(
configure
)
{
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
typename
Kernel_traits
::
Cta_tile_p
>
;
...
...
@@ -68,117 +54,133 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
return
;
}
constexpr
int
smem_size_softmax_lse
=
Kernel_traits
::
Smem_dp_sum
::
BYTES_PER_TILE
;
// Don't need smem_size_softmax_lse if we're not looping
const
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
()
+
(
loop_steps
>
1
?
smem_size_softmax_lse
:
0
);
BOOL_SWITCH
(
launch_params
.
is_dropout
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
launch_params
.
params
.
is_causal
,
IsCausalConst
,
[
&
]
{
BOOL_SWITCH
(
launch_params
.
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
auto
kernel
=
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
IsCausalConst
,
ReturnSoftmaxConst
>
;
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
launch_params
.
params
.
b
,
launch_params
.
params
.
h
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
});
});
});
}
void
run_fmha_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
BOOL_SWITCH
(
launch_params
.
params
.
is_bf16
,
IsBf16Const
,
[
&
]
{
using
elem_type
=
std
::
conditional
<
IsBf16Const
,
__nv_bfloat16
,
__half
>::
type
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
launch_params
.
params
.
d
==
16
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
16
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
16
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
if
(
launch_params
.
params
.
seqlen_k
==
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
// TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u>;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
4
,
0x08u
>
;
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u
, elem_type
>;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
}
else
if
(
launch_params
.
params
.
d
==
32
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
if
(
launch_params
.
params
.
seqlen_k
==
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
}
else
if
(
launch_params
.
params
.
d
==
64
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
if
(
launch_params
.
params
.
seqlen_k
>=
256
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
if
(
dprops
->
major
==
7
&&
dprops
->
minor
==
5
)
{
if
(
launch_params
.
is_dropout
)
{
// Need to use the same block size as backward
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
}
}
}
else
if
(
launch_params
.
params
.
d
==
128
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
&&
!
launch_params
.
is_dropout
)
{
// TD [2022-06-05] Keep K in registers to reduce register spilling
// Gives about 6% speedup compared to using block size 128.
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
128
,
16
,
1
,
4
,
0x18u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
128
,
16
,
1
,
4
,
0x18u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
// Need to use the same block size as backward
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
}
}
// if (launch_params.params.d == 64) {
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>;
// // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u
, elem_type
>;
// // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u
, elem_type
>;
// // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u
, elem_type
>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u
, elem_type
>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// if (launch_params.params.d == 64) {
// if( launch_params.params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u
, elem_type
>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if( launch_params.params.seqlen_k >= 256 ) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 8 && dprops->minor >= 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u
, elem_type
>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if (dprops->major == 7 && dprops->minor == 5) {
//
//
if (launch_params.is_dropout) { // Need to use the same block size as backward
//
//
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
//
//
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
//
//
} else {
//
//
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
//
//
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
//
//
}
// if (launch_params.is_dropout) { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u
, elem_type
>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u
, elem_type
>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// }
// }
// }
// if (launch_params.params.d == 128) {
// if( launch_params.params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u
, elem_type
>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
// // TD [2022-06-05] Keep K in registers to reduce register spilling
// // Gives about 6% speedup compared to using block size 128.
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>;
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u
, elem_type
>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u
, elem_type
>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// }
// }
});
}
\ No newline at end of file
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
de19de7a
...
...
@@ -72,7 +72,7 @@ struct Gemm_Q_K_base {
Smem_tile_k
smem_k
;
};
template
<
typename
Kernel_traits
,
bool
K_in_regs
>
template
<
typename
Kernel_traits
,
bool
K_in_regs
,
typename
elem_type_
=
__half
>
struct
Gemm_Q_K
:
public
Gemm_Q_K_base
<
Kernel_traits
>
{
using
Base
=
Gemm_Q_K_base
<
Kernel_traits
>
;
...
...
@@ -81,6 +81,7 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
using
Smem_tile_k
=
typename
Base
::
Smem_tile_k
;
using
Fragment_k
=
typename
Base
::
Fragment_k
;
using
Mma_tile_p
=
typename
Base
::
Mma_tile_p
;
using
elem_type
=
elem_type_
;
static
constexpr
bool
SHARE_SMEM_FOR_K_AND_V
=
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
;
// If V is stored in shared memory, we can't load K using the same shared memory.
...
...
@@ -115,12 +116,12 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
// Trigger the load from shared memory for the next series of Q values.
Base
::
smem_q
.
load
(
Base
::
frag_q
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm_cl
<
__half
>
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
fmha
::
gemm_cl
<
elem_type
>
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm_cl
<
__half
>
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
fmha
::
gemm_cl
<
elem_type
>
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
}
...
...
@@ -132,8 +133,8 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
};
template
<
typename
Kernel_traits
>
struct
Gemm_Q_K
<
Kernel_traits
,
false
>
:
public
Gemm_Q_K_base
<
Kernel_traits
>
{
template
<
typename
Kernel_traits
,
typename
elem_type_
>
struct
Gemm_Q_K
<
Kernel_traits
,
false
,
elem_type_
>
:
public
Gemm_Q_K_base
<
Kernel_traits
>
{
using
Base
=
Gemm_Q_K_base
<
Kernel_traits
>
;
using
Smem_tile_o
=
typename
Base
::
Smem_tile_o
;
using
Smem_tile_q
=
typename
Base
::
Smem_tile_q
;
...
...
@@ -141,6 +142,7 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
using
Fragment_k
=
typename
Base
::
Fragment_k
;
using
Mma_tile_p
=
typename
Base
::
Mma_tile_p
;
using
elem_type
=
elem_type_
;
Fragment_k
frag_k
[
2
][
Mma_tile_p
::
MMAS_N
];
static
constexpr
bool
SHARE_SMEM_FOR_K_AND_V
=
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
;
...
...
@@ -175,12 +177,12 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
Base
::
smem_q
.
load
(
Base
::
frag_q
[
ki
&
1
],
ki
);
Base
::
smem_k
.
load
(
frag_k
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm_cl
<
__half
>
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
elem_type
>
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm_cl
<
__half
>
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
elem_type
>
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
}
...
...
@@ -197,6 +199,13 @@ constexpr size_t get_dynamic_smem_size(){
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Return_softmax
,
bool
Is_first
,
bool
Is_last
,
typename
Params
,
typename
Prng
>
inline
__device__
void
device_1xN_
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
int
begin
,
int
steps
,
Prng
&
ph0
,
Prng
&
ph1
,
const
int
loop_step_idx
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using
elem_type
=
typename
Kernel_traits
::
elem_type
;
#else
constexpr
bool
is_fp16_type
=
std
::
is_same
<
typename
Kernel_traits
::
elem_type
,
__half
>::
value
;
assert
(
is_fp16_type
);
using
elem_type
=
__half
;
#endif
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
...
...
@@ -231,7 +240,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
using
Smem_softmax_sum
=
typename
Kernel_traits
::
Smem_dp_sum
;
using
Gemm1
=
Gemm_Q_K
<
Kernel_traits
,
Kernel_traits
::
K_IN_REGS
>
;
using
Gemm1
=
Gemm_Q_K
<
Kernel_traits
,
Kernel_traits
::
K_IN_REGS
,
elem_type
>
;
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
...
...
@@ -363,6 +372,10 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// Do this part of P = Q * K^T.
gemm_q_k
(
acc_p
);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1));
// }
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
if
(
!
Is_first
)
{
gmem_o_tmp
.
load
(
out
,
0
);
}
...
...
@@ -466,7 +479,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
static_assert
(
Mma_tile_o
::
MMAS_M
==
Mma_tile_p
::
MMAS_M
);
static_assert
(
Mma_tile_o
::
MMAS_K
==
Mma_tile_p
::
MMAS_N
);
softmax
.
template
pack
<
__half
>(
frag_p
);
softmax
.
template
pack
<
elem_type
>(
frag_p
);
if
(
Return_softmax
)
{
gmem_s
.
store
(
frag_p
,
mask
);
gmem_s
.
move
();
...
...
@@ -482,7 +495,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
frag_p
[
ki
][
mi
].
template
hrelu_
<
__half
>();
frag_p
[
ki
][
mi
].
template
hrelu_
<
elem_type
>();
}
}
}
...
...
@@ -494,7 +507,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
fmha
::
gemm_cl
<
__half
>
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
fmha
::
gemm_cl
<
elem_type
>
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
// if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki]));
// float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki]));
...
...
@@ -605,7 +618,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// Output the values.
if
(
is_final_write
)
{
gmem_o
.
template
store
<
__half
>(
out
,
0
);
gmem_o
.
template
store
<
elem_type
>(
out
,
0
);
gmem_o
.
move
();
}
else
{
gmem_o_tmp
.
store
(
out
,
0
);
...
...
csrc/flash_attn/src/fmha_utils.h
View file @
de19de7a
...
...
@@ -32,6 +32,7 @@
#include <stdlib.h>
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
@@ -50,7 +51,7 @@
////////////////////////////////////////////////////////////////////////////////////////////////////
enum
Data_type
{
DATA_TYPE_FP16
,
DATA_TYPE_FP32
,
DATA_TYPE_INT32
,
DATA_TYPE_INT8
};
enum
Data_type
{
DATA_TYPE_FP16
,
DATA_TYPE_BF16
,
DATA_TYPE_FP32
,
DATA_TYPE_INT32
,
DATA_TYPE_INT8
};
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
@@ -60,6 +61,11 @@ static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) {
uint16_t
h
=
reinterpret_cast
<
const
uint16_t
&>
(
x
);
ushort2
h2
=
{
h
,
h
};
alpha
=
reinterpret_cast
<
const
uint32_t
&>
(
h2
);
}
else
if
(
dtype
==
DATA_TYPE_BF16
)
{
__nv_bfloat16
x
=
__float2bfloat16
(
norm
);
uint16_t
h
=
reinterpret_cast
<
const
uint16_t
&>
(
x
);
ushort2
h2
=
{
h
,
h
};
alpha
=
reinterpret_cast
<
const
uint32_t
&>
(
h2
);
}
else
if
(
dtype
==
DATA_TYPE_FP32
)
{
alpha
=
reinterpret_cast
<
const
uint32_t
&>
(
norm
);
}
else
if
(
dtype
==
DATA_TYPE_INT32
)
{
...
...
@@ -78,6 +84,8 @@ static inline size_t get_size_in_bytes( size_t n, Data_type dtype ) {
return
n
*
4
;
case
DATA_TYPE_FP16
:
return
n
*
2
;
case
DATA_TYPE_BF16
:
return
n
*
2
;
case
DATA_TYPE_INT32
:
return
n
*
4
;
case
DATA_TYPE_INT8
:
...
...
csrc/flash_attn/src/static_switch.h
0 → 100644
View file @
de19de7a
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment