Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
1fcbe6f0
Commit
1fcbe6f0
authored
May 20, 2022
by
Tri Dao
Browse files
First release
parents
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
3496 additions
and
0 deletions
+3496
-0
csrc/stream_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
csrc/stream_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
+93
-0
csrc/stream_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/stream_attn/src/fmha_dgrad_kernel_1xN_loop.h
+734
-0
csrc/stream_attn/src/fmha_dgrad_kernel_1xN_reload_recompute.h
.../stream_attn/src/fmha_dgrad_kernel_1xN_reload_recompute.h
+855
-0
csrc/stream_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/stream_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+127
-0
csrc/stream_attn/src/fmha_fprop_kernel_1xN.h
csrc/stream_attn/src/fmha_fprop_kernel_1xN.h
+646
-0
csrc/stream_attn/src/fmha_kernel.h
csrc/stream_attn/src/fmha_kernel.h
+179
-0
csrc/stream_attn/src/fmha_utils.h
csrc/stream_attn/src/fmha_utils.h
+92
-0
csrc/stream_attn/src/philox.cuh
csrc/stream_attn/src/philox.cuh
+144
-0
rotary.py
rotary.py
+135
-0
stream_attn_interface.py
stream_attn_interface.py
+100
-0
stream_blocksparse_attn_interface.py
stream_blocksparse_attn_interface.py
+142
-0
streaming_attention.py
streaming_attention.py
+114
-0
streaming_blocksparse_attention.py
streaming_blocksparse_attention.py
+135
-0
No files found.
csrc/stream_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
0 → 100644
View file @
1fcbe6f0
/* Copyright (c) 2022, Tri Dao.
*/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_loop.h"
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
int
loop_steps
=-
1
>
__global__
void
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dq_dk_dv_1xN
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
loop_steps
>
(
params
);
}
template
<
typename
Kernel_traits
>
void
run_fmha_dgrad_fp16_sm80_loop_
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_dq
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size_dp_sum
=
Kernel_traits
::
Smem_dp_sum
::
BYTES_PER_TILE
;
using
Smem_tile_s
=
fmha
::
Smem_tile_mma_transposed
<
typename
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
smem_size_s
=
Smem_tile_s
::
BYTES_PER_TILE
;
static_assert
(
smem_size_s
==
16
*
Kernel_traits
::
Cta_tile_p
::
N
*
2
);
static_assert
(
smem_size_dq
==
16
*
Kernel_traits
::
Cta_tile_p
::
K
*
4
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
);
static_assert
(
smem_size_dp_sum
==
16
*
4
*
2
);
constexpr
int
smem_size_dq_dk_dv
=
smem_size_q
*
2
+
smem_size_v
*
(
Kernel_traits
::
V_IN_REGS
?
1
:
2
)
+
smem_size_dq
+
smem_size_s
*
2
+
smem_size_dp_sum
;
bool
is_dropout
=
params
.
p_dropout
<
1.
f
;
// params.p_dropout is the probability of "keeping"
bool
is_causal
=
params
.
is_causal
;
auto
kernel
=
is_dropout
?
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
>
)
:
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
>
);
constexpr
int
N
=
Kernel_traits
::
Cta_tile_p
::
N
;
if
(
params
.
s
==
N
)
{
kernel
=
is_dropout
?
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
,
/*loop_steps=*/
1
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
,
/*loop_steps=*/
1
>
)
:
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
,
/*loop_steps=*/
1
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
,
/*loop_steps=*/
1
>
);
}
else
if
(
params
.
s
==
N
*
2
)
{
kernel
=
is_dropout
?
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
true
,
false
,
/*loop_steps=*/
2
>
)
:
(
is_causal
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
false
,
false
,
/*loop_steps=*/
2
>
);
}
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
void
run_fmha_dgrad_fp16_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
)
{
if
(
params
.
d
==
16
)
{
if
(
params
.
s
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
16
,
16
,
1
,
8
,
0x08u
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
if
(
params
.
s
==
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
8
,
0x08u
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
{
// TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u>;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
8
,
0x08u
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
}
else
if
(
params
.
d
==
32
)
{
if
(
params
.
s
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
8
,
0x08u
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
if
(
params
.
s
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
8
,
0x08u
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
}
else
if
(
params
.
d
==
64
)
{
if
(
params
.
s
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
if
(
params
.
s
>=
256
)
{
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
// Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it
// uses more shared memory, which is fine on A100 but not other GPUs.
// For other GPUs, we should either use N=128 as the base, or keep V in registers.
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
8
,
0x100u
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
}
else
if
(
params
.
d
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
8
,
0x100u
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
}
\ No newline at end of file
csrc/stream_attn/src/fmha_dgrad_kernel_1xN_loop.h
0 → 100644
View file @
1fcbe6f0
/* Copyright (c) 2022, Tri Dao.
*/
#pragma once
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Smem_dp_sum
,
int
M
>
inline
__device__
void
dot_do_o
(
float
(
&
sum
)[
M
],
const
uint4
(
&
do_
)[
M
],
const
uint4
(
&
o
)[
M
],
Smem_dp_sum
smem
,
const
int
buffer_idx
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
sum
[
mi
]
=
smem
.
reduce_warp
(
fmha
::
hmulsum8
(
do_
[
mi
],
o
[
mi
]));
}
static_assert
(
M
==
1
);
smem
.
store
(
sum
[
0
],
buffer_idx
);
// smem.store(sum, buffer_idx);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_first
,
bool
Is_last
,
typename
Params
,
typename
Prng
>
inline
__device__
void
compute_dq_dk_dv_1xN_one_iter
(
const
Params
&
params
,
Prng
&
ph
,
const
int
loop_step_idx
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_dq
=
typename
Kernel_traits
::
Cta_tile_o
;
// The description of the CTA tile for the 3rd batched GEMM.
using
Cta_tile_dkv
=
fmha
::
Cta_tile_extd
<
Cta_tile_p
::
N
,
Cta_tile_p
::
K
,
Cta_tile_p
::
M
,
Cta_tile_p
::
WARPS_N
,
1
,
Cta_tile_p
::
WARPS_M
>
;
static_assert
(
Cta_tile_dkv
::
M
==
512
||
Cta_tile_dkv
::
M
==
256
||
Cta_tile_dkv
::
M
==
128
);
static_assert
(
Cta_tile_dkv
::
N
==
16
||
Cta_tile_dkv
::
N
==
32
||
Cta_tile_dkv
::
N
==
64
||
Cta_tile_dkv
::
N
==
128
);
static_assert
(
Cta_tile_dkv
::
K
==
16
);
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_dq
=
fmha
::
Hmma_tile
<
Cta_tile_dq
>
;
// The MMA tile for the 3rd GEMM.
using
Mma_tile_dkv
=
fmha
::
Hmma_tile
<
Cta_tile_dkv
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to reload Q transposed.
using
Smem_tile_qt
=
fmha
::
Smem_tile_b
<
Cta_tile_dkv
,
fmha
::
Row
,
Gmem_tile_q
::
BYTES_PER_LDG
,
2
>
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K^T. Treat K^T as V
using
Smem_tile_kt
=
typename
Kernel_traits
::
Smem_tile_v
;
// Treating V as K. We need to use Kernel_traits::Smem_tile_k otherwise loading will be wrong
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load dO.
using
Gmem_tile_do
=
typename
Kernel_traits
::
Gmem_tile_do
;
// The shared memory tile to load dO.
// Treating dO as Q.
using
Smem_tile_do
=
typename
Kernel_traits
::
Smem_tile_q
;
// The shared memory tile to reload dO transposed.
using
Smem_tile_dot
=
fmha
::
Smem_tile_b
<
Cta_tile_dkv
,
fmha
::
Row
,
Gmem_tile_q
::
BYTES_PER_LDG
,
2
>
;
// The global memory tile to load O.Loading O here is similar to loading dO.
using
Gmem_tile_o
=
Gmem_tile_do
;
// The global memory tile to store dQ.
// using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_dq;
using
Gmem_tile_dq
=
fmha
::
Gmem_tile_dq
<
Cta_tile_dq
>
;
using
Gmem_tile_dq_tmp
=
fmha
::
Gmem_tile_o
<
Cta_tile_dq
,
4
>
;
// The shared memory tile to swizzle dQ.
using
Smem_tile_dq
=
typename
Kernel_traits
::
Smem_tile_o
;
// The global memory tile to store dV.
using
Gmem_tile_dv
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle dV.
using
Smem_tile_dv
=
fmha
::
Smem_tile_mma_epilogue
<
Cta_tile_dkv
>
;
// The global memory tile to store dK.
using
Gmem_tile_dk
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle dK.
using
Smem_tile_dk
=
fmha
::
Smem_tile_mma_epilogue
<
Cta_tile_dkv
>
;
static_assert
(
Smem_tile_dk
::
NUM_LDS
==
Gmem_tile_dk
::
LDGS
);
static_assert
(
Smem_tile_dk
::
THREADS_PER_ROW
==
Gmem_tile_dk
::
THREADS_PER_ROW
);
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Smem_tile_st
=
typename
Kernel_traits
::
Smem_tile_st
;
using
Gmem_softmax_sum
=
typename
Kernel_traits
::
Gmem_softmax_sum
;
using
Smem_dp_sum
=
typename
Kernel_traits
::
Smem_dp_sum
;
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using
Gemm1
=
Gemm_Q_K
<
Kernel_traits
,
/*K-in_regs=*/
false
>
;
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
// Shared memory.
extern
__shared__
char
smem_
[];
// Shared memory layout if we keep V in registers:
// dO | Q | K / V | dQ | S | dP | dP_sum
// dV | dK
// Shared memory layout if we keep V shared memory:
// dO | Q | K | V | dQ | S | dP | dP_sum
// dV | dK
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
// if( binfo.stop_early() ) return;
if
(
binfo
.
stop_early
(
loop_step_idx
*
Cta_tile_p
::
N
)
)
return
;
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the global memory tile loader for dQ.
Gmem_tile_dq
gmem_dq
(
params
,
0
,
binfo
,
tidx
);
Gmem_tile_dq_tmp
gmem_dq_tmp
(
params
.
o_tmp_ptr
,
params
.
o_stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
// Allocate the shared memory tile loader for K^T. We use the same as K so be careful!!!
Smem_tile_kt
smem_kt
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for dO.
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for dO.
Smem_tile_do
smem_do
(
&
smem_
[
0
],
tidx
);
Smem_tile_dot
smem_dot
(
&
smem_
[
0
],
tidx
);
// Allocate the shared memory tile loader for Q^T.
// TODO: assert that this points to the same memory as gemm_q_k.smem_q
Smem_tile_qt
smem_qt
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_st
smem_s
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dq
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_st
smem_dp
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dq
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq
smem_dq
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
Gmem_softmax_sum
gmem_softmax_lse
(
params
.
softmax_lse_ptr
,
params
,
tidx
);
Gmem_softmax_sum
gmem_softmax_d
(
params
.
dsoftmax_sum
,
params
,
tidx
);
static_assert
(
Cta_tile_p
::
N
%
Cta_tile_p
::
M
==
0
);
const
int
begin
=
Is_causal
?
loop_step_idx
*
Cta_tile_p
::
N
/
Cta_tile_p
::
M
:
0
;
// constexpr int steps = Cta_tile_p::N / Cta_tile_p::M;
const
int
steps
=
params
.
s
/
Cta_tile_p
::
M
-
begin
;
// Wind gmem tiles to the correct position.
gmem_q
.
move
(
begin
);
gmem_do
.
move
(
begin
);
gmem_o
.
move
(
begin
);
gmem_dq
.
move
(
begin
);
gmem_dq_tmp
.
move
(
begin
);
// TODO: need to move gmem_s if we want the intermediate result for debugging
gmem_softmax_lse
.
move
(
begin
);
gmem_softmax_d
.
move
(
begin
);
if
(
!
Is_first
)
{
gmem_k
.
move
(
loop_step_idx
);
gmem_v
.
move
(
loop_step_idx
);
}
// Trigger the loads for K.
gmem_k
.
load
();
// Trigger the loads for Q.
gmem_q
.
load
();
// Trigger the loads for V.
gmem_v
.
load
();
// Trigger the loads for dO.
gmem_do
.
load
();
// Trigger the loads for O.
if
(
Is_first
)
{
gmem_o
.
load
();
}
float
p_lse
[
Mma_tile_p
::
MMAS_M
*
2
];
gmem_softmax_lse
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
p_lse
));
gmem_softmax_lse
.
move
();
float
dp_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
if
(
!
Is_first
)
{
gmem_softmax_d
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
dp_sum
));
gmem_softmax_d
.
move
();
}
float
dp_sum_regs
[
Gmem_tile_do
::
LDGS
];
Smem_dp_sum
smem_dp_sum
(
reinterpret_cast
<
float
*>
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dq
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
*
2
]),
tidx
);
if
(
!
Is_first
)
{
__syncthreads
();
}
// Commit the data for Q, dO, and V to shared memory.
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
gmem_do
.
commit
(
smem_do
);
if
(
Is_first
)
{
dot_do_o
(
dp_sum_regs
,
gmem_do
.
fetch_
,
gmem_o
.
fetch_
,
smem_dp_sum
,
0
);
const
int
dp_sum_row
=
tidx
/
Smem_dp_sum
::
THREADS_PER_ROW
;
if
((
dp_sum_row
<
Smem_dp_sum
::
ROWS
)
&&
(
tidx
%
Smem_dp_sum
::
THREADS_PER_ROW
==
0
))
{
gmem_softmax_d
.
store_row
(
reinterpret_cast
<
uint32_t
(
&
)[
Gmem_tile_do
::
LDGS
]
>
(
dp_sum_regs
),
dp_sum_row
);
}
gmem_softmax_d
.
move
();
}
// Instead of scaling dP by rp_dropout, we scale V instead
if
(
Is_dropout
)
{
const
uint32_t
scale_dropout
=
params
.
scale_dropout
;
#pragma unroll
for
(
int
it
=
0
;
it
<
Gmem_tile_v
::
LDGS
;
it
++
){
gmem_v
.
fetch_
[
it
]
=
fmha
::
hmul8
(
scale_dropout
,
gmem_v
.
fetch_
[
it
]);
}
}
gmem_v
.
commit
(
smem_v
);
// const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
// #pragma unroll
// for(int it=0; it < Gmem_tile_k::LDGS; it++){
// gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
// }
// Commit the data for K to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_k
.
commit
(
gemm_q_k
.
smem_k
);
}
__syncthreads
();
// Load the fragments for Q.
gemm_q_k
.
load_q
();
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename
Smem_tile_v
::
Fragment
frag_v
[
Kernel_traits
::
V_IN_REGS
?
Mma_tile_p
::
MMAS_K
:
2
][
Mma_tile_p
::
MMAS_N
];
if
(
Kernel_traits
::
V_IN_REGS
)
{
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
smem_v
.
load
(
frag_v
[
ki
],
ki
);
}
}
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_k
.
commit
(
gemm_q_k
.
smem_k
);
// Make sure the data is in shared memory.
__syncthreads
();
}
// Load the fragments for K.
gemm_q_k
.
load_k
();
// Load the fragments for K^T.
// typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
// smem_kt.load(frag_kt[0], 0);
// typename Smem_tile_kt::Fragment frag_kt[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_N];
// #pragma unroll
// for( int ki = 0; ki < Mma_tile_dq::MMAS_K; ++ki ) {
// smem_kt.load(frag_kt[ki], ki);
// }
// Create the object to do the softmax.
// We won't be using the shared memory for this softmax at all
Softmax
softmax
(
params
,
smem_
,
tidx
);
// Declare the accumulators for the 3rd gemm.
fmha
::
Fragment_accumulator
acc_dv
[
Mma_tile_dkv
::
MMAS_M
][
Mma_tile_dkv
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_dkv
::
WARPS_K
>::
apply
(
acc_dv
);
fmha
::
Fragment_accumulator
acc_dk
[
Mma_tile_dkv
::
MMAS_M
][
Mma_tile_dkv
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_dkv
::
WARPS_K
>::
apply
(
acc_dk
);
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
steps
;
l
++
)
{
const
int
loop
=
(
begin
+
l
)
*
Cta_tile_p
::
M
;
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
// Load the fragments for V.
// typename Smem_tile_v::Fragment frag_v[2][Mma_tile_p::MMAS_N];
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
smem_v
.
load
(
frag_v
[
0
],
0
);
}
// Load the fragments for dO.
typename
Smem_tile_do
::
Fragment
frag_do
[
2
][
Mma_tile_p
::
MMAS_M
];
smem_do
.
load
(
frag_do
[
0
],
0
);
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
// Do this part of P^T = (Q * K^T)^T.
gemm_q_k
(
acc_p
);
// Load the mask for that iteration.
mask
.
load
(
begin
+
l
);
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack_noscale
(
acc_p
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
// Scale by log-sum-exp of the softmax
// softmax.apply_exp(p_lse);
softmax
.
template
scale_apply_exp
<
/*scale_max=*/
false
>(
p_lse
,
params
.
scale_bmm1f
);
if
(
Is_dropout
)
{
// softmax.apply_dropout(ph, params.p_dropout_in_uint);
// softmax.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint);
softmax
.
template
apply_dropout_16bits
<
/*encode_dropout_in_sign_bit=*/
true
>(
ph
,
params
.
p_dropout_in_uint16_t
);
}
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_dq
::
MMAS_K
][
Mma_tile_dq
::
MMAS_M
];
static_assert
(
Mma_tile_dq
::
MMAS_M
==
Mma_tile_p
::
MMAS_M
);
static_assert
(
Mma_tile_dq
::
MMAS_K
==
Mma_tile_p
::
MMAS_N
);
softmax
.
pack
(
frag_p
);
// Store s * dmask to smem for transpose
smem_s
.
store
(
frag_p
);
// Trigger the load for the next Q values.
if
(
l
<
steps
-
1
)
{
gemm_q_k
.
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
();
}
// if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
// __syncthreads();
// }
// TD [2022-04-24]: if Is_first, then it's faster to set acc_dp to zero then subtract by
// dp_sum later. If !Is_first, then it's faster to set acc_dp to -dp_sum and don't subtract
// later. This is because loading dp_sum earlier uses more registers.
fmha
::
Fragment_accumulator
acc_dp
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
if
(
Is_first
)
{
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_dp
);
}
else
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
++
ni
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
8
;
++
ii
)
{
acc_dp
[
mi
][
ni
].
elt
(
ii
)
=
-
dp_sum
[
mi
*
2
+
((
ii
/
2
)
%
2
)];
}
}
}
}
// Do this part of dP^T = (dO * V^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of dO values.
smem_do
.
load
(
frag_do
[
ki
&
1
],
ki
);
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
smem_v
.
load
(
frag_v
[
ki
&
1
],
ki
);
fmha
::
gemm
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
}
else
{
fmha
::
gemm
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[
ki
-
1
]);
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
// printf("frag_do=%.6f, %.6f\n", tmp.x, tmp.y);
// tmp = __half22float2(reinterpret_cast<__half2 &>(frag_v[(ki - 1) & 1]));
// printf("frag_v=%.6f, %.6f\n", tmp.x, tmp.y);
// }
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
fmha
::
gemm
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
}
else
{
fmha
::
gemm
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)]);
}
}
// Load the fragments for K^T.
typename
Smem_tile_kt
::
Fragment
frag_kt
[
2
][
Mma_tile_dq
::
MMAS_N
];
smem_kt
.
load
(
frag_kt
[
0
],
0
);
if
(
Is_first
)
{
const
int
quad
=
(
tidx
%
Cta_tile_p
::
THREADS_PER_WARP
)
/
4
;
const
int
row
[
2
]
=
{
quad
,
quad
+
8
};
smem_dp_sum
.
load
(
dp_sum
,
row
,
l
%
2
);
}
// Trigger the load for the next dO values.
if
(
l
<
steps
-
1
)
{
smem_do
.
move_to_next_write_buffer
();
gmem_do
.
move
();
gmem_do
.
load
();
if
(
Is_first
)
{
gmem_o
.
move
();
gmem_o
.
load
();
}
}
softmax
.
unpack_noscale
(
acc_dp
);
// // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
// // will be zero.
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum[mi] *= params.p_dropout; }
if
(
Is_first
)
{
softmax
.
subtract_dp_sum
(
dp_sum
);
}
Frag_p
frag_dp
[
Mma_tile_dq
::
MMAS_K
][
Mma_tile_dq
::
MMAS_M
];
softmax
.
pack
(
frag_dp
);
if
(
!
Is_dropout
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
frag_p
[
mi
][
ni
].
hmul
(
frag_dp
[
mi
][
ni
]);
}
}
}
else
{
__half2
dp_sum_half
[
Mma_tile_p
::
MMAS_M
*
2
];
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
*
2
;
mi
++
)
{
dp_sum_half
[
mi
]
=
__float2half2_rn
(
dp_sum
[
mi
]);
}
const
__half
zero_h
=
__half
(
0.
f
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
4
;
++
ii
)
{
const
__half2
p
=
frag_p
[
mi
][
ni
].
template
elt_as
<
__half2
>(
ii
);
const
__half2
pdp
=
__hmul2
(
p
,
frag_dp
[
mi
][
ni
].
template
elt_as
<
__half2
>(
ii
));
// If this element is dropped, then frag_p stores -p instead of p.
// So pd holds -p * dp_sum in that case.
const
__half2
pd
=
__hmul2
(
p
,
dp_sum_half
[
mi
*
2
+
(
ii
%
2
)]);
const
__half
low
=
__low2half
(
p
)
>=
zero_h
?
__low2half
(
pdp
)
:
__low2half
(
pd
);
const
__half
high
=
__high2half
(
p
)
>=
zero_h
?
__high2half
(
pdp
)
:
__high2half
(
pd
);
frag_p
[
mi
][
ni
].
template
elt_as
<
__half2
>(
ii
)
=
__halves2half2
(
low
,
high
);
}
}
}
}
// Store dp to smem for transpose
smem_dp
.
store
(
frag_p
);
// gmem_s.store(frag_p, mask);
// gmem_s.move();
// Declare the accumulators for the 2nd gemm.
fmha
::
Fragment_accumulator
acc_dq
[
Mma_tile_dq
::
MMAS_M
][
Mma_tile_dq
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_dq
::
WARPS_K
>::
apply
(
acc_dq
);
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_dq
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_kt
.
load
(
frag_kt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dq
::
MMAS_K
;
fmha
::
gemm
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
static_assert
(
Gmem_tile_dq
::
LOOPS
==
1
);
// Swizzle the elements and do the final reduction.
smem_dq
.
store
(
acc_dq
,
0
);
typename
Smem_tile_dot
::
Fragment
frag_dot
[
2
][
Mma_tile_dkv
::
MMAS_N
];
static_assert
(
Smem_tile_dot
::
Fragment
::
NUM_REGS
==
4
);
static_assert
(
Mma_tile_dkv
::
MMAS_K
==
1
);
smem_dot
.
load
(
frag_dot
[
0
],
0
);
// Threads in a warp is communicating via shared memory (smem_s and smem_dp)
__syncwarp
();
typename
Smem_tile_st
::
Fragment
frag_s
[
Mma_tile_dkv
::
MMAS_K
][
Mma_tile_dkv
::
MMAS_M
];
smem_s
.
load
(
frag_s
);
if
(
Is_dropout
)
{
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_dkv
::
MMAS_K
;
ki
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_dkv
::
MMAS_M
;
mi
++
)
{
frag_s
[
ki
][
mi
].
hrelu_
();
}
}
}
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_dkv
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_dot
.
load
(
frag_dot
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dkv
::
MMAS_K
;
fmha
::
gemm
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
}
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if
(
l
<
steps
-
1
)
{
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
}
uint4
dq_out
[
Gmem_tile_dq
::
STGS_PER_LOOP
];
if
(
!
Is_first
)
{
gmem_dq_tmp
.
load
(
dq_out
,
0
);
}
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if
(
l
<
steps
-
1
)
{
gmem_do
.
commit
(
smem_do
);
if
(
Is_first
)
{
// dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum);
// smem_dp_sum.move_to_next_write_buffer();
dot_do_o
(
dp_sum_regs
,
gmem_do
.
fetch_
,
gmem_o
.
fetch_
,
smem_dp_sum
,
(
l
+
1
)
%
2
);
const
int
dp_sum_row_1
=
tidx
/
Smem_dp_sum
::
THREADS_PER_ROW
;
if
((
dp_sum_row_1
<
Smem_dp_sum
::
ROWS
)
&&
(
tidx
%
Smem_dp_sum
::
THREADS_PER_ROW
==
0
))
{
gmem_softmax_d
.
store_row
(
reinterpret_cast
<
uint32_t
(
&
)[
Gmem_tile_do
::
LDGS
]
>
(
dp_sum_regs
),
dp_sum_row_1
);
}
gmem_softmax_d
.
move
();
}
gmem_softmax_lse
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
p_lse
));
gmem_softmax_lse
.
move
();
if
(
!
Is_first
)
{
gmem_softmax_d
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
dp_sum
));
gmem_softmax_d
.
move
();
}
}
typename
Smem_tile_st
::
Fragment
frag_dpt
[
Mma_tile_dkv
::
MMAS_K
][
Mma_tile_dkv
::
MMAS_M
];
smem_dp
.
load
(
frag_dpt
);
gemm_q_k
.
reload_k
();
typename
Smem_tile_qt
::
Fragment
frag_qt
[
2
][
Mma_tile_dkv
::
MMAS_N
];
static_assert
(
Smem_tile_qt
::
Fragment
::
NUM_REGS
==
4
);
static_assert
(
Mma_tile_dkv
::
MMAS_K
==
1
);
smem_qt
.
load
(
frag_qt
[
0
],
0
);
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_dkv
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_qt
.
load
(
frag_qt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dkv
::
MMAS_K
;
fmha
::
gemm
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Make sure dQ is in shared memory.
__syncthreads
();
// Load from shared memory.
smem_dq
.
template
load
<
/*zero_init=*/
Is_first
>(
dq_out
);
const
bool
is_final_write
=
Is_last
||
((
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
>=
binfo
.
actual_seqlen
)
||
((
Is_causal
)
&&
((
begin
+
l
)
*
Cta_tile_p
::
M
<
(
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
));
if
(
is_final_write
)
{
// if (Is_dropout) {
// dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout);
// }
dq_out
[
0
]
=
fmha
::
fmul4
(
dq_out
[
0
],
params
.
scale_bmm1f
);
// Output the values.
gmem_dq
.
store
(
dq_out
,
0
);
// Move to the next part of the output.
gmem_dq
.
move
();
}
else
{
// Output the values.
gmem_dq_tmp
.
store
(
dq_out
,
0
);
}
// Move to the next part of the output.
if
(
!
(
Is_first
&&
Is_last
))
{
gmem_dq_tmp
.
move
();
}
// // Make sure the data is in shared memory.
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if
(
l
<
steps
-
1
)
{
gemm_q_k
.
smem_q
.
move_to_next_read_buffer
();
gemm_q_k
.
reload_q
();
smem_qt
.
move_to_next_read_buffer
();
// smem_qt.load(frag_qt[0], 0);
smem_do
.
move_to_next_read_buffer
();
smem_dot
.
move_to_next_read_buffer
();
// smem_dot.load(frag_dot[0], 0);
}
}
// Outer loop over the sequence length.
if
(
Is_dropout
)
{
for
(
int
mi
=
0
;
mi
<
Mma_tile_dkv
::
MMAS_M
;
mi
++
)
{
for
(
int
ni
=
0
;
ni
<
Mma_tile_dkv
::
MMAS_N
;
ni
++
)
{
acc_dv
[
mi
][
ni
].
mul_
(
params
.
rp_dropout
);
}
}
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1));
// }
for
(
int
mi
=
0
;
mi
<
Mma_tile_dkv
::
MMAS_M
;
mi
++
)
{
for
(
int
ni
=
0
;
ni
<
Mma_tile_dkv
::
MMAS_N
;
ni
++
)
{
// acc_dk[mi][ni].mul_(Is_dropout ? params.rp_dropout * params.scale_bmm1f : params.scale_bmm1f);
acc_dk
[
mi
][
ni
].
mul_
(
params
.
scale_bmm1f
);
}
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1));
// }
__syncthreads
();
// TODO [TD - 2022-05-04]: Are there cases where the shared mem for dV and dK are larger than
// the total amount of shared mem?
// Epilogue swizzle for dV
Smem_tile_dv
smem_dv
(
&
smem_
[
0
],
tidx
);
smem_dv
.
store
(
acc_dv
);
// Epilogue swizzle for dK
Smem_tile_dk
smem_dk
(
&
smem_
[
Smem_tile_dv
::
BYTES_PER_TILE
],
tidx
);
smem_dk
.
store
(
acc_dk
);
__syncthreads
();
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
smem_dv
.
load
(
dv_out
);
Qkv_params
dv_params
;
dv_params
.
qkv_ptr
=
params
.
dqkv_ptr
;
dv_params
.
qkv_stride_in_bytes
=
params
.
qkv_stride_in_bytes
;
dv_params
.
h
=
params
.
h
;
Gmem_tile_dv
gmem_dv
(
dv_params
,
2
,
binfo
,
tidx
);
if
(
!
Is_first
)
{
gmem_dv
.
move
(
loop_step_idx
);
}
gmem_dv
.
store
(
dv_out
);
uint4
dk_out
[
Smem_tile_dk
::
NUM_LDS
];
smem_dk
.
load
(
dk_out
);
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// }
Qkv_params
dk_params
;
dk_params
.
qkv_ptr
=
params
.
dqkv_ptr
;
dk_params
.
qkv_stride_in_bytes
=
params
.
qkv_stride_in_bytes
;
dk_params
.
h
=
params
.
h
;
Gmem_tile_dk
gmem_dk
(
dk_params
,
1
,
binfo
,
tidx
);
if
(
!
Is_first
)
{
gmem_dk
.
move
(
loop_step_idx
);
}
gmem_dk
.
store
(
dk_out
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// loop_steps = -1 means the number of steps will be params.s / Kernel_traits::Cta_tile_p::N.
// This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2.
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
int
loop_steps
=-
1
,
typename
Params
>
inline
__device__
void
compute_dq_dk_dv_1xN
(
const
Params
&
params
)
{
constexpr
int
N_per_loop
=
Kernel_traits
::
Cta_tile_p
::
N
;
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
int
tidx_global
=
(
bidb
*
params
.
h
+
bidh
)
*
blockDim
.
x
+
tidx
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
tidx_global
,
std
::
get
<
1
>
(
seeds
));
if
(
loop_steps
==
1
)
{
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
true
>
(
params
,
ph
,
0
);
}
else
if
(
loop_steps
==
2
)
{
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
false
>
(
params
,
ph
,
0
);
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
false
,
true
>
(
params
,
ph
,
1
);
}
else
{
if
(
params
.
s
==
N_per_loop
)
{
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
true
>
(
params
,
ph
,
0
);
}
else
{
const
int
max_loop_steps
=
(
params
.
s
+
N_per_loop
-
1
)
/
N_per_loop
;
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
true
,
false
>
(
params
,
ph
,
0
);
for
(
int
loop_step_idx
=
1
;
loop_step_idx
<
max_loop_steps
-
1
;
loop_step_idx
++
)
{
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
false
,
false
>
(
params
,
ph
,
loop_step_idx
);
}
compute_dq_dk_dv_1xN_one_iter
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
false
,
true
>
(
params
,
ph
,
max_loop_steps
-
1
);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
csrc/stream_attn/src/fmha_dgrad_kernel_1xN_reload_recompute.h
0 → 100644
View file @
1fcbe6f0
/* Copyright (c) 2022, Tri Dao.
*/
#pragma once
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMAS_M
>
inline
__device__
void
dot_fragments
(
float
(
&
sum
)[
MMAS_M
*
2
],
const
fmha
::
Fragment_a
<
fmha
::
Row
>
(
&
x
)[
MMAS_M
],
const
fmha
::
Fragment_a
<
fmha
::
Row
>
(
&
y
)[
MMAS_M
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
sum
[
mi
*
2
+
0
]
+=
hfma2_to_float
(
x
[
mi
].
template
elt_as
<
__half2
>(
0
),
y
[
mi
].
template
elt_as
<
__half2
>(
0
));
sum
[
mi
*
2
+
0
]
+=
hfma2_to_float
(
x
[
mi
].
template
elt_as
<
__half2
>(
2
),
y
[
mi
].
template
elt_as
<
__half2
>(
2
));
sum
[
mi
*
2
+
1
]
+=
hfma2_to_float
(
x
[
mi
].
template
elt_as
<
__half2
>(
1
),
y
[
mi
].
template
elt_as
<
__half2
>(
1
));
sum
[
mi
*
2
+
1
]
+=
hfma2_to_float
(
x
[
mi
].
template
elt_as
<
__half2
>(
3
),
y
[
mi
].
template
elt_as
<
__half2
>(
3
));
// hfma2_to_float(sum[mi * 2 + 0], x[mi].template elt_as<__half2>(0), y[mi].template elt_as<__half2>(0));
// hfma2_to_float(sum[mi * 2 + 0], x[mi].template elt_as<__half2>(2), y[mi].template elt_as<__half2>(2));
// hfma2_to_float(sum[mi * 2 + 1], x[mi].template elt_as<__half2>(1), y[mi].template elt_as<__half2>(1));
// hfma2_to_float(sum[mi * 2 + 1], x[mi].template elt_as<__half2>(3), y[mi].template elt_as<__half2>(3));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
typename
Params
>
inline
__device__
void
compute_dp_dq_1xN
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_dq
=
typename
Kernel_traits
::
Cta_tile_o
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_dq
=
fmha
::
Hmma_tile
<
Cta_tile_dq
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K^T. Treat K^T as V
using
Smem_tile_kt
=
typename
Kernel_traits
::
Smem_tile_v
;
// Treating V as K. We need to use Kernel_traits::Smem_tile_k otherwise loading will be wrong
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load dO.
using
Gmem_tile_do
=
typename
Kernel_traits
::
Gmem_tile_do
;
// The shared memory tile to load dO.
// Treating dO as Q.
using
Smem_tile_do
=
typename
Kernel_traits
::
Smem_tile_q
;
// The global memory tile to load O.Loading O here is similar to loading dO.
using
Gmem_tile_o
=
Gmem_tile_do
;
// The shared memory tile to load O.
using
Smem_tile_o
=
Smem_tile_do
;
// The global memory tile to store dQ.
// using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_dq;
using
Gmem_tile_dq
=
fmha
::
Gmem_tile_dq
<
Cta_tile_dq
>
;
// The shared memory tile to swizzle dQ.
using
Smem_tile_dq
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Gmem_softmax_sum
=
typename
Kernel_traits
::
Gmem_softmax_sum
;
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using
Gemm1
=
Gemm_Q_K
<
Kernel_traits
,
/*K-in_regs=*/
false
>
;
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the global memory tile loader for dQ.
Gmem_tile_dq
gmem_dq
(
params
,
0
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
fmha
::
Mask
<
Cta_tile_p
>
mask
(
binfo
,
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
// Allocate the shared memory tile loader for K^T. We use the same as K so be careful!!!
Smem_tile_kt
smem_kt
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for dO.
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for dO.
Smem_tile_do
smem_do
(
&
smem_
[
0
],
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O.
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
],
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq
smem_dq
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
// Trigger the loads for K.
gmem_k
.
load
();
// Trigger the loads for Q.
gmem_q
.
load
();
// Trigger the loads for V.
gmem_v
.
load
();
// Trigger the loads for dO.
gmem_do
.
load
();
// Trigger the loads for O.
gmem_o
.
load
();
const
uint32_t
scale_bmm1
=
reinterpret_cast
<
const
uint32_t
&>
(
params
.
scale_bmm1
);
#pragma unroll
for
(
int
it
=
0
;
it
<
Gmem_tile_k
::
LDGS
;
it
++
){
gmem_k
.
fetch_
[
it
]
=
fmha
::
hmul8
(
scale_bmm1
,
gmem_k
.
fetch_
[
it
]);
}
// Commit the data for Q, dO, and V to shared memory.
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
gmem_do
.
commit
(
smem_do
);
gmem_o
.
commit
(
smem_o
);
gmem_v
.
commit
(
smem_v
);
// Commit the data for K to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_k
.
commit
(
gemm_q_k
.
smem_k
);
}
__syncthreads
();
// Load the fragments for Q.
gemm_q_k
.
load_q
();
// Load the fragments for dO.
typename
Smem_tile_do
::
Fragment
frag_do
[
2
][
Mma_tile_p
::
MMAS_M
];
smem_do
.
load
(
frag_do
[
0
],
0
);
// Load the fragments for O.
typename
Smem_tile_o
::
Fragment
frag_o
[
2
][
Mma_tile_p
::
MMAS_M
];
smem_o
.
load
(
frag_o
[
0
],
0
);
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename
Smem_tile_v
::
Fragment
frag_v
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
smem_v
.
load
(
frag_v
[
ki
],
ki
);
}
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_k
.
commit
(
gemm_q_k
.
smem_k
);
// Make sure the data is in shared memory.
__syncthreads
();
}
// Load the fragments for K.
gemm_q_k
.
load_k
();
// Load the fragments for K^T.
typename
Smem_tile_kt
::
Fragment
frag_kt
[
2
][
Mma_tile_dq
::
MMAS_N
];
smem_kt
.
load
(
frag_kt
[
0
],
0
);
// typename Smem_tile_kt::Fragment frag_kt[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_N];
// #pragma unroll
// for( int ki = 0; ki < Mma_tile_dq::MMAS_K; ++ki ) {
// smem_kt.load(frag_kt[ki], ki);
// }
// Create the object to do the softmax.
// We won't be using the shared memory for this softmax at all
// Softmax softmax(params, &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE], bidb, tidx);
Softmax
softmax
(
params
,
smem_
,
tidx
);
// Softmax softmax_dp(params, &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE], bidb, tidx);
Gmem_softmax_sum
gmem_softmax_sum
(
params
.
softmax_lse_ptr
,
params
,
tidx
);
Gmem_softmax_sum
gmem_softmax_d
(
params
.
dsoftmax_sum
,
params
,
tidx
);
constexpr
int
STEPS
=
Cta_tile_p
::
N
/
Cta_tile_p
::
M
;
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
STEPS
;
l
++
)
{
const
int
loop
=
l
*
Cta_tile_p
::
M
;
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
float
p_lse
[
Mma_tile_p
::
MMAS_M
*
2
];
gmem_softmax_sum
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
p_lse
));
gmem_softmax_sum
.
move
();
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
// Do this part of P^T = (Q * K^T)^T.
gemm_q_k
(
acc_p
);
// Trigger the load for the next Q values.
if
(
l
<
STEPS
-
1
)
{
gemm_q_k
.
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
();
}
// Load the mask for that iteration.
mask
.
load
(
l
);
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack_noscale
(
acc_p
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
// Scale by log-sum-exp of the softmax
softmax
.
template
apply_exp
<
/*max_in_base2=*/
true
>(
p_lse
);
// softmax.unpack_noscale_half_and_apply_mask(acc_p, mask);
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_dq
::
MMAS_K
][
Mma_tile_dq
::
MMAS_M
];
static_assert
(
Mma_tile_dq
::
MMAS_M
==
Mma_tile_p
::
MMAS_M
);
static_assert
(
Mma_tile_dq
::
MMAS_K
==
Mma_tile_p
::
MMAS_N
);
softmax
.
pack
(
frag_p
);
// if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
// __syncthreads();
// }
float
dp_sum_new
[
Mma_tile_p
::
MMAS_M
*
2
]
=
{
0
};
dot_fragments
(
dp_sum_new
,
frag_do
[
0
],
frag_o
[
0
]);
fmha
::
Fragment_accumulator
acc_dp
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_dp
);
// Do this part of dP^T = (dO * V^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of dO values.
smem_do
.
load
(
frag_do
[
ki
&
1
],
ki
);
smem_o
.
load
(
frag_o
[
ki
&
1
],
ki
);
dot_fragments
(
dp_sum_new
,
frag_do
[
ki
&
1
],
frag_o
[
ki
&
1
]);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("dp_sum_new=%.6f, %.6f\n", dp_sum_new[0], dp_sum_new[1]);
// }
// smem_v.load(frag_v[ki & 1], ki);
// Do the math for the values already in registers.
// fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
// if ((threadIdx.x == 1) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
// printf("frag_do=%.6f, %.6f\n", tmp.x, tmp.y);
// tmp = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki - 1]));
// printf("frag_v=%.6f, %.6f\n", tmp.x, tmp.y);
// }
fmha
::
gemm
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[
ki
-
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
// fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
fmha
::
gemm
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)]);
}
// if ((threadIdx.x == 1) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("acc_dp=%.6f, %.6f\n", acc_dp[0][0].elt(0), acc_dp[0][0].elt(1));
// }
// Trigger the load for the next dO values.
if
(
l
<
STEPS
-
1
)
{
smem_do
.
move_to_next_write_buffer
();
gmem_do
.
move
();
gmem_do
.
load
();
smem_o
.
move_to_next_write_buffer
();
gmem_o
.
move
();
gmem_o
.
load
();
}
// softmax_dp.unpack_noscale(acc_dp);
// // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
// // will be zero.
// #pragma unroll
// for( int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++ ) {
// #pragma unroll
// for( int ni = 0; ni < Mma_tile_p::MMAS_N * 4; ni++ ) {
// softmax_dp.elt_[mi][ni] *= softmax.elt_[mi][ni];
// }
// }
// float dp_sum[Mma_tile_p::MMAS_M * 2];
// softmax_dp.reduce_sum(dp_sum);
// gmem_softmax_d.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
// gmem_softmax_d.move();
// #pragma unroll
// for( int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++ ) {
// #pragma unroll
// for( int ni = 0; ni < Mma_tile_p::MMAS_N * 4; ni++ ) {
// softmax_dp.elt_[mi][ni] -= dp_sum[mi] * softmax.elt_[mi][ni];
// }
// }
fmha
::
SumOp
<
float
>
sum_op
;
fmha
::
quad_allreduce
(
dp_sum_new
,
dp_sum_new
,
sum_op
);
// softmax_dp.unpack_noscale(acc_dp);
softmax
.
unpack_noscale
(
acc_dp
);
// #pragma unroll
// for( int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++ ) {
// #pragma unroll
// for( int ni = 0; ni < Mma_tile_p::MMAS_N * 4; ni++ ) {
// // softmax_dp.elt_[mi][ni] -= dp_sum_new[mi];
// softmax.elt_[mi][ni] -= dp_sum_new[mi];
// }
// }
softmax
.
subtract_dp_sum
(
dp_sum_new
);
Frag_p
frag_dp
[
Mma_tile_dq
::
MMAS_K
][
Mma_tile_dq
::
MMAS_M
];
// softmax_dp.pack(frag_dp);
softmax
.
pack
(
frag_dp
);
gmem_softmax_d
.
store
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
dp_sum_new
));
gmem_softmax_d
.
move
();
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
frag_p
[
mi
][
ni
].
hmul
(
frag_dp
[
mi
][
ni
]);
}
}
// softmax_dp.pack(frag_p);
// gmem_s.store(frag_p, mask);
// gmem_s.move();
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if
(
l
<
STEPS
-
1
)
{
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
gmem_do
.
commit
(
smem_do
);
gmem_o
.
commit
(
smem_o
);
}
// Declare the accumulators for the 2nd gemm.
fmha
::
Fragment_accumulator
acc_dq
[
Mma_tile_dq
::
MMAS_M
][
Mma_tile_dq
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_dq
::
WARPS_K
>::
apply
(
acc_dq
);
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_dq
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_kt
.
load
(
frag_kt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dq
::
MMAS_K
;
fmha
::
gemm
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_dq
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_dq
.
store
(
acc_dq
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_dq
::
STGS_PER_LOOP
];
smem_dq
.
load
(
out
);
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_dq
::
LOOPS
-
1
)
{
__syncthreads
();
}
// Output the values.
gmem_dq
.
store
(
out
,
ii
);
}
// Move to the next part of the output.
gmem_dq
.
move
();
gemm_q_k
.
reload_k
();
smem_kt
.
load
(
frag_kt
[
0
],
0
);
// // Make sure the data is in shared memory.
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if
(
l
<
STEPS
-
1
)
{
gemm_q_k
.
smem_q
.
move_to_next_read_buffer
();
gemm_q_k
.
reload_q
();
smem_do
.
move_to_next_read_buffer
();
smem_do
.
load
(
frag_do
[
0
],
0
);
smem_o
.
move_to_next_read_buffer
();
smem_o
.
load
(
frag_o
[
0
],
0
);
}
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
typename
Params
>
inline
__device__
void
compute_dv_dk_1xN
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_dkv
=
typename
Kernel_traits
::
Cta_tile_o
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_dk
=
fmha
::
Hmma_tile
<
Cta_tile_dkv
>
;
// The global memory tile to load Q. Treating Q as K.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The global memory tile to load K. Treating K as Q.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q^T. Treat Q^T as V
using
Smem_tile_qt
=
typename
Kernel_traits
::
Smem_tile_v
;
// Treating V as dO.
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_q
;
// Treating dO as V in dQ kernel, which is the same as K in the forward kernel.
// The global memory tile to load dO.
using
Gmem_tile_do
=
typename
Kernel_traits
::
Gmem_tile_dot
;
// The shared memory tile to load dO.
using
Smem_tile_do
=
typename
Kernel_traits
::
Smem_tile_k
;
// The shared memory tile to load dO^T.
using
Smem_tile_dot
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store dK and dV.
// using Gmem_tile_dkv = typename Kernel_traits::Gmem_tile_dkv;
using
Gmem_tile_dkv
=
fmha
::
Gmem_tile_dq
<
Cta_tile_dkv
>
;
// The shared memory tile to swizzle dK and dV.
using
Smem_tile_dkv
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Gmem_softmax_sum
=
typename
Kernel_traits
::
Gmem_softmax_sum
;
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using
Gemm1
=
Gemm_Q_K
<
Kernel_traits
,
/*K-in_regs=*/
false
>
;
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the global memory tile loader for dK.
Gmem_tile_dkv
gmem_dk
(
params
,
1
,
binfo
,
tidx
);
// Allocate the global memory tile loader for dV.
Gmem_tile_dkv
gmem_dv
(
params
,
2
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
fmha
::
Mask
<
Cta_tile_p
>
mask
(
binfo
,
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for dO.
Smem_tile_v
smem_v
(
&
smem_
[
0
],
tidx
);
// Allocate the shared memory tile loader for Q^T. We use the same as Q so be careful!!!
Smem_tile_qt
smem_qt
(
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for dO.
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
,
binfo
,
tidx
);
// The base pointer of smem_do;
char
*
smem_do_
=
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_do
smem_do
(
smem_do_
,
tidx
);
Smem_tile_dot
smem_dot
(
smem_do_
,
tidx
);
// Allocate the shared memory tile loader for dK and dV. We use the same as K so be careful!!!
Smem_tile_dkv
smem_dkv
(
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
// Trigger the loads for Q.
gmem_q
.
load
();
// Trigger the loads for K.
gmem_k
.
load
();
// Trigger the loads for dO.
gmem_do
.
load
();
// Trigger the loads for V.
gmem_v
.
load
();
const
uint32_t
scale_bmm1
=
reinterpret_cast
<
const
uint32_t
&>
(
params
.
scale_bmm1
);
#pragma unroll
for
(
int
it
=
0
;
it
<
Gmem_tile_q
::
LDGS
;
it
++
){
gmem_q
.
fetch_
[
it
]
=
fmha
::
hmul8
(
scale_bmm1
,
gmem_q
.
fetch_
[
it
]);
}
// Commit the data for K, dO, and V to shared memory.
gmem_k
.
commit
(
gemm_q_k
.
smem_q
);
gmem_v
.
commit
(
smem_v
);
gmem_do
.
commit
(
smem_do
);
// Commit the data for Q to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_q
.
commit
(
gemm_q_k
.
smem_k
);
}
__syncthreads
();
// Load the fragments for K.
gemm_q_k
.
load_q
();
// Load the fragments for V.
typename
Smem_tile_v
::
Fragment
frag_v
[
2
][
Mma_tile_p
::
MMAS_M
];
smem_v
.
load
(
frag_v
[
0
],
0
);
// Load the fragments for dO. We keep the data in registers during the entire kernel.
typename
Smem_tile_do
::
Fragment
frag_do
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_dk
::
MMAS_K
;
++
ki
)
{
smem_do
.
load
(
frag_do
[
ki
],
ki
);
}
using
Smem_tile_mma_t
=
fmha
::
Smem_tile_transpose
<
Cta_tile_p
>
;
// Smem_tile_mma_t smem_mmat(&smem_[Smem_tile_v::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
Smem_tile_mma_t
smem_mmat
(
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dkv
::
BYTES_PER_TILE
],
tidx
);
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_q
.
commit
(
gemm_q_k
.
smem_k
);
// Make sure the data is in shared memory.
__syncthreads
();
}
// Load the fragments for Q.
gemm_q_k
.
load_k
();
// Load the fragments for K^T.
typename
Smem_tile_qt
::
Fragment
frag_qt
[
2
][
Mma_tile_dk
::
MMAS_N
];
smem_qt
.
load
(
frag_qt
[
0
],
0
);
// typename Smem_tile_qt::Fragment frag_qt[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_N];
// #pragma unroll
// for( int ki = 0; ki < Mma_tile_dk::MMAS_K; ++ki ) {
// smem_qt.load(frag_qt[ki], ki);
// }
// Create the object to do the softmax.
// We won't be using the shared memory for either of the softmax at all
Softmax
softmax
(
params
,
smem_
,
tidx
);
Softmax
softmax_dp
(
params
,
smem_
,
tidx
);
Gmem_softmax_sum
gmem_softmax_sum
(
params
.
softmax_lse_ptr
,
params
,
tidx
);
Gmem_softmax_sum
gmem_softmax_d
(
params
.
dsoftmax_sum
,
params
,
tidx
);
int
warp
=
tidx
/
Cta_tile_p
::
THREADS_PER_WARP
;
int
lane
=
tidx
%
Cta_tile_p
::
THREADS_PER_WARP
;
int
rows
[
Mma_tile_p
::
MMAS_N
*
4
];
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
rows
[
ni
*
4
+
0
]
=
ni
*
Cta_tile_p
::
WARPS_N
*
16
+
warp
*
16
+
(
lane
%
4
)
*
2
;
rows
[
ni
*
4
+
1
]
=
ni
*
Cta_tile_p
::
WARPS_N
*
16
+
warp
*
16
+
(
lane
%
4
)
*
2
+
1
;
rows
[
ni
*
4
+
2
]
=
ni
*
Cta_tile_p
::
WARPS_N
*
16
+
warp
*
16
+
(
lane
%
4
)
*
2
+
8
;
rows
[
ni
*
4
+
3
]
=
ni
*
Cta_tile_p
::
WARPS_N
*
16
+
warp
*
16
+
(
lane
%
4
)
*
2
+
9
;
}
float
p_lse
[
Mma_tile_p
::
MMAS_N
*
4
];
gmem_softmax_sum
.
load_row
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_N
*
4
]
>
(
p_lse
),
rows
);
float
dp_sum
[
Mma_tile_p
::
MMAS_N
*
4
];
gmem_softmax_d
.
load_row
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_N
*
4
]
>
(
dp_sum
),
rows
);
// int qid = lane / 8;
// int rows_shfl[Mma_tile_p::MMAS_N];
// for (int ni = 0; ni < Mma_tile_p::MMAS_N; ni++) {
// rows_shfl[ni] = ni * Cta_tile_p::WARPS_N * 16 + warp * 16 + (lane % 4) * 2 + (qid / 2) * 8 + (qid % 2);
// }
// float p_lse[Mma_tile_p::MMAS_N];
// gmem_softmax_sum.load_row(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_N]>(p_lse), rows_shfl);
// float dp_sum[Mma_tile_p::MMAS_N];
// gmem_softmax_d.load_row(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_N]>(dp_sum), rows_shfl);
constexpr
int
STEPS
=
Cta_tile_p
::
N
/
Cta_tile_p
::
M
;
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
STEPS
;
l
++
)
{
const
int
loop
=
l
*
Cta_tile_p
::
M
;
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
typename
Smem_tile_dot
::
Fragment
frag_dot
[
2
][
Mma_tile_p
::
MMAS_N
];
// smem_mmat.store(frag_do, 0);
// smem_mmat.load(frag_dot[0]);
// smem_mmat.transpose(frag_do, frag_dot[0], 0);
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_pt
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_pt
);
// Do this part of P^T = (Q * K^T)^T.
gemm_q_k
(
acc_pt
);
// Trigger the load for the next K values.
if
(
l
<
STEPS
-
1
)
{
gemm_q_k
.
smem_q
.
move_to_next_write_buffer
();
gmem_k
.
move
();
gmem_k
.
load
();
}
// Load the mask for that iteration.
mask
.
load
(
l
);
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack_noscale
(
acc_pt
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
// Scale by log-sum-exp of the softmax
softmax
.
template
apply_exp_col
<
/*max_in_base2=*/
true
>(
p_lse
);
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_dk
::
MMAS_K
][
Mma_tile_dk
::
MMAS_M
];
softmax
.
pack
(
frag_p
);
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_dv
[
Mma_tile_dk
::
MMAS_M
][
Mma_tile_dk
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_dkv
::
WARPS_K
>::
apply
(
acc_dv
);
smem_mmat
.
transpose
(
frag_do
,
frag_dot
[
0
],
0
);
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_dk
::
MMAS_K
;
++
ki
)
{
// fmha::gemm(acc_dv, frag_p[ki], frag_dot[ki]);
if
(
ki
+
1
<
Mma_tile_dk
::
MMAS_K
)
{
// smem_mmat.store(frag_do, ki + 1);
// smem_mmat.load(frag_dot[(ki + 1) % 2]);
smem_mmat
.
transpose
(
frag_do
,
frag_dot
[(
ki
+
1
)
%
2
],
ki
+
1
);
}
fmha
::
gemm
(
acc_dv
,
frag_p
[
ki
],
frag_dot
[
ki
%
2
]);
}
__syncthreads
();
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_dkv
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_dkv
.
store
(
acc_dv
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_dkv
::
STGS_PER_LOOP
];
smem_dkv
.
load
(
out
);
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_dkv
::
LOOPS
-
1
)
{
__syncthreads
();
}
// Output the values.
gmem_dv
.
store
(
out
,
ii
);
}
// Move to the next part of the output.
gmem_dv
.
move
();
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
&&
l
==
0
)
{
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads
();
}
fmha
::
Fragment_accumulator
acc_dpt
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_dpt
);
// Do this part of dP^T = (dO * V^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of dO values.
smem_v
.
load
(
frag_v
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dpt
,
frag_v
[(
ki
-
1
)
&
1
],
frag_do
[
ki
-
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
// fmha::gemm(acc_dpt, frag_v[(ki - 1) & 1], frag_do[(ki - 1) & 1]);
fmha
::
gemm
(
acc_dpt
,
frag_v
[(
ki
-
1
)
&
1
],
frag_do
[(
ki
-
1
)]);
}
// Trigger the load for the next V values.
if
(
l
<
STEPS
-
1
)
{
smem_v
.
move_to_next_write_buffer
();
gmem_v
.
move
();
gmem_v
.
load
();
}
softmax_dp
.
unpack_noscale
(
acc_dpt
);
// TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
// will be zero.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
*
2
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
*
4
;
ni
++
)
{
// softmax.elt_[mi][ni] *= (softmax_dp.elt_[mi][ni] - dp_sum[ni]);
softmax_dp
.
elt_
[
mi
][
ni
]
-=
dp_sum
[
ni
];
// const float tmp = __shfl_sync(0xffffffff, dp_sum[ni / 4], (ni % 4) * 8 + threadIdx.x % 8);
// softmax_dp.elt_[mi][ni] -= tmp;
}
}
Frag_p
frag_dp
[
Mma_tile_dk
::
MMAS_K
][
Mma_tile_dk
::
MMAS_M
];
softmax_dp
.
pack
(
frag_dp
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
frag_p
[
mi
][
ni
].
hmul
(
frag_dp
[
mi
][
ni
]);
}
}
// using Frag_p = fmha::Fragment_a<fmha::Row>;
// Frag_p frag_p[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M];
// softmax.pack(frag_p);
// softmax_dp.pack(frag_p);
// gmem_s.store(frag_p, mask);
// gmem_s.move();
__syncthreads
();
// Commit the values for K and V into shared memory.
if
(
l
<
STEPS
-
1
)
{
gmem_k
.
commit
(
gemm_q_k
.
smem_q
);
gmem_v
.
commit
(
smem_v
);
}
// Declare the accumulators for the 2nd gemm.
fmha
::
Fragment_accumulator
acc_dk
[
Mma_tile_dk
::
MMAS_M
][
Mma_tile_dk
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_dkv
::
WARPS_K
>::
apply
(
acc_dk
);
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_dk
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_qt
.
load
(
frag_qt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dk
,
frag_p
[
ki
-
1
],
frag_qt
[(
ki
-
1
)
&
1
]);
// fmha::gemm(acc_dk, frag_p[ki - 1], frag_qt[(ki - 1)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_dk
,
frag_p
[
ki
-
1
],
frag_qt
[(
ki
-
1
)
&
1
]);
// fmha::gemm(acc_dk, frag_p[ki - 1], frag_qt[(ki - 1)]);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_dkv
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_dkv
.
store
(
acc_dk
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_dkv
::
STGS_PER_LOOP
];
smem_dkv
.
load
(
out
);
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_dkv
::
LOOPS
-
1
)
{
__syncthreads
();
}
// Output the values.
gmem_dk
.
store
(
out
,
ii
);
}
// Move to the next part of the output.
gmem_dk
.
move
();
gemm_q_k
.
reload_k
();
smem_qt
.
load
(
frag_qt
[
0
],
0
);
// Make sure the data is in shared memory.
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if
(
l
<
STEPS
-
1
)
{
gemm_q_k
.
smem_q
.
move_to_next_read_buffer
();
gemm_q_k
.
reload_q
();
smem_v
.
move_to_next_read_buffer
();
smem_v
.
load
(
frag_v
[
0
],
0
);
}
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
csrc/stream_attn/src/fmha_fprop_fp16_kernel.sm80.cu
0 → 100644
View file @
1fcbe6f0
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Return_softmax
>
__global__
void
fmha_fprop_fp16_sm80_loop_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN_loop
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
>
(
params
);
}
template
<
typename
Kernel_traits
>
void
run_fmha_fp16_sm80_loop_
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
bool
is_causal
=
launch_params
.
params
.
is_causal
;
// TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way?
auto
kernel
=
launch_params
.
is_dropout
?
(
is_causal
?
(
launch_params
.
return_softmax
?
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
true
,
true
,
true
>
:
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
true
,
true
,
false
>
)
:
(
launch_params
.
return_softmax
?
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
true
,
false
,
true
>
:
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
true
,
false
,
false
>
))
:
(
is_causal
?
(
launch_params
.
return_softmax
?
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
true
,
true
>
:
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
true
,
false
>
)
:
(
launch_params
.
return_softmax
?
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
false
,
true
>
:
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
false
,
false
,
false
>
));
constexpr
int
N
=
Kernel_traits
::
Cta_tile_p
::
N
;
const
int
loop_steps
=
(
launch_params
.
params
.
s
+
N
-
1
)
/
N
;
constexpr
int
smem_size_softmax_lse
=
Kernel_traits
::
Smem_dp_sum
::
BYTES_PER_TILE
;
// Don't need smem_size_softmax_lse if we're not looping
const
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
()
+
(
loop_steps
>
1
?
smem_size_softmax_lse
:
0
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
if
(
configure
)
{
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
typename
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
M
=
Kernel_traits
::
Cta_tile_p
::
M
;
size_t
STEPS
=
(
launch_params
.
params
.
s
+
M
-
1
)
/
M
;
constexpr
size_t
MMAS_M
=
Mma_tile_p
::
MMAS_M
;
constexpr
size_t
MMAS_N
=
Mma_tile_p
::
MMAS_N
;
size_t
elts_per_head
=
STEPS
*
MMAS_M
*
MMAS_N
*
8
*
loop_steps
;
launch_params
.
elts_per_thread
=
elts_per_head
;
return
;
}
dim3
grid
(
launch_params
.
params
.
h
,
launch_params
.
params
.
b
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
void
run_fmha_fp16_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
if
(
launch_params
.
params
.
d
==
16
)
{
if
(
launch_params
.
params
.
s
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
16
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
if
(
launch_params
.
params
.
s
==
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
// TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u>;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
16
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
}
else
if
(
launch_params
.
params
.
d
==
32
)
{
if
(
launch_params
.
params
.
s
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
if
(
launch_params
.
params
.
s
==
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
}
else
if
(
launch_params
.
params
.
d
==
64
)
{
if
(
launch_params
.
params
.
s
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
if
(
launch_params
.
params
.
s
==
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
}
else
if
(
launch_params
.
params
.
d
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
// if (launch_params.params.d == 64) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
}
\ No newline at end of file
csrc/stream_attn/src/fmha_fprop_kernel_1xN.h
0 → 100644
View file @
1fcbe6f0
/***************************************************************************************************
* Copyright (c) 2022, Tri Dao.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
#include <fmha/utils.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
>
struct
Gemm_Q_K_base
{
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
using
Fragment_q
=
typename
Smem_tile_q
::
Fragment
;
using
Fragment_k
=
typename
Smem_tile_k
::
Fragment
;
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
static
constexpr
int
SMEM_BYTES_SOFTMAX
=
Cta_tile_p
::
M
*
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
)
*
2
;
__device__
inline
Gemm_Q_K_base
(
char
*
smem_ptr_q
,
char
*
smem_ptr_k
,
const
int
tidx
)
:
smem_q
(
smem_ptr_q
,
tidx
)
,
smem_k
(
smem_ptr_k
,
tidx
)
{
}
__device__
inline
void
load_q
()
{
smem_q
.
load
(
frag_q
[
0
],
0
);
}
__device__
inline
void
reload_q
()
{
smem_q
.
load
(
frag_q
[
0
],
0
);
}
Fragment_q
frag_q
[
2
][
Mma_tile_p
::
MMAS_M
];
Smem_tile_q
smem_q
;
Smem_tile_k
smem_k
;
};
template
<
typename
Kernel_traits
,
bool
K_in_regs
>
struct
Gemm_Q_K
:
public
Gemm_Q_K_base
<
Kernel_traits
>
{
using
Base
=
Gemm_Q_K_base
<
Kernel_traits
>
;
using
Smem_tile_o
=
typename
Base
::
Smem_tile_o
;
using
Smem_tile_q
=
typename
Base
::
Smem_tile_q
;
using
Smem_tile_k
=
typename
Base
::
Smem_tile_k
;
using
Fragment_k
=
typename
Base
::
Fragment_k
;
using
Mma_tile_p
=
typename
Base
::
Mma_tile_p
;
static
constexpr
bool
SHARE_SMEM_FOR_K_AND_V
=
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
;
// If V is stored in shared memory, we can't load K using the same shared memory.
static_assert
(
Kernel_traits
::
V_IN_REGS
);
static
constexpr
int
SMEM_OFFSET_O
=
Smem_tile_q
::
BYTES_PER_TILE
;
static
constexpr
int
SMEM_OFFSET_SOFTMAX
=
SMEM_OFFSET_O
+
Smem_tile_o
::
BYTES_PER_TILE
;
static
constexpr
int
SMEM_OFFSET_V
=
Smem_tile_q
::
BYTES_PER_TILE
+
(
SHARE_SMEM_FOR_K_AND_V
?
0
:
Smem_tile_k
::
BYTES_PER_TILE
);
// Q | K / V
// | O | SOFTMAX
static
constexpr
int
SMEM_BYTES
=
Smem_tile_q
::
BYTES_PER_TILE
+
std
::
max
((
SHARE_SMEM_FOR_K_AND_V
?
1
:
2
)
*
Smem_tile_k
::
BYTES_PER_TILE
,
Smem_tile_o
::
BYTES_PER_TILE
+
Base
::
SMEM_BYTES_SOFTMAX
);
__device__
inline
Gemm_Q_K
(
char
*
smem_
,
const
int
tidx
)
:
Base
(
smem_
,
smem_
+
Smem_tile_q
::
BYTES_PER_TILE
,
tidx
)
{
}
__device__
inline
void
load_k
(){
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
Base
::
smem_k
.
load
(
frag_k
[
ki
],
ki
);
}
}
template
<
typename
Acc
,
int
M
,
int
N
>
__device__
inline
void
operator
()(
Acc
(
&
acc_p
)[
M
][
N
]){
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
Base
::
smem_q
.
load
(
Base
::
frag_q
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
}
__device__
inline
void
reload_k
(){
// Noop.
}
Fragment_k
frag_k
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
};
template
<
typename
Kernel_traits
>
struct
Gemm_Q_K
<
Kernel_traits
,
false
>
:
public
Gemm_Q_K_base
<
Kernel_traits
>
{
using
Base
=
Gemm_Q_K_base
<
Kernel_traits
>
;
using
Smem_tile_o
=
typename
Base
::
Smem_tile_o
;
using
Smem_tile_q
=
typename
Base
::
Smem_tile_q
;
using
Smem_tile_k
=
typename
Base
::
Smem_tile_k
;
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
using
Fragment_k
=
typename
Base
::
Fragment_k
;
using
Mma_tile_p
=
typename
Base
::
Mma_tile_p
;
Fragment_k
frag_k
[
2
][
Mma_tile_p
::
MMAS_N
];
static
constexpr
bool
SHARE_SMEM_FOR_K_AND_V
=
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
;
static
constexpr
bool
V_IN_REGS
=
Kernel_traits
::
V_IN_REGS
;
static_assert
(
V_IN_REGS
||
!
SHARE_SMEM_FOR_K_AND_V
);
static
constexpr
int
SMEM_OFFSET_V
=
Smem_tile_q
::
BYTES_PER_TILE
+
(
SHARE_SMEM_FOR_K_AND_V
?
0
:
Smem_tile_k
::
BYTES_PER_TILE
);
static_assert
(
Smem_tile_v
::
BYTES_PER_TILE
==
(
int
)
Smem_tile_k
::
BYTES_PER_TILE
);
static
constexpr
int
SMEM_OFFSET_O
=
SMEM_OFFSET_V
+
Smem_tile_v
::
BYTES_PER_TILE
;
static
constexpr
int
SMEM_OFFSET_SOFTMAX
=
SMEM_OFFSET_O
+
Smem_tile_o
::
BYTES_PER_TILE
;
// If V_IN_REGS and SHARE_SMEM_FOR_K_AND_V: Q | K/V | O | SOFTMAX
// If !V_IN_REGS (then !SHARE_SMEM_FOR_K_AND_V): Q | K | V | O | SOFTMAX
static
constexpr
int
SMEM_BYTES
=
Smem_tile_q
::
BYTES_PER_TILE
+
(
SHARE_SMEM_FOR_K_AND_V
?
1
:
2
)
*
Smem_tile_k
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
+
Base
::
SMEM_BYTES_SOFTMAX
;
__device__
inline
Gemm_Q_K
(
char
*
smem_
,
const
int
tidx
)
:
Base
(
smem_
,
smem_
+
Smem_tile_q
::
BYTES_PER_TILE
,
tidx
)
{
}
__device__
inline
void
load_k
(){
Base
::
smem_k
.
load
(
frag_k
[
0
],
0
);
}
template
<
typename
Acc
,
int
M
,
int
N
>
__device__
inline
void
operator
()(
Acc
(
&
acc_p
)[
M
][
N
]){
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
Base
::
smem_q
.
load
(
Base
::
frag_q
[
ki
&
1
],
ki
);
Base
::
smem_k
.
load
(
frag_k
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
}
__device__
inline
void
reload_k
(){
Base
::
smem_k
.
load
(
frag_k
[
0
],
0
);
}
};
template
<
typename
Kernel_traits
>
constexpr
size_t
get_dynamic_smem_size
(){
return
Gemm_Q_K
<
Kernel_traits
,
Kernel_traits
::
K_IN_REGS
>::
SMEM_BYTES
;
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Return_softmax
,
bool
Is_first
,
bool
Is_last
,
typename
Params
,
typename
Prng
>
inline
__device__
void
device_1xN_
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
int
begin
,
int
steps
,
Prng
&
ph0
,
Prng
&
ph1
,
const
int
loop_step_idx
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_o
=
typename
Kernel_traits
::
Cta_tile_o
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_o
=
fmha
::
Hmma_tile
<
Cta_tile_o
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
using
Gmem_tile_o
=
typename
Kernel_traits
::
Gmem_tile_o
;
using
Gmem_tile_o_tmp
=
fmha
::
Gmem_tile_o
<
Cta_tile_o
,
4
>
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Gmem_softmax_sum
=
typename
Kernel_traits
::
Gmem_softmax_sum
;
using
Smem_softmax_sum
=
typename
Kernel_traits
::
Smem_dp_sum
;
using
Gemm1
=
Gemm_Q_K
<
Kernel_traits
,
Kernel_traits
::
K_IN_REGS
>
;
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
// if( binfo.stop_early() ) return;
if
(
binfo
.
stop_early
(
loop_step_idx
*
Cta_tile_p
::
N
)
)
return
;
Gemm1
gemm_q_k
(
smem_
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
Gmem_tile_o_tmp
gmem_o_tmp
(
params
.
o_tmp_ptr
,
params
.
o_stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
Gmem_softmax_sum
gmem_softmax_lse
(
params
.
softmax_lse_ptr
,
params
,
tidx
);
// Wind gmem tiles to the correct position.
static_assert
(
Cta_tile_p
::
N
%
Cta_tile_p
::
M
==
0
);
const
int
begin_og
=
begin
;
begin
=
Is_causal
?
std
::
max
(
begin
,
loop_step_idx
*
Cta_tile_p
::
N
/
Cta_tile_p
::
M
)
:
begin
;
const
int
steps_og
=
steps
;
steps
-=
begin
-
begin_og
;
gmem_q
.
move
(
begin
);
gmem_o
.
move
(
begin
);
gmem_o_tmp
.
move
(
begin
);
if
(
Return_softmax
)
{
gmem_s
.
move
(
begin
);
}
gmem_softmax_lse
.
move
(
begin
);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("begin = %d, steps = %d\n", begin, steps);
// }
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Gemm1
::
SMEM_OFFSET_V
];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
if
(
!
Is_first
)
{
gmem_k
.
move
(
loop_step_idx
);
gmem_v
.
move
(
loop_step_idx
);
if
(
Return_softmax
)
{
gmem_s
.
move
(
loop_step_idx
*
steps_og
);
}
}
// Trigger the loads for K.
gmem_k
.
load
();
// Trigger the loads for Q.
gmem_q
.
load
();
// Trigger the loads for V.
gmem_v
.
load
();
if
(
!
Is_first
)
{
__syncthreads
();
}
float
p_prev_lse
[
Mma_tile_p
::
MMAS_M
*
2
];
if
(
!
Is_first
)
{
gmem_softmax_lse
.
load
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
p_prev_lse
));
}
// Commit the data for Q and V to shared memory.
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
gmem_v
.
commit
(
smem_v
);
// const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
// #pragma unroll
// for(int it=0;it < Gmem_tile_k::LDGS;it++){
// gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
// }
// Commit the data for K to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_k
.
commit
(
gemm_q_k
.
smem_k
);
}
__syncthreads
();
// Load the fragments for Q.
gemm_q_k
.
load_q
();
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename
Smem_tile_v
::
Fragment
frag_v
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
smem_v
.
load
(
frag_v
[
ki
],
ki
);
}
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_k
.
commit
(
gemm_q_k
.
smem_k
);
// Make sure the data is in shared memory.
__syncthreads
();
}
// Load the fragments for K.
gemm_q_k
.
load_k
();
// Create the object to do the softmax.
Softmax
softmax
(
params
,
&
smem_
[
Gemm1
::
SMEM_OFFSET_SOFTMAX
],
tidx
);
Smem_softmax_sum
smem_softmax_lse
(
reinterpret_cast
<
float
*>
(
&
smem_
[
Gemm1
::
SMEM_BYTES
]),
tidx
);
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
steps
;
l
++
)
{
if
((
begin
+
l
)
*
Cta_tile_p
::
M
>=
binfo
.
actual_seqlen
)
break
;
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
// Do this part of P = Q * K^T.
gemm_q_k
(
acc_p
);
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
if
(
!
Is_first
)
{
gmem_o_tmp
.
load
(
out
,
0
);
}
// Trigger the load for the next Q values.
if
(
l
<
steps
-
1
)
{
gemm_q_k
.
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
();
}
// Load the mask for that iteration.
mask
.
load
(
begin
+
l
);
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack_noscale
(
acc_p
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
// softmax.unpack_noscale_half_and_apply_mask(acc_p, mask);
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
&&
l
==
0
)
{
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads
();
}
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]);
// }
// }
// Compute the max.
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
if
(
!
Is_first
)
{
smem_softmax_lse
.
store_pair
(
p_prev_lse
,
l
%
2
);
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; }
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
*
2
;
mi
++
)
{
p_max
[
mi
]
=
p_prev_lse
[
mi
]
/
params
.
scale_bmm1f
;
}
}
// Trigger the load for the next LSE values.
if
(
l
<
steps
-
1
)
{
if
(
!
Is_first
)
{
gmem_softmax_lse
.
load_next
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
*
2
]
>
(
p_prev_lse
));
}
}
// __half2 p_max[Mma_tile_p::MMAS_M];
softmax
.
template
reduce_max
<
/*zero_init=*/
Is_first
>(
p_max
);
// if ((threadIdx.x == 0) && (l == 38)) {
// printf("loop_step_idx %d, p_max = %.6f, %.6f., p_prev_lse = %.6f, %.6f\n", loop_step_idx, p_max[0], p_max[1], Is_first ? -10000.f : p_prev_lse[0], Is_first ? -10000.f : p_prev_lse[1]);
// }
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("after reduce_max=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]);
// }
// }
// Compute the exponential value.
// softmax.apply_exp(p_max);
softmax
.
scale_apply_exp
(
p_max
,
params
.
scale_bmm1f
);
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("after apply_exp=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]);
// }
// }
// Compute the sum.
float
p_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
// if (!Is_first) {
// int warp = tidx / Cta_tile_p::THREADS_PER_WARP;
// int lane = tidx % Cta_tile_p::THREADS_PER_WARP;
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) {
// p_sum[mi] = ((warp == 0) && (lane % 4 == 0)) ? expf(p_prev_lse[mi] - p_max[mi]) : 0;
// }
// }
// softmax.reduce_sum(p_sum);
softmax
.
reduce_sum_before_sync_
(
p_sum
);
// softmax.template reduce_sum_before_sync_</*zero_init=*/Is_first>(p_sum);
// float p_sum_log[Mma_tile_p::MMAS_M * 2];
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; ++mi) {
// float sum = p_sum[mi];
// // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] + __logf(sum);
// constexpr float kLog2e = M_LOG2E;
// p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] * kLog2e + __log2f(sum);
// }
// // gmem_softmax_lse.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_sum));
// gmem_softmax_lse.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_sum_log));
// gmem_softmax_lse.move();
// // Finalize softmax on the accumulators of P^T.
// softmax.scale(p_sum);
constexpr
bool
encode_dropout_in_sign_bit
=
Return_softmax
;
if
(
Is_dropout
)
{
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph0, params.p_dropout_in_uint);
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph0, ph1, params.p_dropout_in_uint);
softmax
.
template
apply_dropout_16bits
<
encode_dropout_in_sign_bit
>(
ph0
,
ph1
,
params
.
p_dropout_in_uint16_t
);
}
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
static_assert
(
Mma_tile_o
::
MMAS_M
==
Mma_tile_p
::
MMAS_M
);
static_assert
(
Mma_tile_o
::
MMAS_K
==
Mma_tile_p
::
MMAS_N
);
softmax
.
pack
(
frag_p
);
if
(
Return_softmax
)
{
gmem_s
.
store
(
frag_p
,
mask
);
gmem_s
.
move
();
}
// Commit the values for Q into shared memory.
if
(
l
<
steps
-
1
)
{
gmem_q
.
commit
(
gemm_q_k
.
smem_q
);
}
if
(
Is_dropout
&&
encode_dropout_in_sign_bit
)
{
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
frag_p
[
ki
][
mi
].
hrelu_
();
}
}
}
// Declare the accumulators for the 2nd gemm.
fmha
::
Fragment_accumulator
acc_o
[
Mma_tile_o
::
MMAS_M
][
Mma_tile_o
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_o
::
WARPS_K
>::
apply
(
acc_o
);
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
}
// The mapping from tidx to rows changes between the softmax and the O-reduction.
// So we recalculate the max.
float
p_max_o
[
Gmem_tile_o
::
STGS_PER_LOOP
][
Mma_tile_o
::
MMAS_M
];
// TODO: not sure if this is right for seqlen 128 or 256
int
rows
[
Gmem_tile_o
::
STGS_PER_LOOP
];
for
(
int
jj
=
0
;
jj
<
Gmem_tile_o
::
STGS_PER_LOOP
;
jj
++
)
{
rows
[
jj
]
=
tidx
/
Gmem_tile_o
::
THREADS_PER_ROW
+
jj
*
Gmem_tile_o
::
ROWS_PER_STG
;
}
softmax
.
reduce_max_after_sync_
(
p_max_o
,
rows
);
static_assert
(
Mma_tile_o
::
MMAS_M
==
1
);
for
(
int
jj
=
0
;
jj
<
Gmem_tile_o
::
STGS_PER_LOOP
;
jj
++
)
{
p_max_o
[
jj
][
0
]
*=
params
.
scale_bmm1f
;
}
float
p_prev_scale_o
[
Gmem_tile_o
::
STGS_PER_LOOP
];
if
(
!
Is_first
)
{
smem_softmax_lse
.
load
(
p_prev_scale_o
,
rows
,
l
%
2
);
}
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("p_prev_scale_o=%.6f\n", p_prev_scale_o[0]);
// }
// }
static_assert
(
Gmem_tile_o
::
LOOPS
==
1
);
// Swizzle the elements and do the final reduction.
smem_o
.
store
(
acc_o
,
0
);
// Make sure the data is in shared memory.
__syncthreads
();
static_assert
(
Mma_tile_o
::
MMAS_M
==
1
);
float
p_sum_o
[
Gmem_tile_o
::
STGS_PER_LOOP
][
Mma_tile_o
::
MMAS_M
];
softmax
.
reduce_sum_after_sync_
(
p_sum_o
,
rows
);
if
(
!
Is_first
)
{
for
(
int
jj
=
0
;
jj
<
Gmem_tile_o
::
STGS_PER_LOOP
;
jj
++
)
{
p_prev_scale_o
[
jj
]
=
expf
(
p_prev_scale_o
[
jj
]
-
p_max_o
[
jj
][
0
]);
p_sum_o
[
jj
][
0
]
+=
p_prev_scale_o
[
jj
];
}
}
float
p_sum_log
[
Gmem_tile_o
::
STGS_PER_LOOP
][
Mma_tile_o
::
MMAS_M
];
#pragma unroll
for
(
int
jj
=
0
;
jj
<
Gmem_tile_o
::
STGS_PER_LOOP
;
jj
++
)
{
float
sum
=
p_sum_o
[
jj
][
0
];
p_sum_log
[
jj
][
0
]
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
-
INFINITY
:
p_max_o
[
jj
][
0
]
+
__logf
(
sum
);
// if (sum == 0.f || sum != sum) {
// printf("loop_step_idx = %d, l = %d, tidx = %d, sum = %.6f, p_max_o = %.6f\n", loop_step_idx, l, tidx, sum, p_max_o[jj][0]);
// }
// if (Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("p_sum_log=%.6f\n", p_sum_log[jj][0]);
// }
// }
if
((
tidx
%
Gmem_tile_o
::
THREADS_PER_ROW
==
0
)
&&
(
tidx
/
Gmem_tile_o
::
THREADS_PER_ROW
<
Gmem_tile_o
::
ROWS
))
{
gmem_softmax_lse
.
store_row
(
reinterpret_cast
<
uint32_t
(
&
)[
Mma_tile_p
::
MMAS_M
]
>
(
p_sum_log
[
jj
]),
rows
[
jj
]);
}
}
gmem_softmax_lse
.
move
();
// Load from shared memory.
if
(
!
Is_first
)
{
for
(
int
jj
=
0
;
jj
<
Gmem_tile_o
::
STGS_PER_LOOP
;
jj
++
)
{
out
[
jj
]
=
fmha
::
fmul4
(
out
[
jj
],
p_prev_scale_o
[
jj
]);
}
}
smem_o
.
template
load
<
/*zero_init=*/
Is_first
>(
out
);
const
bool
is_final_write
=
Is_last
||
((
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
>=
binfo
.
actual_seqlen
)
||
((
Is_causal
)
&&
((
begin
+
l
)
*
Cta_tile_p
::
M
<
(
loop_step_idx
+
1
)
*
Cta_tile_p
::
N
));
#pragma unroll
for
(
int
jj
=
0
;
jj
<
Gmem_tile_o
::
STGS_PER_LOOP
;
jj
++
)
{
float
sum
=
p_sum_o
[
jj
][
0
];
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
if
(
Is_dropout
&&
is_final_write
)
{
inv_sum
*=
params
.
rp_dropout
;
}
out
[
jj
]
=
fmha
::
fmul4
(
out
[
jj
],
inv_sum
);
}
// if (Is_dropout && Is_last) {
// for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
// out[jj] = fmha::fmul4(out[jj], params.rp_dropout);
// }
// }
// Output the values.
if
(
is_final_write
)
{
gmem_o
.
store
(
out
,
0
);
gmem_o
.
move
();
}
else
{
gmem_o_tmp
.
store
(
out
,
0
);
}
// Move to the next part of the output.
if
(
!
(
Is_first
&&
Is_last
))
{
gmem_o_tmp
.
move
();
}
gemm_q_k
.
reload_k
();
// Make sure we are reading from the correct buffer.
gemm_q_k
.
smem_q
.
move_to_next_read_buffer
();
// Trigger the load from shared memory for the next series of Q values.
if
(
l
<
steps
-
1
)
{
gemm_q_k
.
reload_q
();
}
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
device_1xN_loop
(
const
Params
&
params
)
{
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
int
tidx_global
=
(
bidb
*
params
.
h
+
bidh
)
*
blockDim
.
x
*
2
+
tidx
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph0
(
std
::
get
<
0
>
(
seeds
),
tidx_global
,
std
::
get
<
1
>
(
seeds
));
Philox
ph1
(
std
::
get
<
0
>
(
seeds
),
tidx_global
+
blockDim
.
x
,
std
::
get
<
1
>
(
seeds
));
const
int
STEPS
=
params
.
s
/
Kernel_traits
::
Cta_tile_p
::
M
;
constexpr
int
N_per_loop
=
Kernel_traits
::
Cta_tile_p
::
N
;
if
(
params
.
s
==
N_per_loop
)
{
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
true
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph0
,
ph1
,
0
);
}
else
{
const
int
max_loop_steps
=
(
params
.
s
+
N_per_loop
-
1
)
/
N_per_loop
;
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
true
,
false
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph0
,
ph1
,
0
);
for
(
int
loop_step_idx
=
1
;
loop_step_idx
<
max_loop_steps
-
1
;
loop_step_idx
++
)
{
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
false
,
false
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph0
,
ph1
,
loop_step_idx
);
}
fmha
::
device_1xN_
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
,
false
,
true
>
(
params
,
bidb
,
bidh
,
0
,
STEPS
,
ph0
,
ph1
,
max_loop_steps
-
1
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
csrc/stream_attn/src/fmha_kernel.h
0 → 100644
View file @
1fcbe6f0
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <philox.cuh>
#include <fmha.h>
#include <fmha/utils.h>
#include <fmha/smem_tile.h>
#include <fmha/gmem_tile.h>
#include <fmha/mask.h>
#include <fmha/softmax.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_CTA
>
struct
BlockInfoPadded
{
template
<
typename
Params
>
__device__
BlockInfoPadded
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
tidx
)
:
bidb
(
bidb
),
bidh
(
bidh
),
h
(
params
.
h
)
{
// The block index.
sum_s
=
params
.
cu_seqlens
[
bidb
];
actual_seqlen
=
params
.
cu_seqlens
[
bidb
+
1
]
-
sum_s
;
bidx
=
sum_s
*
params
.
h
+
bidh
;
tidx_global
=
(
bidb
*
params
.
h
+
bidh
)
*
THREADS_PER_CTA
+
tidx
;
}
__device__
bool
stop_early
(
const
int
start_col
=
0
)
const
{
return
actual_seqlen
<=
start_col
;
}
int
actual_seqlen
;
int
bidx
;
int
sum_s
;
int
bidh
;
int
bidb
;
int
tidx_global
;
int
h
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
CHUNKS
,
typename
Cta_tile
>
struct
Noloop_traits
{
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum
{
STEP
=
Cta_tile
::
M
};
enum
{
SEQLEN
=
Cta_tile
::
N
};
template
<
typename
Block_info
>
inline
__device__
Noloop_traits
(
const
int
bidc
,
const
Block_info
&
binfo
)
:
bidc_
(
bidc
)
{
const
int
seqlen
=
binfo
.
actual_seqlen
;
const
int
steps
=
(
seqlen
+
STEP
-
1
)
/
STEP
;
const
int
steps_per_chunk
=
(
steps
+
CHUNKS
-
1
)
/
CHUNKS
;
const
int
step_begin
=
bidc_
*
steps_per_chunk
;
const
int
step_end
=
min
(
steps
,
(
bidc_
+
1
)
*
steps_per_chunk
);
const
int
actual_steps
=
max
(
0
,
step_end
-
step_begin
);
loop_offset_
=
step_begin
;
num_steps_
=
actual_steps
;
}
template
<
typename
...
Tiles
>
inline
__device__
void
move_all
(
Tiles
&
...
tiles
)
const
{
using
expand_type
=
int
[];
for
(
int
s
=
0
;
s
<
loop_offset_
;
s
++
)
{
expand_type
{
(
tiles
.
move
(),
0
)...
};
}
}
inline
__device__
int
get_idx_dk
()
const
{
//return bidc_;
return
bidc_
*
2
+
0
;
}
inline
__device__
int
get_idx_dv
()
const
{
//return CHUNKS + bidc_;
return
bidc_
*
2
+
1
;
}
inline
__device__
int
offset_loop_count
(
const
int
l
)
{
// convert loop counter to position in the outer sequence
return
(
loop_offset_
+
l
)
*
STEP
;
}
const
uint32_t
bidc_
;
int
loop_offset_
;
int
num_steps_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
>
std
::
tuple
<
int
,
int
,
int
,
int
,
int
,
int
>
work_dist
(
const
int
total_ctas
,
const
int
heads_total
)
{
constexpr
int
STEPS_PER_HEAD
=
Kernel_traits
::
Cta_tile_p
::
N
/
Kernel_traits
::
Cta_tile_p
::
M
;
const
int
num_full_heads
=
heads_total
/
total_ctas
;
const
int
heads_last_wave
=
heads_total
%
total_ctas
;
int
num_main_groups
=
0
;
int
main_steps
=
0
;
int
rest_steps
=
0
;
if
(
heads_last_wave
>
0
)
{
// Number of CTA groups that process within heads.
num_main_groups
=
total_ctas
/
heads_last_wave
;
// Remaining CTAs that process between heads.
const
int
rest_ctas
=
total_ctas
-
(
heads_last_wave
*
num_main_groups
);
if
(
rest_ctas
==
0
)
{
// We have exactly "num_main_groups" CTAs to process each of the remaining heads.
main_steps
=
(
STEPS_PER_HEAD
+
num_main_groups
-
1
)
/
num_main_groups
;
num_main_groups
=
STEPS_PER_HEAD
/
main_steps
;
// Here: main_step > 0
rest_steps
=
STEPS_PER_HEAD
%
main_steps
;
}
else
{
// Ideal number of steps if we could load-balance as evenly as possible.
const
int
steps_ideal
=
(
heads_last_wave
*
STEPS_PER_HEAD
+
total_ctas
-
1
)
/
total_ctas
;
// Iterations that a "rest" CTA has to do at most.
const
int
max_rest_iters
=
(
heads_last_wave
+
rest_ctas
-
1
)
/
rest_ctas
;
// Find the first step distribution, s.t. the maximum work of the "rest" CTAs is less than the work of the main CTAs.
main_steps
=
steps_ideal
;
rest_steps
=
STEPS_PER_HEAD
-
main_steps
*
num_main_groups
;
for
(
;
main_steps
*
num_main_groups
<
STEPS_PER_HEAD
;
main_steps
++
)
{
rest_steps
=
STEPS_PER_HEAD
-
main_steps
*
num_main_groups
;
const
int
max_rest_total_steps
=
rest_steps
*
max_rest_iters
;
if
(
max_rest_total_steps
<
main_steps
)
break
;
}
rest_steps
=
STEPS_PER_HEAD
-
main_steps
*
num_main_groups
;
}
}
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
const
int
max_steps
=
STEPS_PER_HEAD
*
num_full_heads
+
std
::
max
(
main_steps
,
rest_steps
);
const
int
elts_per_thread_per_step
=
Mma_tile_p
::
MMAS_M
*
Mma_tile_p
::
MMAS_N
*
8
;
const
int
elts_per_thread
=
max_steps
*
elts_per_thread_per_step
;
return
{
num_full_heads
,
num_main_groups
,
heads_last_wave
,
main_steps
,
rest_steps
,
elts_per_thread
};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
csrc/stream_attn/src/fmha_utils.h
0 → 100644
View file @
1fcbe6f0
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
#define FMHA_CHECK_CUDA( call ) \
do { \
cudaError_t status_ = call; \
if( status_ != cudaSuccess ) { \
fprintf( stderr, \
"CUDA error (%s:%d): %s\n", \
__FILE__, \
__LINE__, \
cudaGetErrorString( status_ ) ); \
exit( 1 ); \
} \
} while( 0 )
////////////////////////////////////////////////////////////////////////////////////////////////////
enum
Data_type
{
DATA_TYPE_FP16
,
DATA_TYPE_FP32
,
DATA_TYPE_INT32
,
DATA_TYPE_INT8
};
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
void
set_alpha
(
uint32_t
&
alpha
,
float
norm
,
Data_type
dtype
)
{
if
(
dtype
==
DATA_TYPE_FP16
)
{
half
x
=
__float2half_rn
(
norm
);
uint16_t
h
=
reinterpret_cast
<
const
uint16_t
&>
(
x
);
ushort2
h2
=
{
h
,
h
};
alpha
=
reinterpret_cast
<
const
uint32_t
&>
(
h2
);
}
else
if
(
dtype
==
DATA_TYPE_FP32
)
{
alpha
=
reinterpret_cast
<
const
uint32_t
&>
(
norm
);
}
else
if
(
dtype
==
DATA_TYPE_INT32
)
{
int32_t
inorm
=
static_cast
<
int32_t
>
(
norm
);
alpha
=
reinterpret_cast
<
const
uint32_t
&>
(
inorm
);
}
else
{
assert
(
false
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
size_t
get_size_in_bytes
(
size_t
n
,
Data_type
dtype
)
{
switch
(
dtype
)
{
case
DATA_TYPE_FP32
:
return
n
*
4
;
case
DATA_TYPE_FP16
:
return
n
*
2
;
case
DATA_TYPE_INT32
:
return
n
*
4
;
case
DATA_TYPE_INT8
:
return
n
;
default:
assert
(
false
);
return
0
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
csrc/stream_attn/src/philox.cuh
0 → 100644
View file @
1fcbe6f0
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu
#pragma once
// Philox CUDA.
namespace
{
class
Philox
{
public:
__device__
inline
Philox
(
unsigned
long
long
seed
,
unsigned
long
long
subsequence
,
unsigned
long
long
offset
)
:
STATE
(
0
)
,
key
(
reinterpret_cast
<
const
uint2
&>
(
seed
))
{
//key.x = (unsigned int)seed;
//key.y = (unsigned int)(seed >> 32);
//counter = make_uint4(0, 0, 0, 0);
//counter.z = (unsigned int)(subsequence);
//counter.w = (unsigned int)(subsequence >> 32);
//STATE = 0;
//incr_n(offset / 4);
// key = reinterpret_cast<const uint2&>(seed);
ull2
*
tmp
=
reinterpret_cast
<
ull2
*>
(
&
counter
);
tmp
->
x
=
offset
/
4
;
tmp
->
y
=
subsequence
;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w);
// }
}
__device__
inline
uint4
operator
()()
{
// if (STATE == 0) {
uint4
counter_
=
counter
;
uint2
key_
=
key
;
// 7-round philox
#pragma unroll
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
counter_
=
single_round
(
counter_
,
key_
);
key_
.
x
+=
(
kPhilox10A
);
key_
.
y
+=
(
kPhilox10B
);
}
// output = single_round(counter_, key_);
uint4
output
=
single_round
(
counter_
,
key_
);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
// }
incr
();
// }
// return a float4 directly
// unsigned long ret;
// switch(STATE) {
// case 0: ret = output.x; break;
// case 1: ret = output.y; break;
// case 2: ret = output.z; break;
// case 3: ret = output.w; break;
//}
// STATE = (STATE + 1) % 4;
return
output
;
}
private:
struct
ull2
{
uint64_t
x
;
uint64_t
y
;
};
uint4
counter
;
// uint4 output;
const
uint2
key
;
unsigned
int
STATE
;
__device__
inline
void
incr_n
(
unsigned
long
long
n
)
{
unsigned
int
nlo
=
(
unsigned
int
)(
n
);
unsigned
int
nhi
=
(
unsigned
int
)(
n
>>
32
);
counter
.
x
+=
nlo
;
if
(
counter
.
x
<
nlo
)
nhi
++
;
counter
.
y
+=
nhi
;
if
(
nhi
<=
counter
.
y
)
return
;
if
(
++
counter
.
z
)
return
;
++
counter
.
w
;
}
__device__
uint4
incr128
(
uint4
ctr
)
{
uint4
res
;
asm
(
"add.cc.u32 %0, %4, %8;
\n\t
"
"addc.cc.u32 %1, %5, %9;
\n\t
"
"addc.cc.u32 %2, %6, %10;
\n\t
"
"addc.u32 %3, %7, %11;
\n\t
"
:
"=r"
(
res
.
x
),
"=r"
(
res
.
y
),
"=r"
(
res
.
z
),
"=r"
(
res
.
w
)
:
"r"
(
ctr
.
x
),
"r"
(
ctr
.
y
),
"r"
(
ctr
.
z
),
"r"
(
ctr
.
w
),
"n"
(
1
),
"n"
(
0
),
"n"
(
0
),
"n"
(
0
));
return
res
;
}
__device__
inline
void
incr
()
{
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// }
counter
=
incr128
(
counter
);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// }
}
__device__
unsigned
int
mulhilo32
(
unsigned
int
a
,
unsigned
int
b
,
unsigned
int
*
result_high
)
{
*
result_high
=
__umulhi
(
a
,
b
);
return
a
*
b
;
}
__device__
uint2
mulhilo32_v2
(
const
unsigned
int
a
,
const
unsigned
int
b
)
{
uint2
*
res
;
unsigned
long
long
tmp
;
asm
(
"mul.wide.u32 %0, %1, %2;
\n\t
"
:
"=l"
(
tmp
)
:
"r"
(
a
),
"r"
(
b
));
res
=
(
uint2
*
)(
&
tmp
);
return
*
res
;
}
__device__
inline
uint4
single_round
(
const
uint4
ctr
,
const
uint2
key
)
{
//unsigned int hi0;
//unsigned int hi1;
//unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
//unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
//uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
uint2
res0
=
mulhilo32_v2
(
kPhiloxSA
,
ctr
.
x
);
uint2
res1
=
mulhilo32_v2
(
kPhiloxSB
,
ctr
.
z
);
uint4
ret
=
{
res1
.
y
^
ctr
.
y
^
key
.
x
,
res1
.
x
,
res0
.
y
^
ctr
.
w
^
key
.
y
,
res0
.
x
};
return
ret
;
}
static
const
unsigned
long
kPhilox10A
=
0x9E3779B9
;
static
const
unsigned
long
kPhilox10B
=
0xBB67AE85
;
static
const
unsigned
long
kPhiloxSA
=
0xD2511F53
;
static
const
unsigned
long
kPhiloxSB
=
0xCD9E8D57
;
};
// Inverse of 2^32.
constexpr
float
M_RAN_INVM32
=
2.3283064e-10
f
;
__device__
__inline__
float4
uniform4
(
const
uint4
x
)
{
return
make_float4
(
x
.
x
*
M_RAN_INVM32
,
x
.
y
*
M_RAN_INVM32
,
x
.
z
*
M_RAN_INVM32
,
x
.
w
*
M_RAN_INVM32
);
}
}
// namespace
rotary.py
0 → 100644
View file @
1fcbe6f0
# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py
# We split the input differently ((d 2) -> d 2 instead of (2 d) -> d 2), following the original
# paper's implementation. This should not matter.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox
# NOTE: Almost the same right now, moving parts to Triton is the next step
from
typing
import
Tuple
import
math
import
torch
from
einops
import
rearrange
,
repeat
def
rotate_half
(
x
):
# rearrange doesn't work with torch.jit
# x = rearrange(x, '... (d r) -> ... d r', r=2)
x
=
x
.
unflatten
(
dim
=-
1
,
sizes
=
(
-
1
,
2
))
x1
,
x2
=
x
.
unbind
(
dim
=-
1
)
rotated_x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
# return rearrange(rotated_x, '... d r -> ... (d r)')
return
rotated_x
.
flatten
(
start_dim
=-
2
)
@
torch
.
jit
.
script
def
apply_rotary_pos_emb
(
x
,
cos
,
sin
,
seq_dimension
:
int
=
-
2
):
# NOTE: This could probably be moved to Triton
# Handle a possible sequence length mismatch in between q and k
cos
=
cos
[:
x
.
shape
[
seq_dimension
],
:]
sin
=
sin
[:
x
.
shape
[
seq_dimension
],
:]
if
seq_dimension
==
-
3
:
cos
=
cos
[:,
None
,
:]
sin
=
sin
[:,
None
,
:]
return
(
x
*
cos
)
+
(
rotate_half
(
x
)
*
sin
)
class
RotaryEmbedding
(
torch
.
nn
.
Module
):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
A crucial insight from the method is that the query and keys are
transformed by rotation matrices which depend on the relative positions.
Other implementations are available in the Rotary Transformer repo_ and in
GPT-NeoX_, GPT-NeoX was an inspiration
.. _RoFormer: https://arxiv.org/abs/2104.09864
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
.. warning: Please note that this embedding is not registered on purpose, as it is transformative
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
"""
def
__init__
(
self
,
dim_model
:
int
,
*
_
,
**
__
):
super
().
__init__
()
# Generate and save the inverse frequency buffer (non trainable)
inv_freq
=
1.0
/
(
10000
**
(
torch
.
arange
(
0
,
dim_model
,
2
).
float
()
/
dim_model
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
self
.
_seq_len_cached
=
None
self
.
_cos_cached
=
None
self
.
_sin_cached
=
None
def
_update_cos_sin_tables
(
self
,
x
,
seq_dimension
=-
2
):
seq_len
=
x
.
shape
[
seq_dimension
]
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if
(
seq_len
!=
self
.
_seq_len_cached
or
self
.
_cos_cached
.
device
!=
x
.
device
or
self
.
_cos_cached
.
dtype
!=
x
.
dtype
):
self
.
_seq_len_cached
=
seq_len
t
=
torch
.
arange
(
x
.
shape
[
seq_dimension
],
device
=
x
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs
=
torch
.
outer
(
t
,
self
.
inv_freq
)
self
.
_cos_cached
=
repeat
(
torch
.
cos
(
freqs
).
to
(
x
.
dtype
),
'... d -> ... (d 2)'
)
self
.
_sin_cached
=
repeat
(
torch
.
sin
(
freqs
).
to
(
x
.
dtype
),
'... d -> ... (d 2)'
)
return
self
.
_cos_cached
,
self
.
_sin_cached
def
forward
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
seq_dimension
=-
2
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
seq_dimension
in
[
-
2
,
-
3
]
# Either (bs, h, s, d) or (bs, s, h, d)
self
.
_cos_cached
,
self
.
_sin_cached
=
self
.
_update_cos_sin_tables
(
k
,
seq_dimension
=
seq_dimension
)
return
(
apply_rotary_pos_emb
(
q
,
self
.
_cos_cached
,
self
.
_sin_cached
,
seq_dimension
),
apply_rotary_pos_emb
(
k
,
self
.
_cos_cached
,
self
.
_sin_cached
,
seq_dimension
),
)
class
RotaryEmbedding2D
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
):
super
().
__init__
()
assert
dim
%
4
==
0
self
.
rotary_emb1d
=
RotaryEmbedding
(
dim
//
2
)
def
forward
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
seq_dimension
=-
2
):
assert
seq_dimension
in
[
-
2
,
-
3
]
# Either (bs, h, s, d) or (bs, s, h, d)
seqlen
=
q
.
shape
[
seq_dimension
]
seqlen_sqrt
=
int
(
math
.
sqrt
(
seqlen
))
assert
seqlen
==
seqlen_sqrt
**
2
if
seq_dimension
==
-
3
:
# (bs, s, h, d)
q
=
rearrange
(
q
,
'b s h d -> b h s d'
)
k
=
rearrange
(
k
,
'b s h d -> b h s d'
)
q0
,
q1
=
q
.
chunk
(
2
,
dim
=-
1
)
k0
,
k1
=
k
.
chunk
(
2
,
dim
=-
1
)
# (bs, h, s, d)
q0
=
rearrange
(
q0
,
'b nheads (h w) d -> b nheads h w d'
,
h
=
seqlen_sqrt
)
k0
=
rearrange
(
k0
,
'b nheads (h w) d -> b nheads h w d'
,
h
=
seqlen_sqrt
)
q0_emb
,
k0_emb
=
self
.
rotary_emb1d
(
q0
,
k0
,
seq_dimension
=-
2
)
q0_emb
=
rearrange
(
q0_emb
,
'b nheads h w d -> b nheads (h w) d'
)
k0_emb
=
rearrange
(
k0_emb
,
'b nheads h w d -> b nheads (h w) d'
)
q1
=
rearrange
(
q1
,
'b nheads (h w) d -> b nheads h w d'
,
h
=
seqlen_sqrt
)
k1
=
rearrange
(
k1
,
'b nheads (h w) d -> b nheads h w d'
,
h
=
seqlen_sqrt
)
q1_emb
,
k1_emb
=
self
.
rotary_emb1d
(
q1
,
k1
,
seq_dimension
=-
3
)
q1_emb
=
rearrange
(
q1_emb
,
'b nheads h w d -> b nheads (h w) d'
)
k1_emb
=
rearrange
(
k1_emb
,
'b nheads h w d -> b nheads (h w) d'
)
q_emb
,
k_emb
=
torch
.
cat
([
q0_emb
,
q1_emb
],
dim
=-
1
),
torch
.
cat
([
k0_emb
,
k1_emb
],
dim
=-
1
)
if
seq_dimension
==
-
3
:
q_emb
=
rearrange
(
q_emb
,
'b h s d -> b s h d'
)
k_emb
=
rearrange
(
k_emb
,
'b h s d -> b s h d'
)
return
q_emb
,
k_emb
stream_attn_interface.py
0 → 100644
View file @
1fcbe6f0
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
import
torch
import
torch.nn
as
nn
import
stream_attn_cuda
def
_stream_attn_forward
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
,
return_softmax
):
context
,
softmax_lse
,
*
rest
=
stream_attn_cuda
.
fwd
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
False
,
causal
,
return_softmax
,
None
)
# if context.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
S_dmask
=
rest
[
0
]
if
return_softmax
else
None
return
context
,
softmax_lse
,
S_dmask
def
_stream_attn_backward
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
dqkv
,
dp
,
softmax_d
=
stream_attn_cuda
.
bwd
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
dropout_p
,
softmax_scale
,
max_s
,
False
,
causal
,
None
)
# if dqkv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return
dqkv
class
StreamAttnFun
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
context
,
softmax_lse
,
S_dmask
=
_stream_attn_forward
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
False
)
ctx
.
save_for_backward
(
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
max_s
=
max_s
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
return
context
@
staticmethod
def
backward
(
ctx
,
dout
):
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
rng_state
=
ctx
.
saved_tensors
if
rng_state
is
not
None
:
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
# S_dmask is None, temporarily use another tensor just to get it running
dqkv
=
_stream_attn_backward
(
dout
,
qkv
,
context
,
context
,
softmax_lse
,
cu_seqlens
,
ctx
.
dropout_p
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
)
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class
StreamAttnFunWithS
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
# Save rng_state because the backward pass is gonna regenerate the dropout mask
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
context
,
softmax_lse
,
S_dmask
=
_stream_attn_forward
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
True
)
ctx
.
save_for_backward
(
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
max_s
=
max_s
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
return
context
,
S_dmask
,
softmax_lse
@
staticmethod
def
backward
(
ctx
,
dout
,
_dS_dmask_ignored
,
_dsoftmax_sum_ignored
):
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
rng_state
=
ctx
.
saved_tensors
if
rng_state
is
not
None
:
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
dqkv
=
_stream_attn_backward
(
dout
,
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
ctx
.
dropout_p
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
)
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
def
stream_attn_func
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
"""dropout_p should be set to 0.0 during evaluation
"""
func
=
StreamAttnFun
if
not
return_attn_probs
else
StreamAttnFunWithS
return
func
.
apply
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
)
stream_blocksparse_attn_interface.py
0 → 100644
View file @
1fcbe6f0
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
import
torch
import
torch.nn
as
nn
import
stream_attn_cuda
def
convert_blockmask
(
blockmask
,
causal
):
"""Convert from the 0-1 format to the format used by the CUDA code.
0 means the block is skipped.
nonzero means the block is not skipped.
Argument:
blockmask: (row, col): a 0-1 tensor
Return:
blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row
indices of the nonzero blocks, padded with -1 to reach length @row.
The indices are multiplied by 4, with the smallest bit used to encode whether
it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is
the last nonzero in its row..
"""
assert
not
causal
# TD [2022-05-13]: The indexing and sorting is very tricky
nrow
,
ncol
=
blockmask
.
shape
# Sort does not support bool on CUDA
blockmask
=
blockmask
.
to
(
dtype
=
torch
.
uint8
)
nonzero_val
,
nonzero_sorted_rowidx
=
blockmask
.
sort
(
dim
=
0
,
stable
=
True
,
descending
=
True
)
nonzero_unsorted_rowidx
=
nonzero_sorted_rowidx
.
argsort
(
dim
=
0
)
last_nonzero_col_per_row
=
blockmask
.
sort
(
dim
=-
1
,
stable
=
True
).
indices
[:,
-
1
]
last_nonzero_col_per_row_after_sort
=
nonzero_unsorted_rowidx
[
torch
.
arange
(
nrow
,
device
=
blockmask
.
device
),
last_nonzero_col_per_row
]
first_nonzero_col_per_row
=
blockmask
.
sort
(
dim
=-
1
,
stable
=
True
,
descending
=
True
).
indices
[:,
0
]
first_nonzero_col_per_row_after_sort
=
nonzero_unsorted_rowidx
[
torch
.
arange
(
nrow
,
device
=
blockmask
.
device
),
first_nonzero_col_per_row
]
nonzero_idx
=
nonzero_sorted_rowidx
*
4
nonzero_idx
[
last_nonzero_col_per_row_after_sort
,
last_nonzero_col_per_row
]
+=
2
nonzero_idx
[
first_nonzero_col_per_row_after_sort
,
first_nonzero_col_per_row
]
+=
1
nonzero_idx
[
nonzero_val
==
0
]
=
-
1
return
nonzero_idx
.
T
.
contiguous
().
to
(
dtype
=
torch
.
int32
)
def
_stream_blocksparse_attn_forward
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
,
return_softmax
):
context
,
softmax_lse
,
*
rest
=
stream_attn_cuda
.
fwd_block
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
,
return_softmax
,
None
)
# if context.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
S_dmask
=
rest
[
0
]
if
return_softmax
else
None
return
context
,
softmax_lse
,
S_dmask
def
_stream_blocksparse_attn_backward
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
dqkv
,
dp
,
softmax_d
=
stream_attn_cuda
.
bwd_block
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
dropout_p
,
softmax_scale
,
max_s
,
causal
,
None
)
# if dqkv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return
dqkv
class
StreamBlocksparseAttnFun
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
context
,
softmax_lse
,
S_dmask
=
_stream_blocksparse_attn_forward
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
False
)
ctx
.
save_for_backward
(
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
max_s
=
max_s
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
return
context
@
staticmethod
def
backward
(
ctx
,
dout
):
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
rng_state
=
ctx
.
saved_tensors
if
rng_state
is
not
None
:
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
# S_dmask is None, temporarily use another tensor just to get it running
dqkv
=
_stream_blocksparse_attn_backward
(
dout
,
qkv
,
context
,
context
,
softmax_lse
,
cu_seqlens
,
blockmask
,
ctx
.
dropout_p
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
)
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class
StreamBlocksparseAttnFunWithS
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
# Save rng_state because the backward pass is gonna regenerate the dropout mask
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
context
,
softmax_lse
,
S_dmask
=
_stream_blocksparse_attn_forward
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
True
)
ctx
.
save_for_backward
(
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
max_s
=
max_s
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
return
context
,
S_dmask
,
softmax_lse
@
staticmethod
def
backward
(
ctx
,
dout
,
_dS_dmask_ignored
,
_dsoftmax_sum_ignored
):
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
rng_state
=
ctx
.
saved_tensors
if
rng_state
is
not
None
:
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
dqkv
=
_stream_blocksparse_attn_backward
(
dout
,
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
ctx
.
dropout_p
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
)
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
def
stream_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
convert_mask
=
True
):
"""dropout_p should be set to 0.0 during evaluation
"""
func
=
StreamBlocksparseAttnFun
if
not
return_attn_probs
else
StreamBlocksparseAttnFunWithS
if
convert_mask
:
blockmask
=
convert_blockmask
(
blockmask
,
causal
=
causal
)
return
func
.
apply
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
)
streaming_attention.py
0 → 100644
View file @
1fcbe6f0
import
math
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
from
rotary
import
RotaryEmbedding
,
RotaryEmbedding2D
from
stream_attn_interface
import
stream_attn_func
from
bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
class
StreamingAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
"""
def
__init__
(
self
,
softmax_temp
=
None
,
attention_dropout
=
0.0
,
device
=
None
,
dtype
=
None
):
super
().
__init__
()
self
.
softmax_temp
=
softmax_temp
self
.
dropout_p
=
attention_dropout
def
forward
(
self
,
qkv
,
attn_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
cu_seqlens
=
None
,
max_s
=
None
,
need_weights
=
False
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
if unpadded: (nnz, 3, h, d)
attn_mask: An implementation of BaseMask that encodes where each
query can attend to
key_padding_mask: An implementation of BaseMask that encodes how
many query each sequence in the batch consists of
"""
assert
not
need_weights
assert
attn_mask
is
None
assert
qkv
.
dtype
==
torch
.
float16
assert
qkv
.
is_cuda
if
cu_seqlens
is
None
:
batch_size
=
qkv
.
shape
[
0
]
seqlen
=
qkv
.
shape
[
1
]
if
key_padding_mask
is
None
:
qkv
=
rearrange
(
qkv
,
'b s ... -> (b s) ...'
)
max_s
=
seqlen
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
output
=
stream_attn_func
(
qkv
,
cu_seqlens
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
else
:
key_padding_mask_bool
=
key_padding_mask
.
bool_matrix
nheads
=
qkv
.
shape
[
-
2
]
x
=
rearrange
(
qkv
,
'b s three h d -> b s (three h d)'
)
x_unpad
,
indices
,
cu_seqlens
,
max_s
=
unpad_input
(
x
,
key_padding_mask_bool
)
x_unpad
=
rearrange
(
x_unpad
,
'nnz (three h d) -> nnz three h d'
,
three
=
3
,
h
=
nheads
)
output_unpad
=
stream_attn_func
(
x_unpad
,
cu_seqlens
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
output
=
rearrange
(
pad_input
(
rearrange
(
output_unpad
,
'nnz h d -> nnz (h d)'
),
indices
,
batch_size
,
seqlen
),
'b s (h d) -> b s h d'
,
h
=
nheads
)
else
:
assert
max_s
is
not
None
output
=
stream_attn_func
(
qkv
,
cu_seqlens
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
return
output
,
None
class
StreamingMHA
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
num_heads
,
bias
=
True
,
batch_first
=
True
,
attention_dropout
=
0.0
,
causal
=
False
,
use_rotary_emb
=
None
,
device
=
None
,
dtype
=
None
,
**
kwargs
)
->
None
:
assert
batch_first
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
causal
=
causal
self
.
num_heads
=
num_heads
assert
self
.
embed_dim
%
num_heads
==
0
,
"self.kdim must be divisible by num_heads"
self
.
head_dim
=
self
.
embed_dim
//
num_heads
assert
self
.
head_dim
in
[
16
,
32
,
64
],
"Only support head_dim == 16, 32, or 64"
assert
use_rotary_emb
in
[
None
,
'1d'
,
'2d'
]
self
.
use_rotary_emb
=
use_rotary_emb
if
self
.
use_rotary_emb
==
'1d'
:
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
head_dim
)
elif
self
.
use_rotary_emb
==
'2d'
:
self
.
rotary_emb
=
RotaryEmbedding2D
(
self
.
head_dim
)
self
.
Wqkv
=
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
inner_attn
=
StreamingAttention
(
attention_dropout
=
attention_dropout
,
**
factory_kwargs
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
x_ignored_
,
x_ignored_1_
,
attn_mask
=
None
,
key_padding_mask
=
None
,
need_weights
=
False
):
qkv
=
self
.
Wqkv
(
x
)
if
self
.
use_rotary_emb
:
query
,
key
,
value
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
h
=
self
.
num_heads
).
unbind
(
dim
=
2
)
query
,
key
=
self
.
rotary_emb
(
query
,
key
,
seq_dimension
=-
3
)
qkv
=
torch
.
stack
([
query
,
key
,
value
],
dim
=
2
)
else
:
qkv
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
h
=
self
.
num_heads
)
context
,
attn_weights
=
self
.
inner_attn
(
qkv
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
causal
=
self
.
causal
)
return
self
.
out_proj
(
rearrange
(
context
,
'b s h d -> b s (h d)'
)),
attn_weights
streaming_blocksparse_attention.py
0 → 100644
View file @
1fcbe6f0
import
math
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
import
hydra
from
stream_blocksparse_attn_interface
import
stream_blocksparse_attn_func
from
stream_blocksparse_attn_interface
import
convert_blockmask
from
bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
class
StreamingBlocksparseAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
"""
def
__init__
(
self
,
sparsity_config
,
softmax_temp
=
None
,
attention_dropout
=
0.0
,
max_seq_length
=
2048
,
device
=
None
,
dtype
=
None
):
super
().
__init__
()
self
.
sparsity_config
=
hydra
.
utils
.
instantiate
(
sparsity_config
)
self
.
softmax_temp
=
softmax_temp
self
.
dropout_p
=
attention_dropout
# initialize sparse layout and register as buffer
max_seq_length
=
((
max_seq_length
+
256
-
1
)
//
256
)
*
256
layout
=
self
.
sparsity_config
.
make_layout
(
max_seq_length
)
self
.
register_buffer
(
"layout"
,
layout
)
blockmask_converted
=
convert_blockmask
(
self
.
layout
,
causal
=
False
)
self
.
register_buffer
(
"blockmask_converted"
,
blockmask_converted
)
# logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')
def
forward
(
self
,
qkv
,
attn_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
cu_seqlens
=
None
,
max_s
=
None
,
need_weights
=
False
,
convert_mask
=
True
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
attn_mask: An implementation of BaseMask that encodes where each
query can attend to
key_padding_mask: An implementation of BaseMask that encodes how
many query each sequence in the batch consists of
"""
assert
not
need_weights
assert
attn_mask
is
None
assert
qkv
.
dtype
==
torch
.
float16
assert
qkv
.
is_cuda
if
cu_seqlens
is
None
:
batch_size
=
qkv
.
shape
[
0
]
seqlen
=
qkv
.
shape
[
1
]
# Convert mask to take a subset
seqlen_rounded
=
((
seqlen
+
256
-
1
)
//
256
)
*
256
assert
seqlen_rounded
//
16
<=
self
.
layout
.
shape
[
0
],
seqlen_rounded
//
256
<=
self
.
layout
.
shape
[
1
]
blockmask
=
self
.
layout
[:
seqlen_rounded
//
16
,
:
seqlen_rounded
//
256
]
if
key_padding_mask
is
None
:
qkv
=
rearrange
(
qkv
,
'b s ... -> (b s) ...'
)
max_s
=
seqlen
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
output
=
stream_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
else
:
key_padding_mask_bool
=
key_padding_mask
.
bool_matrix
nheads
=
qkv
.
shape
[
-
2
]
x
=
rearrange
(
qkv
,
'b s three h d -> b s (three h d)'
)
x_unpad
,
indices
,
cu_seqlens
,
max_s
=
unpad_input
(
x
,
key_padding_mask_bool
)
x_unpad
=
rearrange
(
x_unpad
,
'nnz (three h d) -> nnz three h d'
,
three
=
3
,
h
=
nheads
)
output_unpad
=
stream_blocksparse_attn_func
(
x_unpad
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
output
=
rearrange
(
pad_input
(
rearrange
(
output_unpad
,
'nnz h d -> nnz (h d)'
),
indices
,
batch_size
,
seqlen
),
'b s (h d) -> b s h d'
,
h
=
nheads
)
else
:
assert
max_s
is
not
None
seqlen
=
max_s
# Convert mask to take a subset
seqlen_rounded
=
((
seqlen
+
256
-
1
)
//
256
)
*
256
assert
seqlen_rounded
//
16
<=
self
.
layout
.
shape
[
0
],
seqlen_rounded
//
256
<=
self
.
layout
.
shape
[
1
]
blockmask
=
self
.
layout
[:
seqlen_rounded
//
16
,
:
seqlen_rounded
//
256
]
if
convert_mask
:
output
=
stream_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
else
:
output
=
stream_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
self
.
blockmask_converted
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
,
convert_mask
=
False
,
)
return
output
,
None
class
StreamingBlocksparseMHA
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
num_heads
,
sparsity_config
,
bias
=
True
,
batch_first
=
True
,
attention_dropout
=
0.0
,
causal
=
False
,
max_seq_length
=
2048
,
device
=
None
,
dtype
=
None
,
**
kwargs
)
->
None
:
assert
batch_first
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
causal
=
causal
self
.
num_heads
=
num_heads
assert
self
.
embed_dim
%
num_heads
==
0
,
"self.kdim must be divisible by num_heads"
self
.
head_dim
=
self
.
embed_dim
//
num_heads
assert
self
.
head_dim
in
[
16
,
32
,
64
],
"Only support head_dim == 16, 32, or 64"
self
.
Wqkv
=
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
inner_attn
=
StreamingBlocksparseAttention
(
sparsity_config
,
attention_dropout
=
attention_dropout
,
max_seq_length
=
max_seq_length
,
**
factory_kwargs
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
x_ignored_
,
x_ignored_1_
,
attn_mask
=
None
,
key_padding_mask
=
None
,
need_weights
=
False
):
qkv
=
self
.
Wqkv
(
x
)
qkv
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
h
=
self
.
num_heads
)
context
,
attn_weights
=
self
.
inner_attn
(
qkv
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
causal
=
self
.
causal
)
return
self
.
out_proj
(
rearrange
(
context
,
'b s h d -> b s (h d)'
)),
attn_weights
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment